Source code for proxtorch.operators.fused_lasso
import torch
from torch import Tensor
from proxtorch.base import ProxOperator
[docs]class FusedLasso(ProxOperator):
r"""Proximal operator for the 1D Fused Lasso."""
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha
[docs] def prox(self, x: Tensor, tau: float) -> Tensor:
r"""
Apply the proximal operation for the 1D fused lasso.
Uses a simple soft-thresholding approach.
Args:
x (Tensor): Input tensor.
tau (float): Proximal step size.
Returns:
Tensor: Result after applying the fused lasso operation.
Note:
More efficient algorithms exist for larger-scale problems.
"""
diff = x[:-1] - x[1:]
threshold = self.alpha * tau
diff = torch.sign(diff) * torch.clamp(torch.abs(diff) - threshold, min=0)
result = torch.zeros_like(x)
result[0] = x[0] + diff[0]
for i in range(1, len(x)):
result[i] = result[i - 1] - diff[i - 1]
return result
def _nonsmooth(self, x: Tensor) -> torch.Tensor:
r"""Compute the Fused Lasso objective for a given input tensor."""
return self.alpha * torch.sum(torch.abs(x[:-1] - x[1:]))