Source code for proxtorch.constraints.rank

import torch

from proxtorch.base import Constraint


[docs]class Rank(Constraint): r""" Constraint for rank regularization. Attributes: max_rank (int): Maximum allowable rank. """ def __init__(self, max_rank: int): super().__init__() self.max_rank = max_rank def __call__(self, x: torch.Tensor) -> bool: r""" Check if the constraint is satisfied for the given tensor. Args: x (torch.Tensor): Input tensor. Returns: bool: True if rank of x is less than or equal to max_rank, False otherwise. """ singular_values = torch.svd(x).S rank = ( (singular_values > 1e-5).sum().item() ) # Count non-negligible singular values to determine rank return rank <= self.max_rank
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor: r""" Projects the tensor onto the feasible set defined by the rank constraint. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Tensor after projection. """ u, s, v = torch.svd(x) # Set singular values beyond the max_rank-th value to zero s[self.max_rank :] = 0 return u @ torch.diag(s) @ v.T