Source code for proxtorch.operators.group_lasso

import torch

from proxtorch.base import ProxOperator


[docs]class GroupLasso(ProxOperator): def __init__(self, alpha: float, group_sizes: list): r""" Initialize the GroupLasso operator. Args: alpha (float): Group Lasso regularization parameter. group_sizes (list): List containing the sizes of each group. """ super(GroupLasso, self).__init__() self.alpha = alpha self.group_sizes = group_sizes
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor: r"""Proximal mapping of the Group Lasso operator. Args: x (torch.Tensor): Input tensor. tau (float): Proximal parameter. Returns: torch.Tensor: Result of the proximal mapping. """ start = 0 result = torch.zeros_like(x) for size in self.group_sizes: end = start + size group_norm = torch.norm(x[start:end], p=2) if group_norm > 0: multiplier = max(1 - self.alpha * tau / group_norm, 0) result[start:end] = multiplier * x[start:end] start = end return result
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor: r"""Function call to evaluate the Group Lasso penalty. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Result of the Group Lasso penalty. """ penalty = 0.0 start = 0 for size in self.group_sizes: end = start + size penalty += torch.norm(x[start:end], p=2) start = end return self.alpha * penalty