Source code for proxtorch.constraints.l0ball
import torch
from proxtorch.base import Constraint
[docs]class L0Ball(Constraint):
r"""
L0Prox ball proximal operator.
Projects onto a vector with at most `s` non-zero elements.
Attributes:
s (int): Budget of non-zero elements.
"""
def __init__(self, s: int):
super().__init__()
self.s = s
[docs] def prox(self, x: torch.Tensor) -> torch.Tensor:
r"""
Proximal operator for the L0Prox ball.
Keeps the `s` largest elements in magnitude and sets the rest to zero.
Args:
x (torch.Tensor): Input tensor.
tau (float): Proximal step size. Not used for L0Ball, but kept for API consistency.
Returns:
torch.Tensor: Resultant tensor after applying the proximal operator.
"""
_, indices = torch.topk(torch.abs(x), self.s)
result = torch.zeros_like(x)
result[indices] = x[indices]
return result
def __call__(self, x: torch.Tensor) -> bool:
r"""Check if the tensor satisfies the L0 constraint.
Args:
x (torch.Tensor): Input tensor.
Returns:
bool: True if tensor has at most `s` non-zero elements, False otherwise.
"""
return torch.linalg.norm(x.reshape(-1), 0) <= self.s