Implicit Dictionary Learning
This is where we overview the model for the toy experiment.
Full Tutorial
See the tutorial page for a Jupyter notebook using this model.
Model Overview
Bases: ImplicitL2OModel
Model to recover signal from measurements by leveraging sparse structure
Inferences are defined by
where \(\mathsf{K}\) is a tunable matrix. Because the model is equivalent under scaling of \(\mathsf{K}\), we fix \(\mathsf{|| K ||_2 = 1}\). This is enforced during training by dividing by the matrix norm at the beginning of forward propagation. The forward iteration is Linearized ADMM (L-ADMM).
Apply Optimization Step
Apply model operator using L-ADMM update
Core functionality is single iteration update for Linearized ADMM, which is rearranged to make the signal \(\mathsf{x}\) update last. This is needed to ensure the JFB backprop attaches gradients. Here the tuple \(\mathsf{(\hat{x}, \hat{p}, \hat{v}_1, \hat{v}_2)}\) is given as input. Each update is given by the following.
\(\mathsf{p \leftarrow shrink(\hat{p} + {\lambda} (\hat{v}_1 + a (K\hat{x} - \hat{p})))}\)
\(\mathsf{v_1 \leftarrow \hat{v}_1 + \alpha (K\hat{x} - p)}\)
\(\mathsf{v_2 \leftarrow \hat{v}_2 + \alpha (Ax - d)}\)
\(\mathsf{r \leftarrow K^\top (2v_1 - \hat{v}_1) + A^\top (2v_2 - \hat{v}_2)}\)
\(\mathsf{x \leftarrow x - {\beta} r}\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
tensor
|
Signal Estimate |
required |
p |
tensor
|
Sparse transform \(\mathsf{Kx}\) of signal |
required |
v1 |
tensor
|
Dual variable for sparsity transform constraint \(\mathsf{Kx=p}\) |
required |
v2 |
tensor
|
Dual variable for linear constraint |
required |
Returns:
Name | Type | Description |
---|---|---|
x |
tensor
|
Updated Signal |
Source code in src/models.py
Get Convergence Criteria
Identify criteria for whether forward iteration to converges
Convergence Criteria
- Fidelity must satisfy \(\mathsf{\| Ax - d\| \leq tol \cdot ||d||}\)
- Update residual should be small for x and v, i.e. the expression \(\mathsf{\|x^{k+1} - x^k|| + ||v^{k+1} - v^k||}\) is close to zero relative to \(\mathsf{\|x^k\| + \|v^k\|}\).
Note
Tolerance is added to norm_data
to handle the case where
\(\mathsf{d = 0}\).
Source code in src/models.py
Forward
Compute inference using L-ADMM.
The aim is to find \(\mathsf{v^\star}\) satisfying \(\mathsf{v^\star = T(v^\star; d)}\) where \(\mathsf{v^\star}\) is the dual variable for minimization problem, and \(\mathsf{T}\) is the update operation for L-ADMM. This operation is applied repeatedly until an approximate fixed point of \(\mathsf{T(\cdot; d)}\) is found. Associated with optimal dual, we obtain the inference x*.
Note
We write the dual as a tuple \(\mathsf{v = (v_1, v_2)}\).