Source code for proxtorch.operators.graphnet
import torch
from proxtorch.operators.l1 import L1
from proxtorch.operators.tvl1_2d import TVL1_2D
from proxtorch.operators.tvl1_3d import TVL1_3D
[docs]class GraphNet3D(TVL1_3D):
def __init__(self, alpha, l1_ratio):
super().__init__(alpha=alpha, l1_ratio=l1_ratio)
self.l1_prox = L1(alpha * l1_ratio)
[docs] def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:
return self.l1_prox.prox(x, tau)
def _smooth(self, x: torch.Tensor) -> torch.Tensor:
# The last channel is the for the l1 norm
grad = self.gradient(x)[:-1] / (1 - self.l1_ratio)
# sum of squares of the gradients
norm = torch.sum(grad**2)
return 0.5 * norm * self.alpha * (1 - self.l1_ratio)
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor:
l1 = self.l1_prox(x)
return l1
[docs]class GraphNet2D(GraphNet3D, TVL1_2D):
pass