Source code for proxtorch.operators.l2

import torch

from proxtorch.base import ProxOperator


[docs]class L2(ProxOperator): r""" L2 norm proximal operator. This class provides methods for soft-thresholding and computation of the L2 norm. Attributes: alpha (float): Regularization parameter. """ def __init__(self, alpha: float = 1.0): super().__init__() self.alpha = alpha
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor: r""" Apply the L2 proximal operator. Args: x (torch.Tensor): Input tensor. tau (float): Proximal operator step size. Returns: torch.Tensor: Resultant tensor after applying the proximal operator. """ return x / (1.0 + self.alpha * tau)
def _nonsmooth(self, x): return 0.5 * self.alpha * torch.linalg.norm(x.reshape(-1), 2) ** 2