Source code for proxtorch.constraints.non_negative

import torch
from torch import Tensor

from proxtorch.base import Constraint


[docs]class NonNegative(Constraint): r"""Proximal operator for the non-negative constraint."""
[docs] def prox(self, x: Tensor) -> Tensor: r""" Apply the proximal operation for non-negative constraint. Args: x (Tensor): Input tensor. tau (float): Proximal step size (not used here, but kept for consistency). Returns: Tensor: Result after applying the non-negative constraint. """ return torch.clamp(x, min=0)
def __call__(self, x: torch.Tensor) -> bool: r"""Check if the tensor satisfies the non-negative constraint. Args: x (torch.Tensor): Input tensor. Returns: bool: True if all elements of tensor are non-negative, False otherwise. """ return torch.all(x >= 0)