# Implicit Dictionary Learning

This is where we overview the model for the toy experiment.

Full Tutorial

See the tutorial page for a Jupyter notebook using this model.

## Model Overview

Bases: ImplicitL2OModel

Model to recover signal from measurements by leveraging sparse structure

Inferences are defined by

$\mathsf{model(d) = argmin_x\ \| Kx \|_1 \quad s.t. \quad Ax = d,}$

where $$\mathsf{K}$$ is a tunable matrix. Because the model is equivalent under scaling of $$\mathsf{K}$$, we fix $$\mathsf{|| K ||_2 = 1}$$. This is enforced during training by dividing by the matrix norm at the beginning of forward propagation. The forward iteration is Linearized ADMM (L-ADMM).

## Apply Optimization Step

Apply model operator using L-ADMM update

Core functionality is single iteration update for Linearized ADMM, which is rearranged to make the signal $$\mathsf{x}$$ update last. This is needed to ensure the JFB backprop attaches gradients. Here the tuple $$\mathsf{(\hat{x}, \hat{p}, \hat{v}_1, \hat{v}_2)}$$ is given as input. Each update is given by the following.

$$\mathsf{p \leftarrow shrink(\hat{p} + {\lambda} (\hat{v}_1 + a (K\hat{x} - \hat{p})))}$$

$$\mathsf{v_1 \leftarrow \hat{v}_1 + \alpha (K\hat{x} - p)}$$

$$\mathsf{v_2 \leftarrow \hat{v}_2 + \alpha (Ax - d)}$$

$$\mathsf{r \leftarrow K^\top (2v_1 - \hat{v}_1) + A^\top (2v_2 - \hat{v}_2)}$$

$$\mathsf{x \leftarrow x - {\beta} r}$$

Parameters:

Name Type Description Default
x tensor

Signal Estimate

required
p tensor

Sparse transform $$\mathsf{Kx}$$ of signal

required
v1 tensor

Dual variable for sparsity transform constraint $$\mathsf{Kx=p}$$

required
v2 tensor

Dual variable for linear constraint

required

Returns:

Name Type Description
x tensor

Updated Signal

Source code in src/models.py
 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 def _apply_T(self, x: inference, p: torch.tensor, v1: dual, v2: dual, d: input_data, return_tuple=False) -> inference: r""" Apply model operator using L-ADMM update Core functionality is single iteration update for Linearized ADMM, which is rearranged to make the signal $\mathsf{x}$ update last. This is needed to ensure the JFB backprop attaches gradients. Here the tuple $\mathsf{(\hat{x}, \hat{p}, \hat{v}_1, \hat{v}_2)}$ is given as input. Each update is given by the following. $\mathsf{p \leftarrow shrink(\hat{p} + {\lambda} (\hat{v}_1 + a (K\hat{x} - \hat{p})))}$ $\mathsf{v_1 \leftarrow \hat{v}_1 + \alpha (K\hat{x} - p)}$ $\mathsf{v_2 \leftarrow \hat{v}_2 + \alpha (Ax - d)}$ $\mathsf{r \leftarrow K^\top (2v_1 - \hat{v}_1) + A^\top (2v_2 - \hat{v}_2)}$ $\mathsf{x \leftarrow x - {\beta} r}$ Args: x (tensor): Signal Estimate p (tensor): Sparse transform $\mathsf{Kx}$ of signal v1 (tensor): Dual variable for sparsity transform constraint $\mathsf{Kx=p}$ v2 (tensor): Dual variable for linear constraint Returns: x (tensor): Updated Signal """ x = x.permute(1, 0).float() p = p.permute(1, 0).float() d = d.permute(1, 0).float() v1 = v1.permute(1, 0).float() v2 = v2.permute(1, 0).float() Kx = torch.mm(self.K.float(), x) Ax = torch.mm(self.A.to(device=self.device()), x) p = self.shrink(p + self.lambd * (v1 + self.alpha * (Kx - p))) v1_prev = v1.to(self.device()) v2_prev = v2.to(self.device()) v1 = v1 + self.alpha * (Kx - p) v2 = v2 + self.alpha * (Ax - d) r = self.K.t().mm(2 * v1 - v1_prev).to(self.device()) r += self.A.to(device=self.device()).t().mm(2 * v2 - v2_prev) x = x - self.beta * r x = x.permute(1, 0) p = p.permute(1, 0) v1 = v1.permute(1, 0) v2 = v2.permute(1, 0) if return_tuple: return x, p, v1, v2 else: return x 

## Get Convergence Criteria

Identify criteria for whether forward iteration to converges

