Source code for proxtorch.constraints.lInfinityBall

import torch

from proxtorch.base import Constraint


[docs]class LInfinityBall(Constraint): r""" Projection onto the LInfinity ball of radius `s`. Attributes: s (float): Radius of the LInfinity ball. """ def __init__(self, s: float = 1.0): super().__init__() self.s = s
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: r""" Project x onto the LInfinity-ball of radius `s`. Args: x (torch.Tensor): Input tensor. tau (float): Proximal step size. Not used for LInfinityBall, but kept for API consistency. Returns: torch.Tensor: Resultant tensor after the projection. """ return torch.clamp(x, min=-self.s, max=self.s)
def __call__(self, x: torch.Tensor) -> bool: r"""Check if the tensor satisfies the LInfinity constraint. Args: x (torch.Tensor): Input tensor. Returns: bool: True if the infinity-norm of tensor is less than or equal to `s`, False otherwise. """ return torch.max(torch.abs(x)) <= self.s