proxtorch.constraints#
- class proxtorch.constraints.Box(a: float = 0.0, b: float = 1.0)[source]#
Bases:
Constraint- prox(x: Tensor) Tensor[source]#
Projects the tensor onto the feasible set defined by the constraint.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Tensor after projection.
- Return type:
- Raises:
NotImplementedError – If the method is not implemented in a subclass.
- class proxtorch.constraints.Frobenius(s: float = 1.0)[source]#
Bases:
ConstraintConstraint for the Frobenius norm.
- prox(x: Tensor) Tensor[source]#
Projects the tensor onto the feasible set defined by the Frobenius norm constraint.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Tensor after projection.
- Return type:
- class proxtorch.constraints.L0Ball(s: int)[source]#
Bases:
ConstraintL0Prox ball proximal operator.
Projects onto a vector with at most s non-zero elements.
- prox(x: Tensor) Tensor[source]#
Proximal operator for the L0Prox ball.
Keeps the s largest elements in magnitude and sets the rest to zero.
- Parameters:
x (torch.Tensor) – Input tensor.
tau (float) – Proximal step size. Not used for L0Ball, but kept for API consistency.
- Returns:
Resultant tensor after applying the proximal operator.
- Return type:
- class proxtorch.constraints.L1Ball(s: float = 1.0)[source]#
Bases:
ConstraintProjection onto the L1 ball.
- class proxtorch.constraints.L2Ball(s: float = 1.0)[source]#
Bases:
ConstraintProjection onto the L1 ball.
- class proxtorch.constraints.LInfinityBall(s: float = 1.0)[source]#
Bases:
ConstraintProjection onto the LInfinity ball of radius s.
- class proxtorch.constraints.NonNegative[source]#
Bases:
ConstraintProximal operator for the non-negative constraint.
- class proxtorch.constraints.Rank(max_rank: int)[source]#
Bases:
ConstraintConstraint for rank regularization.
- prox(x: Tensor) Tensor[source]#
Projects the tensor onto the feasible set defined by the rank constraint.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Tensor after projection.
- Return type:
- class proxtorch.constraints.TraceNorm(alpha: float = 1.0)[source]#
Bases:
ConstraintConstraint for trace norm regularization.
- prox(x: Tensor) Tensor[source]#
Projects the tensor onto the feasible set defined by the trace norm constraint.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Tensor after projection.
- Return type: