.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_fista.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_fista.py: Lasso Regression with FISTA Optimization using PyTorch's Optimizer =================================================================== ... Dependencies: - `numpy` - `torch` - `matplotlib` - `proxtorch` - `sklearn.datasets` .. GENERATED FROM PYTHON SOURCE LINES 14-72 .. image-sg:: /auto_examples/images/sphx_glr_plot_fista_001.png :alt: Lasso Coefficients with FISTA using PyTorch's Optimizer :srcset: /auto_examples/images/sphx_glr_plot_fista_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Global seed set to 42 /home/docs/checkouts/readthedocs.org/user_builds/proxtorch/checkouts/latest/docs/source/examples/plot_fista.py:68: MatplotlibDeprecationWarning: The 'use_line_collection' parameter of stem() was deprecated in Matplotlib 3.6 and will be removed two minor releases later. If any parameter follows 'use_line_collection', they should be passed as keyword, not positionally. plt.stem(coef, linefmt="r-", markerfmt="ro", use_line_collection=True, label="True") | .. code-block:: default import matplotlib.pyplot as plt import torch import torch.optim as optim from pytorch_lightning import seed_everything from sklearn.datasets import make_regression from torch.nn import Parameter from proxtorch.operators import L1 seed_everything(42) # Create synthetic data X, y, coef = make_regression( n_samples=100, n_features=20, noise=0.1, coef=True, random_state=42 ) X, y = torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32) # Parameters alpha = 0.1 # Regularization parameter for Lasso lr = 0.01 # Learning rate n_iter = 100 # Number of iterations l1_prox = L1(alpha=alpha) def fista(X, y, l1_prox, lr, n_iter): theta = Parameter(torch.zeros(X.shape[1])) # Initialize weights optimizer = optim.SGD([theta], lr=lr) for _ in range(n_iter): optimizer.zero_grad() # Reset gradients # Forward pass: Compute predicted y y_pred = X @ theta # Compute loss loss = ((y_pred - y) ** 2).mean() # Backward pass: Compute gradient of the loss with respect to model parameters loss.backward() # Optimizer step (gradient descent) optimizer.step() # Proximal operation with torch.no_grad(): theta.data = l1_prox.prox(theta, lr) return theta # Run FISTA weights = fista(X, y, l1_prox, lr, n_iter) # Plot non-zero coefficients plt.stem(weights.detach().numpy(), label="FISTA") plt.stem(coef, linefmt="r-", markerfmt="ro", use_line_collection=True, label="True") plt.title("Lasso Coefficients with FISTA using PyTorch's Optimizer") plt.legend() plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.578 seconds) .. _sphx_glr_download_auto_examples_plot_fista.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_fista.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_fista.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_