Tutorial¶
Let's consider a toy problem. Here we consider a signal $\mathsf{x^\star \in \mathbb{R}^n}$ where $\mathsf{x^\star = M s}$ with $\mathsf{M \in \mathbb{R}^{n\times r}}$ and $\mathsf{s\in\mathbb{R}^r}$. The vector $\mathsf{s}$ is generated as the element-wise product of a Gaussian random vector and a Bernoulli vector ($\mathsf{p}$ the probability that each entry is nonzero), each vector with i.i.d. entries.
Access is given to linear measurements $ \mathsf{d = A x^\star}, $ where $\mathsf{A\in\mathbb{R}^{m\times n}}$ is a matrix with normalized columns (up to numerical tolerance). The task at hand is to
$$\mathsf{ Find\ x^\star\ given\ d\ and\ A}.$$
Using the fact that $\mathsf{p\cdot r\ll n}$, we know $\mathsf{x^\star}$ admits a sparse representation. Thus, we estimate
$$ \mathsf{x^\star \approx argmin_{x} \ \|K x\|_1 \ \ \mbox{s.t.}\ \ Ax=d.} $$
In this case, the implicit L2O model takes as input $\mathsf{d}$ and outputs an inference via
$$ \mathsf{{N_{\theta}}(d) = argmin_{x} \ \|K x\|_1 \ \ \mbox{s.t.}\ \ Ax=d}.$$
Throughout, we take $\mathsf{m=100}$, $\mathsf{n=250}$, $\mathsf{r=50}$, and $\mathsf{p=0.1}$.
First, we import various utilities and mount Google drive (where this notebook was executed).
import os
import sys
from google.colab import drive
drive.mount('/content/drive')
sys.path.append('/content/drive/MyDrive/xai-via-l2o/src/')
save_dir = './drive/MyDrive/xai-via-l2o/'
from certificate import CertificateModel, CertificateEnsemble
from utils import create_property_loaders, solve_least_squares
from utils import plot_cmf, print_model_params
from utils import train_certificate, create_dict_loaders, plot_dict_signal
from models import ImpDictModel
import scipy.io
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import random_split
loader_train, loader_test, A = create_dict_loaders()
max_epoch = 500
device = 'cuda:0'
model = ImpDictModel(A)
model = model.to(device=device)
criterion = nn.MSELoss()
file_name = save_dir + 'weights/dictionary_model_weights.pth'
loss_best = 1.0e10
MSE_ave = 0.0
learning_rate = 1.0e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
training_msg = '[{:5d}] train loss = {:2.3e} | depth = {:3.0f} | lr = {:2.3e}'
training_msg += ' | K 2-norm = {:2.3e}'
model.to(device)
print_model_params(model)
load_weights = False
if load_weights:
state = torch.load(file_name, map_location=torch.device(device))
model.load_state_dict(state['model_state_dict'])
print('Loaded model from file.')
epochs_adm = 0
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True). +-------------------+--------------+ | Network Component | # Parameters | +-------------------+--------------+ | K | 62500 | | TOTAL | 62500 | +-------------------+--------------+
Model Training¶
With the model loaded, we next train it to predict $\mathsf{x^\star}$ from $\mathsf{d}$. We use the Adam optimizer and print samples from test data every few epochs to give intuition for how well the parameters are tuned.
for epoch in range(max_epoch):
model.train()
for x_true, d_batch in loader_train:
optimizer.zero_grad()
x_pred, depth = model(d_batch, max_depth=250,
normalize_K=True, return_depth=True)
loss = criterion(x_pred, x_true.to(device).float())
loss.backward()
optimizer.step()
loss_curr = loss.detach().item()
if epoch % 5 == 0:
plot_dict_signal(model, loader_test, inference_depth=500)
if loss_curr < loss_best:
loss_best = loss_curr
state = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(state, file_name)
print('Model weights saved to ' + file_name)
K_norm = torch.linalg.matrix_norm(model.K.detach(), ord=2)
print(training_msg.format(epoch, loss_curr, depth,
optimizer.param_groups[0]['lr'], K_norm))
Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 0] train loss = 2.967e+00 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 1] train loss = 1.845e+00 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 2] train loss = 1.155e+00 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 3] train loss = 7.241e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 4] train loss = 6.072e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 5] train loss = 3.936e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 6] train loss = 2.700e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 7] train loss = 2.247e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 8] train loss = 1.804e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 9] train loss = 1.279e-01 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 10] train loss = 8.161e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 11] train loss = 6.750e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 12] train loss = 5.095e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 13] train loss = 3.192e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 14] train loss = 2.809e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 15] train loss = 2.261e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 16] train loss = 1.661e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 17] train loss = 1.362e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 18] train loss = 1.237e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 19] train loss = 1.055e-02 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-via-l2o/weights/dictionary_model_weights.pth [ 20] train loss = 9.322e-03 | depth = 250 | lr = 1.000e-04 | K 2-norm = 1.000e+00
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-2-30d19760215d> in <cell line: 1>() 3 for x_true, d_batch in loader_train: 4 optimizer.zero_grad() ----> 5 x_pred, depth = model(d_batch, max_depth=250, 6 normalize_K=True, return_depth=True) 7 loss = criterion(x_pred, x_true.to(device).float()) /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [