Skip to content

Implicit Dictionary Learning

This is where we overview the model for the toy experiment.

Full Tutorial

See the tutorial page for a Jupyter notebook using this model.

Model Overview

Bases: ImplicitL2OModel

Model to recover signal from measurements by leveraging sparse structure

Inferences are defined by

\[ \mathsf{model(d) = argmin_x\ \| Kx \|_1 \quad s.t. \quad Ax = d,} \]

where \(\mathsf{K}\) is a tunable matrix. Because the model is equivalent under scaling of \(\mathsf{K}\), we fix \(\mathsf{|| K ||_2 = 1}\). This is enforced during training by dividing by the matrix norm at the beginning of forward propagation. The forward iteration is Linearized ADMM (L-ADMM).


Apply Optimization Step

Apply model operator using L-ADMM update

Core functionality is single iteration update for Linearized ADMM, which is rearranged to make the signal \(\mathsf{x}\) update last. This is needed to ensure the JFB backprop attaches gradients. Here the tuple \(\mathsf{(\hat{x}, \hat{p}, \hat{v}_1, \hat{v}_2)}\) is given as input. Each update is given by the following.

\(\mathsf{p \leftarrow shrink(\hat{p} + {\lambda} (\hat{v}_1 + a (K\hat{x} - \hat{p})))}\)

\(\mathsf{v_1 \leftarrow \hat{v}_1 + \alpha (K\hat{x} - p)}\)

\(\mathsf{v_2 \leftarrow \hat{v}_2 + \alpha (Ax - d)}\)

\(\mathsf{r \leftarrow K^\top (2v_1 - \hat{v}_1) + A^\top (2v_2 - \hat{v}_2)}\)

\(\mathsf{x \leftarrow x - {\beta} r}\)

Parameters:

Name Type Description Default
x tensor

Signal Estimate

required
p tensor

Sparse transform \(\mathsf{Kx}\) of signal

required
v1 tensor

Dual variable for sparsity transform constraint \(\mathsf{Kx=p}\)

required
v2 tensor

Dual variable for linear constraint

required

Returns:

Name Type Description
x tensor

Updated Signal

Source code in src/models.py
def _apply_T(self, x: inference, p: torch.tensor, v1: dual, v2: dual,
             d: input_data, return_tuple=False) -> inference:
    r""" Apply model operator using L-ADMM update

        Core functionality is single iteration update for Linearized ADMM,
        which is rearranged to make the signal $\mathsf{x}$ update last.
        This is needed to ensure the JFB backprop attaches gradients. Here
        the tuple $\mathsf{(\hat{x}, \hat{p}, \hat{v}_1, \hat{v}_2)}$ is
        given as input. Each update is given by the following.

        $\mathsf{p \leftarrow shrink(\hat{p} + {\lambda} (\hat{v}_1
        + a (K\hat{x} - \hat{p})))}$

        $\mathsf{v_1 \leftarrow \hat{v}_1 + \alpha (K\hat{x} - p)}$

        $\mathsf{v_2 \leftarrow \hat{v}_2 + \alpha (Ax - d)}$

        $\mathsf{r \leftarrow  K^\top (2v_1 - \hat{v}_1)
        + A^\top (2v_2 - \hat{v}_2)}$

        $\mathsf{x \leftarrow x - {\beta} r}$

        Args:
            x (tensor): Signal Estimate
            p (tensor): Sparse transform $\mathsf{Kx}$ of signal
            v1 (tensor): Dual variable for sparsity transform constraint $\mathsf{Kx=p}$
            v2 (tensor): Dual variable for linear constraint

        Returns:
            x (tensor): Updated Signal
    """
    x  = x.permute(1, 0).float()
    p  = p.permute(1, 0).float()
    d  = d.permute(1, 0).float()
    v1 = v1.permute(1, 0).float()
    v2 = v2.permute(1, 0).float()

    Kx = torch.mm(self.K.float(), x)
    Ax = torch.mm(self.A.to(device=self.device()), x)
    p  = self.shrink(p + self.lambd * (v1 + self.alpha * (Kx - p)))

    v1_prev = v1.to(self.device())
    v2_prev = v2.to(self.device())

    v1 = v1 + self.alpha * (Kx - p)
    v2 = v2 + self.alpha * (Ax - d)
    r  = self.K.t().mm(2 * v1 - v1_prev).to(self.device())
    r += self.A.to(device=self.device()).t().mm(2 * v2 - v2_prev)
    x  = x - self.beta * r

    x  = x.permute(1, 0)
    p  = p.permute(1, 0)
    v1 = v1.permute(1, 0)
    v2 = v2.permute(1, 0)
    if return_tuple:
        return x, p, v1, v2
    else:
        return x


Get Convergence Criteria

Identify criteria for whether forward iteration to converges

