Source code for proxtorch.constraints.l2ball

import torch

from proxtorch.base import Constraint


[docs]class L2Ball(Constraint): r""" Projection onto the L1 ball. Attributes: s (float): Radius of the L1 ball. """ def __init__(self, s: float = 1.0): super().__init__() self.s = s
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: r""" Project x onto the L1-ball of radius `s`. Args: x (torch.Tensor): Input tensor. tau (float): Proximal step size. Not used for L1Ball, but kept for API consistency. Returns: torch.Tensor: Resultant tensor after the projection. """ # The logic provided is one of the ways to achieve this projection. if torch.norm(x, p=2) <= self.s: return x else: return x * self.s / torch.norm(x, p=2)
def __call__(self, x: torch.Tensor) -> bool: r"""Check if the tensor satisfies the L1 constraint. Args: x (torch.Tensor): Input tensor. Returns: bool: True if L1-norm of tensor is less than or equal to `s`, False otherwise. """ return torch.linalg.norm(x.reshape(-1), ord=2) <= self.s