Source code for proxtorch.operators.tracenorm
import torch
from proxtorch.base import ProxOperator
[docs]class TraceNorm(ProxOperator):
r"""
Proximal operator for the trace norm regularization.
Attributes:
alpha (float): Regularization strength.
"""
def __init__(self, alpha: float = 1.0):
super().__init__()
self.alpha = alpha
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:
r"""
Proximal operator for the trace norm regularization.
Args:
x (torch.Tensor): Input tensor.
tau (float): Proximal step size.
Returns:
torch.Tensor: Resultant tensor after applying the proximal operator.
"""
u, s, v = torch.svd(x)
s = torch.clamp(s - self.alpha * tau, min=0)
return u @ torch.diag(s) @ v.T
def _nonsmooth(self, x: torch.Tensor) -> float:
r"""
Compute the trace norm regularization.
Args:
x (torch.Tensor): Input tensor.
Returns:
float: Trace norm regularization term.
"""
return self.alpha * torch.linalg.norm(x, "nuc")
NuclearNorm = TraceNorm