Source code for plenoptic.simulate.canonical_computations.laplacian_pyramid

import torch
import torch.nn as nn

from ...tools.conv import blur_downsample, upsample_blur


[docs] class LaplacianPyramid(nn.Module): """Laplacian Pyramid in Torch. The Laplacian pyramid [1]_ is a multiscale image representation. It decomposes the image by computing the local mean using Gaussian blurring filters and substracting it from the image and repeating this operation on the local mean itself after downsampling. This representation is overcomplete and invertible. Parameters ---------- n_scales: int number of scales to compute scale_filter: bool, optional If true, the norm of the downsampling/upsampling filter is 1. If false (default), it is 2. If the norm is 1, the image is multiplied by 4 during the upsampling operation; the net effect is that the `n`th scale of the pyramid is divided by `2^n`. References ---------- .. [1] Burt, P. and Adelson, E., 1983. The Laplacian pyramid as a compact image code. IEEE Transactions on communications, 31(4), pp.532-540. """ def __init__(self, n_scales=5, scale_filter=False): super().__init__() self.n_scales = n_scales self.scale_filter = scale_filter
[docs] def forward(self, x): """Build the Laplacian pyramid of an image. Parameters ---------- x: torch.Tensor of shape (batch, channel, height, width) Image, or batch of images. If there are multiple channels, the Laplacian is computed separately for each of them Returns ------- y: list of torch.Tensor Laplacian pyramid representation, each element of the list corresponds to a scale, from fine to coarse """ y = [] for scale in range(self.n_scales - 1): odd = torch.as_tensor(x.shape)[2:4] % 2 x_down = blur_downsample(x, scale_filter=self.scale_filter) x_up = upsample_blur(x_down, odd, scale_filter=self.scale_filter) y.append(x - x_up) x = x_down y.append(x) return y
[docs] def recon_pyr(self, y): """Reconstruct the image from its Laplacian pyramid. Parameters ---------- y: list of torch.Tensor Laplacian pyramid representation, each element of the list corresponds to a scale, from fine to coarse Returns ------- x: torch.Tensor of shape (batch, channel, height, width) Image, or batch of images """ x = y[self.n_scales - 1] for scale in range(self.n_scales - 1, 0, -1): odd = torch.as_tensor(y[scale - 1].shape)[2:4] % 2 y_up = upsample_blur(x, odd, scale_filter=self.scale_filter) x = y[scale - 1] + y_up return x