Convergence Criteria
1. Fidelity must satisfy $$\mathsf{\| Ax - d\| \leq tol \cdot ||d||}$$
2. Update residual should be small for x and v, i.e. the expression $$\mathsf{\|x^{k+1} - x^k|| + ||v^{k+1} - v^k||}$$ is close to zero relative to $$\mathsf{\|x^k\| + \|v^k\|}$$.
Note

Tolerance is added to norm_data to handle the case where $$\mathsf{d = 0}$$.

Source code in src/models.py
  98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 def _get_conv_crit(self, x, x_prev, v1, v1_prev, v2, v2_prev, d, tol_fidelity=1.0e-2, tol_residual=1.0e-4, tol_num_stability=1.0e-8): """ Identify criteria for whether forward iteration to converges Convergence Criteria: 1. Fidelity must satisfy $\mathsf{\| Ax - d\| \leq tol \cdot ||d||}$ 2. Update residual should be small for x and v, i.e. the expression $\mathsf{\|x^{k+1} - x^k|| + ||v^{k+1} - v^k||}$ is close to zero relative to $\mathsf{\|x^k\| + \|v^k\|}$. Note: Tolerance is added to norm_data to handle the case where $\mathsf{d = 0}$. """ norm_res = torch.max(torch.norm(v1 - v1_prev, dim=1)) norm_res += torch.max(torch.norm(v2 - v2_prev, dim=1)) norm_res += torch.max(torch.norm(x - x_prev, dim=1)) norm_res_ref = torch.max(torch.norm(x_prev, dim=1)) norm_res_ref += torch.max(torch.norm(v1_prev, dim=1)) norm_res_ref += torch.max(torch.norm(v2_prev, dim=1)) fidelity = torch.mm(x, self.A.t().to(self.device())) - d norm_fidelity = torch.min(torch.norm(fidelity, dim=1)) norm_data = torch.max(torch.norm(d, dim=1)) norm_data += tol_num_stability * norm_fidelity residual_conv = norm_res <= tol_residual * norm_res_ref feasible_sol = norm_fidelity <= tol_fidelity * norm_data return residual_conv and feasible_sol 

## Forward

The aim is to find $$\mathsf{v^\star}$$ satisfying $$\mathsf{v^\star = T(v^\star; d)}$$ where $$\mathsf{v^\star}$$ is the dual variable for minimization problem, and $$\mathsf{T}$$ is the update operation for L-ADMM. This operation is applied repeatedly until an approximate fixed point of $$\mathsf{T(\cdot; d)}$$ is found. Associated with optimal dual, we obtain the inference x*.
We write the dual as a tuple $$\mathsf{v = (v_1, v_2)}$$.
Source code in src/models.py
 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 def forward(self, d: input_data, max_depth=5000, depth_warning=False, return_depth=False, normalize_K=False, return_certs=False) -> inference: """ Compute inference using L-ADMM. The aim is to find $\mathsf{v^\star}$ satisfying $\mathsf{v^\star = T(v^\star; d)}$ where $\mathsf{v^\star}$ is the dual variable for minimization problem, and $\mathsf{T}$ is the update operation for L-ADMM. This operation is applied repeatedly until an approximate fixed point of $\mathsf{T(\cdot; d)}$ is found. Associated with optimal dual, we obtain the inference x*. Note: We write the dual as a tuple $\mathsf{v = (v_1, v_2)}$. """ self.A = self.A.to(self.device()) d = d.to(self.device()).float() with torch.no_grad(): if normalize_K: K_norm = torch.linalg.matrix_norm(self.K, ord=2) self.K /= K_norm self.depth = 0.0 x_size = self.A.size() d_size = self.A.size() x = torch.zeros((d.size(), x_size), device=self.device(), dtype=float) p = torch.zeros((d.size(), x_size), device=self.device(), dtype=float) v1 = torch.zeros((d.size(), x_size), device=self.device(), dtype=float) v2 = torch.zeros((d.size(), d_size), device=self.device(), dtype=float) x_prev = x.clone() all_samp_conv = False while not all_samp_conv and self.depth < max_depth: v1_prev = v1.clone() v2_prev = v2.clone() x_prev = x.clone() x, p, v1, v2 = self._apply_T(x, p, v1, v2, d, return_tuple=True) all_samp_conv = self._get_conv_crit(x, x_prev, v1, v1_prev, v2, v2_prev, d) self.depth += 1 if self.depth >= max_depth and depth_warning: print("\nWarning: Max Depth Reached - Break Forward Loop\n") Tx = self._apply_T(x, p, v1, v2, d) output = [Tx] if return_depth: output.append(self.depth) if return_certs: output.append(self.get_certs(Tx.detach(), d)) return output if len(output) > 1 else Tx