Source code for proxtorch.constraints.l1ball

import torch

from proxtorch.base import Constraint


[docs]class L1Ball(Constraint): r""" Projection onto the L1 ball. Attributes: s (float): Radius of the L1 ball. """ def __init__(self, s: float = 1.0): super().__init__() self.s = s
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: r""" Project x onto the L1-ball of radius `s`. Args: x (torch.Tensor): Input tensor. tau (float): Proximal step size. Not used for L1Ball, but kept for API consistency. Returns: torch.Tensor: Resultant tensor after the projection. """ # The logic provided is one of the ways to achieve this projection. if torch.norm(x, p=1) <= self.s: return x else: u, _ = torch.sort(torch.abs(x), descending=True) cssv = torch.cumsum(u, dim=0) - self.s idx = torch.arange(1, x.numel() + 1, device=x.device) cond = u - cssv / idx > 0 rho = idx[cond][-1] theta = cssv[cond][-1] / float(rho) return torch.sign(x) * torch.clamp(torch.abs(x) - theta, min=0)
def __call__(self, x: torch.Tensor) -> bool: r"""Check if the tensor satisfies the L1 constraint. Args: x (torch.Tensor): Input tensor. Returns: bool: True if L1-norm of tensor is less than or equal to `s`, False otherwise. """ return torch.linalg.norm(x.reshape(-1), ord=1) <= self.s