Source code for proxtorch.operators.huber

import torch
from torch import Tensor

from proxtorch.base import ProxOperator


[docs]class Huber(ProxOperator): r"""Proximal operator for the Huber penalty.""" def __init__(self, alpha: float = 1.0, delta: float = 1.0): super().__init__() self.alpha = alpha self.delta = delta
[docs] def prox(self, x: Tensor, tau: float) -> Tensor: r""" Apply the proximal operation for the Huber penalty. Args: x (Tensor): Input tensor. tau (float): Proximal step size. Returns: Tensor: Result after applying the Huber operation. """ cond1 = x.abs() <= self.delta cond2 = x > self.delta cond3 = x < -self.delta x[cond1] = x[cond1] / (1.0 + tau * self.alpha) x[cond2] = x[cond2] - tau * self.alpha * self.delta x[cond3] = x[cond3] + tau * self.alpha * self.delta return x
def _nonsmooth(self, x: Tensor) -> torch.Tensor: r"""Compute the Huber penalty for a given input tensor.""" cond1 = x.abs() <= self.delta cond2 = ~cond1 return self.alpha * ( 0.5 * (x[cond1] ** 2).sum() + self.delta * x[cond2].abs().sum() - 0.5 * self.delta**2 * cond2.float().sum() )