Tutorial¶
Let's consider a toy problem. Here we consider a signal $\mathsf{x^\star \in \mathbb{R}^n}$ where $\mathsf{x^\star = M s}$ with $\mathsf{M \in \mathbb{R}^{n\times r}}$ and $\mathsf{s\in\mathbb{R}^r}$. The vector $\mathsf{s}$ is generated as the element-wise product of a Gaussian random vector and a Bernoulli vector ($\mathsf{p}$ the probability that each entry is nonzero), each vector with i.i.d. entries.
Access is given to linear measurements $ \mathsf{d = A x^\star}, $ where $\mathsf{A\in\mathbb{R}^{m\times n}}$ is a matrix with normalized columns (up to numerical tolerance). The task at hand is to
$$\mathsf{ Find\ x^\star\ given\ d\ and\ A}.$$
Using the fact that $\mathsf{p\cdot r\ll n}$, we know $\mathsf{x^\star}$ admits a sparse representation. Thus, we estimate
$$ \mathsf{x^\star \approx argmin_{x} \ \|K x\|_1 \ \ \mbox{s.t.}\ \ Ax=d.} $$
In this case, the implicit L2O model takes as input $\mathsf{d}$ and outputs an inference via
$$ \mathsf{{N_{\theta}}(d) = argmin_{x} \ \|K x\|_1 \ \ \mbox{s.t.}\ \ Ax=d}.$$
Throughout, we take $\mathsf{m=100}$, $\mathsf{n=250}$, $\mathsf{r=50}$, and $\mathsf{p=0.1}$.
First, we import various utilities and mount Google drive (where this notebook was executed).
import os
import sys
from google.colab import drive
drive.mount('/content/drive')
sys.path.append('/content/drive/MyDrive/xai-l2o/src/')
save_dir = './drive/MyDrive/xai-l2o/'
# from certificate import CertificateModel, CertificateEnsemble
from utils import solve_least_squares, create_dict_loaders
from utils import plot_dict_signal, print_model_params
from models import ImpDictModel
import scipy.io
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import random_split
torch.manual_seed(31415)
loader_train, loader_test, A = create_dict_loaders()
max_epoch = 75
device = 'cuda:0'
model = ImpDictModel(A)
model = model.to(device=device)
criterion = nn.MSELoss()
file_name = save_dir + 'weights/dictionary_model_weights.pth'
loss_best = 1.0e10
MSE_ave = 0.0
learning_rate = 4.0e-5
max_depth_train = 400
max_depth_test = 2000
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1.0e-5)
training_msg = '[{:5d}] train loss = {:2.3e} | depth = {:3.0f} | lr = {:2.3e}'
training_msg += ' | K 2-norm = {:2.3e}'
model.to(device)
print_model_params(model)
load_weights = False
if load_weights:
state = torch.load(file_name, map_location=torch.device(device))
model.load_state_dict(state['model_state_dict'])
print('Loaded model from file.')
epochs_adm = 0
Mounted at /content/drive +-------------------+--------------+ | Network Component | # Parameters | +-------------------+--------------+ | K | 62500 | | TOTAL | 62500 | +-------------------+--------------+
Model Training¶
With the model loaded, we next train it to predict $\mathsf{x^\star}$ from $\mathsf{d}$. We use the Adam optimizer and print samples from test data every few epochs to give intuition for how well the parameters are tuned.
for epoch in range(max_epoch):
model.train()
for x_true, d_batch in loader_train:
optimizer.zero_grad()
x_pred, depth = model(d_batch, max_depth=max_depth_train,
normalize_K=True, return_depth=True)
loss = criterion(x_pred, x_true.to(device).float())
loss.backward()
optimizer.step()
loss_curr = loss.detach().item()
if epoch % 10 == 0:
plot_dict_signal(model, loader_test, inference_depth=max_depth_test)
if loss_curr < loss_best:
loss_best = loss_curr
state = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(state, file_name)
print('Model weights saved to ' + file_name)
K_norm = torch.linalg.matrix_norm(model.K.detach(), ord=2)
print(training_msg.format(epoch, loss_curr, depth,
optimizer.param_groups[0]['lr'], K_norm))
Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 0] train loss = 4.517e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 1] train loss = 3.454e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 2] train loss = 2.508e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 3] train loss = 2.057e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 4] train loss = 1.678e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 [ 5] train loss = 1.681e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 6] train loss = 1.224e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 7] train loss = 1.100e+00 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 8] train loss = 7.953e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 [ 9] train loss = 8.571e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 10] train loss = 7.052e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 11] train loss = 5.494e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 12] train loss = 5.340e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 13] train loss = 4.789e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 14] train loss = 4.106e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 15] train loss = 3.093e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 16] train loss = 2.922e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 17] train loss = 2.745e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 18] train loss = 2.305e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 19] train loss = 1.662e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 20] train loss = 1.486e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 21] train loss = 1.426e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 22] train loss = 1.246e-01 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 23] train loss = 9.714e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 24] train loss = 8.895e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 25] train loss = 8.003e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 26] train loss = 6.286e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 27] train loss = 5.849e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 28] train loss = 5.447e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 29] train loss = 4.025e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00
Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 30] train loss = 2.939e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 31] train loss = 2.802e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 32] train loss = 2.250e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 33] train loss = 2.195e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 34] train loss = 1.735e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 [ 35] train loss = 1.775e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 36] train loss = 1.341e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 37] train loss = 1.308e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 38] train loss = 1.020e-02 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00 Model weights saved to ./drive/MyDrive/xai-l2o/weights/dictionary_model_weights.pth [ 39] train loss = 9.429e-03 | depth = 400 | lr = 4.000e-05 | K 2-norm = 1.000e+00