Source code for proxtorch.constraints.frobenius

import torch

from proxtorch.base import Constraint


[docs]class Frobenius(Constraint): r""" Constraint for the Frobenius norm. Attributes: s (float): Regularization strength. """ def __init__(self, s: float = 1.0): super().__init__() self.s = s def __call__(self, x: torch.Tensor) -> bool: r""" Check if the constraint is satisfied for the given tensor. Args: x (torch.Tensor): Input tensor. Returns: bool: True if Frobenius norm of x is less than or equal to s, False otherwise. """ frobenius_norm = torch.norm(x, p="fro") return frobenius_norm <= self.s
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: r""" Projects the tensor onto the feasible set defined by the Frobenius norm constraint. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Tensor after projection. """ frobenius_norm = torch.norm(x, p="fro") if frobenius_norm > self.s: return self.s * (x / frobenius_norm) return x