Source code for proxtorch.constraints.tracenorm
import torch
from proxtorch.base import Constraint
[docs]class TraceNorm(Constraint):
r"""
Constraint for trace norm regularization.
Attributes:
alpha (float): Regularization strength.
"""
def __init__(self, alpha: float = 1.0):
super().__init__()
self.alpha = alpha
def __call__(self, x: torch.Tensor) -> bool:
r"""
Check if the constraint is satisfied for the given tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
bool: True if trace norm of x is less than or equal to s, False otherwise.
"""
singular_values = torch.svd(x).S
trace_norm = torch.sum(singular_values)
return trace_norm <= self.alpha
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor:
r"""
Projects the tensor onto the feasible set defined by the trace norm constraint.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Tensor after projection.
"""
u, s, v = torch.svd(x)
# Clip singular values so that their sum doesn't exceed s
cumulative_s = torch.cumsum(s, dim=0)
k = (cumulative_s <= self.alpha).sum()
if k > 0:
s[k:] = 0 # Set singular values beyond the k-th value to zero
scaling_factor = min(1, self.alpha / cumulative_s[k - 1])
s[:k] *= scaling_factor
else:
s[:] = 0
return u @ torch.diag(s) @ v.T
NuclearNorm = TraceNorm