Source code for proxtorch.operators.tvl1_2d

import torch

from proxtorch.operators.tvl1_3d import TVL1_3D


[docs]class TVL1_2D(TVL1_3D):
[docs] def divergence(self, p: torch.Tensor) -> torch.Tensor: div_x = torch.zeros_like(p[-1]) div_y = torch.zeros_like(p[-1]) div_x[:-1].add_(p[0, :-1, :]) div_y[:, :-1].add_(p[1, :, :-1]) div_x[1:-1].sub_(p[0, :-2, :]) div_y[:, 1:-1].sub_(p[1, :, :-2]) div_x[-1].sub_(p[0, -2, :]) div_y[:, -1].sub_(p[1, :, -2]) return (div_x + div_y) * (1 - self.l1_ratio) - self.l1_ratio * p[-1]