Source code for proxtorch.operators.l1
import torch
from proxtorch.base import ProxOperator
[docs]class L1(ProxOperator):
r"""
L1 norm proximal operator.
The L1 norm promotes sparsity in the tensor.
Attributes:
alpha (float): Regularization strength.
"""
def __init__(self, alpha: float = 1.0):
super().__init__()
self.alpha = alpha
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:
r"""
Soft-thresholding for the L1 norm.
Args:
x (torch.Tensor): Input tensor.
tau (float): Proximal operator step size.
Returns:
torch.Tensor: Resultant tensor after soft-thresholding.
"""
return torch.sign(x) * torch.clamp(torch.abs(x) - tau * self.alpha, min=0)
def _nonsmooth(self, x):
return self.alpha * torch.linalg.norm(x.reshape(-1), ord=1)