Skip to content

Implicit Deep Learning for CT

Herein we overview the model and setup for the CT image reconstruction experiments.

CT Data

The datasets used in this set of experiments are stored in a publicly accesible Google Drive folder.

Download CT Data


CT Model Overview

Bases: ImplicitL2OModel

Model to reconstruct CT image from measurements.

Inferences are defined by

model(d) = argmin f_theta(Kx)   s.t.   ||Ax - d|| < delta,

where K, theta, and delta are tunable parameters. The forward iteration is Linearized ADMM (L-ADMM) and the stepsizes in the algorithm are tunable too.


Apply Optimization Step

Apply model operator using L-ADMM update

Core functionality is single iteration update for Linearized ADMM, which is rearranged to make the signal 'u' update last. This is needed to ensure the JFB attaches gradients.

Source code in src/models.py
def _apply_T(self, x: inference, d: input_data, return_tuple=False):
    ''' Apply model operator using L-ADMM update

        Core functionality is single iteration update for Linearized ADMM,
        which is rearranged to make the signal 'u' update last. This is
        needed to ensure the JFB attaches gradients.
    '''
    batch_size = x.shape[0]

    d = d.view(d.shape[0], -1).to(self.device())
    d = d.permute(1, 0)
    xk = x.view(x.shape[0], -1)
    xk = xk.permute(1, 0)
    pk = self.K(xk)
    wk = torch.matmul(self.A, xk)
    nuk1 = torch.zeros(pk.size(), device=self.device())
    nuk2 = torch.zeros(d.size(), device=self.device())

    alpha = torch.clamp(self.alpha.data, min=0, max=2)
    beta = torch.clamp(self.beta.data, min=0, max=2)
    lambd = torch.clamp(self.lambd.data, min=0, max=2)
    delta = self.delta.data

    # pk step
    pk = pk + lambd*(nuk1 + alpha * (self.K(xk) - pk))
    pk = self.R(pk)

    # wk step
    Axk = torch.matmul(self.A, xk)
    res_temp = nuk2 + alpha * (Axk - wk)
    temp_term = wk + lambd * res_temp
    # temp_term = self.S(wk + lambd * res_temp)
    wk = self.ball_proj(temp_term, d, delta)

    # nuk1 step
    res_temp = self.K(xk) - pk
    nuk1_plus = nuk1 + alpha * res_temp

    # nuk2 step
    res_temp = Axk - wk
    nuk2_plus = nuk2 + alpha * res_temp

    # rk step
    self.convK_T.weight.data = self.convK.weight.data
    rk = self.Kt(2*nuk1_plus - nuk1)
    rk = rk + torch.matmul(self.At, 2*nuk2_plus - nuk2)

    # xk step
    xk = torch.clamp(xk - beta * rk, min=0, max=1)

    if return_tuple:
        return xk.permute(1, 0).view(batch_size, 1, 128, 128), nuk1_plus, pk
    else:
        return xk.permute(1, 0).view(batch_size, 1, 128, 128)


Get Convergence Criteria

Identify criteria for whether forward iteration to converges

Criteria implies update residual should be small for x, i.e. the expression |x^{k+1} - x^k|| is close to zero

Source code in src/models.py
def _get_conv_crit(self, x, x_prev, d, tol=1.0e-2):
    ''' Identify criteria for whether forward iteration to converges

        Criteria implies update residual should be small for x, i.e. the
               expression |x^{k+1} - x^k|| is close to zero
    '''
    batch_size = x.shape[0]
    x = x.view(batch_size, -1)
    d = d.view(batch_size, -1)
    x_prev = x_prev.view(batch_size, -1)

    res_norm = torch.max(torch.norm(x - x_prev, dim=1))
    residual_conv = res_norm <= tol

    return residual_conv


Forward

Compute inference using L-ADMM.

The aim is to find nu = T(nu; d) where nu is the dual variable for minimization problem, and T is the update operation for L-ADMM. Associated with optimal dual, we obtain the inference u.

Source code in src/models.py
def forward(self, d, depth_warning=False, return_depth=False, tol=1e-3, return_all_vars=False):
    ''' Compute inference using L-ADMM.

        The aim is to find nu* = T(nu*; d) where nu* is the dual variable
        for minimization problem, and T is the update operation for L-ADMM.
        Associated with optimal dual, we obtain the inference u*.
    '''
    with torch.no_grad():

        self.depth = 0.0
        x = torch.zeros((d.size()[0], 1, 128, 128),
                        device=self.device())
        x_prev = np.Inf*torch.ones(x.shape, device=self.device())
        all_samp_conv = False

        while not all_samp_conv and self.depth < self.max_depth:
            x_prev = x.clone()
            x = self._apply_T(x, d)
            all_samp_conv = self._get_conv_crit(x,
                                                x_prev,
                                                d,
                                                tol=tol)

            self.depth += 1

    if self.depth >= self.max_depth and depth_warning:
        print("\nWarning: Max Depth Reached - Break Forward Loop\n")

    self.fixed_point_error = torch.max(torch.norm(x - x_prev, dim=1))

    if return_depth and return_all_vars==False:
        Tx = self._apply_T(x, d)
        return Tx, self.depth
    elif return_all_vars:
        Tx, nuk1, pk = self._apply_T(x, d, return_tuple=True)
        return Tx, nuk1, pk
    else:
        Tx = self._apply_T(x, d)
        return Tx