Source code for proxtorch.operators.elastic_net

import torch

from proxtorch.base import ProxOperator


[docs]class ElasticNet(ProxOperator): r""" Elastic Net proximal operator. Combines both L1 and L2 penalties for regularization. Attributes: alpha (float): Regularization strength. l1_ratio (float): Proportion of L1 regularization. l2_ratio (float): Proportion of L2 regularization. """ def __init__(self, alpha: float = 1.0, l1_ratio: float = 0.5): super().__init__() self.alpha = alpha self.l1_ratio = l1_ratio self.l2_ratio = 1.0 - l1_ratio
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor: r""" Proximal operator for the Elastic Net regularization. Args: x (torch.Tensor): Input tensor. tau (float): Proximal step size. Returns: torch.Tensor: Resultant tensor after applying the proximal operator. """ return ( torch.sign(x) * torch.clamp(torch.abs(x) - self.alpha * self.l1_ratio * tau, min=0) / (1.0 + self.alpha * self.l2_ratio * tau) )
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor: r""" Compute the combined L1 and L2 regularizations. Args: x (torch.Tensor): Input tensor. Returns: float: Elastic Net regularization term. """ l1_term = torch.norm(x, 1) l2_term = 0.5 * torch.norm(x) ** 2 return self.alpha * (self.l1_ratio * l1_term + self.l2_ratio * l2_term)