Source code for proxtorch.constraints.box

# Box Constraint
import torch

from proxtorch.base import Constraint


[docs]class Box(Constraint): def __init__(self, a: float = 0.0, b: float = 1.0): super().__init__() self.a = a self.b = b
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: return torch.clamp(x, min=self.a, max=self.b)
def __call__(self, x: torch.Tensor) -> torch.Tensor: return torch.all(x >= self.a) and torch.all(x <= self.b)