Source code for proxtorch.operators.tvl1_3d

"""
Total Variation L1 (TV-L1) denoising algorithm.

Inspired by Nilearn's proxtvl1 implementation.

References:
- Wikipedia: Total Variation Denoising (http://en.wikipedia.org/wiki/Total_variation_denoising)
- Beck, A., & Teboulle, M. (2009). Fast gradient-based algorithms for constrained total variation image denoising and deblurring problems.
- Nilearn (https://nilearn.github.io/)
"""

from math import sqrt

import torch
import torch.nn.functional as F

from proxtorch.base import ProxOperator


def get_padding_tuple(dim_index, ndim):
    """
    Return a padding tuple for a specified dimension index.

    Args:
        dim_index (int): Index of the dimension for which padding is needed.
        ndim (int): Total number of dimensions in the tensor.

    Returns:
        tuple: Padding tuple with a value of 1 at the specified dimension and 0 elsewhere.
    """
    padding_tuple = [0] * (ndim * 2)
    padding_tuple[-2 * dim_index - 1] = 1
    return tuple(padding_tuple)


[docs]class TVL1_3D(ProxOperator): """ Class for the 3D Total Variation proximal operator. """ def __init__( self, alpha: float, l1_ratio=0.05, max_iter: int = 200, tol: float = 5e-5, ) -> None: """ Initialize the 3D Total Variation proximal operator. Args: alpha (float): Regularization strength. max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 200. tol (float, optional): Tolerance level for early stopping. Defaults to 1e-4. l1_ratio (float, optional): The L1 ratio. Defaults to 0.0. """ super().__init__() self.alpha = alpha self.max_iter = max_iter self.tol = tol self.l1_ratio = l1_ratio
[docs] def gradient(self, x): """ Compute the gradient of the tensor x using finite differences. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Gradients of the tensor x. """ gradients = torch.zeros((x.dim() + 1,) + x.shape, device=x.device) # For each dimension compute the gradient using torch.diff for d in range(x.dim()): gradients[d, ...] = F.pad( torch.diff(x, dim=d, n=1), pad=get_padding_tuple(d, x.dim()) ) gradients[:-1] *= 1.0 - self.l1_ratio gradients[-1] = self.l1_ratio * x return gradients
[docs] def divergence(self, p: torch.Tensor) -> torch.Tensor: """ Compute the divergence of the tensor p. Args: p (torch.Tensor): Input tensor. Returns: torch.Tensor: Divergence of the tensor p. """ div_x = torch.zeros_like(p[-1]) div_y = torch.zeros_like(p[-1]) div_z = torch.zeros_like(p[-1]) div_x[:-1].add_(p[0, :-1, :, :]) div_y[:, :-1].add_(p[1, :, :-1, :]) div_z[:, :, :-1].add_(p[2, :, :, :-1]) div_x[1:-1].sub_(p[0, :-2, :, :]) div_y[:, 1:-1].sub_(p[1, :, :-2, :]) div_z[:, :, 1:-1].sub_(p[2, :, :, :-2]) div_x[-1].sub_(p[0, -2, :, :]) div_y[:, -1].sub_(p[1, :, -2, :]) div_z[:, :, -1].sub_(p[2, :, :, -2]) return (div_x + div_y + div_z) * (1 - self.l1_ratio) - self.l1_ratio * p[-1]
def _projector_on_tvl1_dual(self, grad): """ Function to compute TV-l1 duality gap. Modifies IN PLACE the gradient + id to project it on the l21 unit ball in the gradient direction and the L1 ball in the identity direction. Args: grad (torch.Tensor): Gradient tensor. Returns: torch.Tensor: Projected gradient tensor. """ # The l21 ball for the gradient direction if self.l1_ratio < 1.0: # infer number of axes and include an additional axis if l1_ratio > 0 end = len(grad) - int(self.l1_ratio > 0.0) norm = torch.sqrt(torch.sum(grad[:end] * grad[:end], 0)) norm = torch.clamp(norm, min=1.0) # set everything < 1 to 1 for i in range(end): grad[i] /= norm # The L1 ball for the identity direction if self.l1_ratio > 0.0: norm = torch.abs(grad[-1]) norm = torch.clamp(norm, min=1.0) grad[-1] /= norm return grad def _dual_gap_prox_tvl1(self, input_img_norm, new, gap, weight): """ Compute the dual gap of total variation denoising. Args: input_img_norm (float): Norm of the input image. new (torch.Tensor): Updated tensor. gap (torch.Tensor): Gap tensor. weight (float): Regularization strength. Returns: float: Dual gap value. Notes: see "Total variation regularization for fMRI-based prediction of behavior", by Michel et al. (2011) for a derivation of the dual gap """ tv_new = self.tvl1_from_grad(self.gradient(new)) gap = gap.view(-1) d_gap = ( torch.dot(gap, gap) + 2 * weight * tv_new - input_img_norm + torch.sum(new * new) ) return 0.5 * d_gap
[docs] def prox(self, x: torch.Tensor, lr: float) -> torch.Tensor: """ Iterative algorithm to compute the proximal mapping of the tensor. Args: x (torch.Tensor): Input tensor. lr (float): Learning rate. Returns: torch.Tensor: Tensor after applying the proximal operation. Notes ----- Total variation denoising aims to minimize the total variation of the image, which can be roughly described as the integral of the norm of the image gradient. As a result, it produces "cartoon-like" images, i.e., piecewise-constant images. For more details, refer to: http://en.wikipedia.org/wiki/Total_variation_denoising This function implements the FISTA (Fast Iterative Shrinkage Thresholding Algorithm) algorithm of Beck et Teboulle, adapted to total variation denoising in "Fast gradient-based algorithms for constrained total variation image denoising and deblurring problems" (2009). For more on bound constraints implementation, see the aforementioned Beck and Teboulle paper. """ fista = True weight = self.alpha * lr input_shape = x.shape input_img_norm = torch.norm(x) ** 2 lipschitz_constant = 1.1 * (4 * 3) negated_output = -x grad_aux = torch.zeros_like(self.gradient(x)) grad_im = torch.zeros_like(grad_aux) t = 1.0 i = 0 dgap = torch.tensor(float("inf")).to(x.device) while i < self.max_iter: # tv_prev = self.tv_from_grad(self.gradient(output)) grad_tmp = self.gradient(negated_output) grad_tmp *= 1.0 / (lipschitz_constant * weight) grad_aux += grad_tmp grad_tmp = self._projector_on_tvl1_dual(grad_aux) # Careful, in the next few lines, grad_tmp and grad_aux are a # view on the same array, as _projector_on_tvl1_dual returns a view # on the input array t_new = 0.5 * (1 + sqrt(1 + 4 * t**2)) t_factor = (t - 1) / t_new if fista: # fista grad_aux = (1 + t_factor) * grad_tmp - t_factor * grad_im else: # ista grad_aux = grad_tmp grad_im = grad_tmp t = t_new gap = weight * self.divergence(grad_aux) negated_output = gap - x if i % 4 == 0: old_dgap = dgap dgap = self._dual_gap_prox_tvl1( input_img_norm, -negated_output, gap, weight ) if dgap < self.tol: break if old_dgap < dgap: fista = False i += 1 output = x - weight * self.divergence(grad_im) return output.reshape(input_shape)
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor: """ Compute the Total Variation (TV) for a given tensor x. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: The TV of the tensor x. """ gradients = self.gradient(x) return self.tvl1_from_grad(gradients) * self.alpha
[docs] @staticmethod def tvl1_from_grad(gradients: torch.Tensor) -> torch.Tensor: r""" Calculate the TV from gradients. Args: gradients (torch.Tensor): Gradient tensor. Returns: float: The TV value computed from the gradients. """ tv = torch.sum(torch.sqrt(torch.sum(gradients[:-1] * gradients[:-1], dim=0))) l1 = torch.sum(torch.abs(gradients[-1])) return tv + l1