Lasso Regression with FISTA Optimization using PyTorch’s Optimizer#

Dependencies: - numpy - torch - matplotlib - proxtorch - sklearn.datasets

Lasso Coefficients with FISTA using PyTorch's Optimizer
Global seed set to 42
/home/docs/checkouts/readthedocs.org/user_builds/proxtorch/checkouts/latest/docs/source/examples/plot_fista.py:68: MatplotlibDeprecationWarning: The 'use_line_collection' parameter of stem() was deprecated in Matplotlib 3.6 and will be removed two minor releases later. If any parameter follows 'use_line_collection', they should be passed as keyword, not positionally.
  plt.stem(coef, linefmt="r-", markerfmt="ro", use_line_collection=True, label="True")

import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from pytorch_lightning import seed_everything
from sklearn.datasets import make_regression
from torch.nn import Parameter

from proxtorch.operators import L1

seed_everything(42)

# Create synthetic data
X, y, coef = make_regression(
    n_samples=100, n_features=20, noise=0.1, coef=True, random_state=42
)
X, y = torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# Parameters
alpha = 0.1  # Regularization parameter for Lasso
lr = 0.01  # Learning rate
n_iter = 100  # Number of iterations
l1_prox = L1(alpha=alpha)


def fista(X, y, l1_prox, lr, n_iter):
    theta = Parameter(torch.zeros(X.shape[1]))  # Initialize weights
    optimizer = optim.SGD([theta], lr=lr)

    for _ in range(n_iter):
        optimizer.zero_grad()  # Reset gradients
        # Forward pass: Compute predicted y
        y_pred = X @ theta

        # Compute loss
        loss = ((y_pred - y) ** 2).mean()

        # Backward pass: Compute gradient of the loss with respect to model parameters
        loss.backward()

        # Optimizer step (gradient descent)
        optimizer.step()

        # Proximal operation
        with torch.no_grad():
            theta.data = l1_prox.prox(theta, lr)

    return theta


# Run FISTA
weights = fista(X, y, l1_prox, lr, n_iter)

# Plot non-zero coefficients
plt.stem(weights.detach().numpy(), label="FISTA")
plt.stem(coef, linefmt="r-", markerfmt="ro", use_line_collection=True, label="True")
plt.title("Lasso Coefficients with FISTA using PyTorch's Optimizer")
plt.legend()
plt.show()

Total running time of the script: (0 minutes 1.578 seconds)

Gallery generated by Sphinx-Gallery