Convergence Criteria
  1. Fidelity must satisfy \(\mathsf{\| Ax - d\| \leq tol \cdot ||d||}\)
  2. Update residual should be small for x and v, i.e. the expression \(\mathsf{\|x^{k+1} - x^k|| + ||v^{k+1} - v^k||}\) is close to zero relative to \(\mathsf{\|x^k\| + \|v^k\|}\).
Note

Tolerance is added to norm_data to handle the case where \(\mathsf{d = 0}\).

Source code in src/models.py
def _get_conv_crit(self, x, x_prev, v1, v1_prev, v2, v2_prev, d,
                   tol_fidelity=1.0e-2, tol_residual=1.0e-4,
                   tol_num_stability=1.0e-8):
    """ Identify criteria for whether forward iteration to converges

        Convergence Criteria:
            1. Fidelity must satisfy $\mathsf{\| Ax - d\| \leq tol \cdot ||d||}$
            2. Update residual should be small for x and v, i.e. the
               expression $\mathsf{\|x^{k+1} - x^k|| + ||v^{k+1} - v^k||}$ is close
               to zero relative to $\mathsf{\|x^k\| + \|v^k\|}$.

        Note:
            Tolerance is added to `norm_data` to handle the case where
            $\mathsf{d = 0}$.
    """
    norm_res = torch.max(torch.norm(v1 - v1_prev, dim=1))
    norm_res += torch.max(torch.norm(v2 - v2_prev, dim=1))
    norm_res += torch.max(torch.norm(x - x_prev, dim=1))

    norm_res_ref = torch.max(torch.norm(x_prev, dim=1))
    norm_res_ref += torch.max(torch.norm(v1_prev, dim=1))
    norm_res_ref += torch.max(torch.norm(v2_prev, dim=1))

    fidelity = torch.mm(x, self.A.t().to(self.device())) - d
    norm_fidelity = torch.min(torch.norm(fidelity, dim=1))

    norm_data = torch.max(torch.norm(d, dim=1))
    norm_data += tol_num_stability * norm_fidelity

    residual_conv = norm_res <= tol_residual * norm_res_ref
    feasible_sol = norm_fidelity <= tol_fidelity * norm_data
    return residual_conv and feasible_sol


Forward

Compute inference using L-ADMM.

The aim is to find \(\mathsf{v^\star}\) satisfying \(\mathsf{v^\star = T(v^\star; d)}\) where \(\mathsf{v^\star}\) is the dual variable for minimization problem, and \(\mathsf{T}\) is the update operation for L-ADMM. This operation is applied repeatedly until an approximate fixed point of \(\mathsf{T(\cdot; d)}\) is found. Associated with optimal dual, we obtain the inference x*.

Note

We write the dual as a tuple \(\mathsf{v = (v_1, v_2)}\).

Source code in src/models.py
def forward(self, d: input_data, max_depth=5000,
            depth_warning=False, return_depth=False,
            normalize_K=False, return_certs=False) -> inference:
    """ Compute inference using L-ADMM.

        The aim is to find $\mathsf{v^\star}$ satisfying $\mathsf{v^\star = T(v^\star; d)}$ 
        where $\mathsf{v^\star}$ is the dual variable for minimization problem, 
        and $\mathsf{T}$ is the update operation for L-ADMM. This operation is applied
        repeatedly until an approximate fixed point of $\mathsf{T(\cdot; d)}$ is found.
        Associated with optimal dual, we obtain the inference x*.

        Note:
            We write the dual as a tuple $\mathsf{v = (v_1, v_2)}$.
    """
    self.A = self.A.to(self.device())
    d = d.to(self.device()).float()

    with torch.no_grad():
        if normalize_K:
            K_norm = torch.linalg.matrix_norm(self.K, ord=2)
            self.K /= K_norm

        self.depth = 0.0
        x_size = self.A.size()[1]
        d_size = self.A.size()[0]

        x = torch.zeros((d.size()[0], x_size),
                        device=self.device(), dtype=float)
        p = torch.zeros((d.size()[0], x_size),
                        device=self.device(), dtype=float)
        v1 = torch.zeros((d.size()[0], x_size),
                         device=self.device(), dtype=float)
        v2 = torch.zeros((d.size()[0], d_size),
                         device=self.device(), dtype=float)

        x_prev = x.clone()
        all_samp_conv = False
        while not all_samp_conv and self.depth < max_depth:
            v1_prev = v1.clone()
            v2_prev = v2.clone()
            x_prev = x.clone()

            x, p, v1, v2 = self._apply_T(x, p, v1, v2, d,
                                         return_tuple=True)
            all_samp_conv = self._get_conv_crit(x, x_prev, v1, v1_prev,
                                                v2, v2_prev, d)

            self.depth += 1

    if self.depth >= max_depth and depth_warning:
        print("\nWarning: Max Depth Reached - Break Forward Loop\n")

    Tx = self._apply_T(x, p, v1, v2, d)
    output = [Tx]
    if return_depth:
        output.append(self.depth)
    if return_certs:
        output.append(self.get_certs(Tx.detach(), d))
    return output if len(output) > 1 else Tx