Source code for plenoptic.simulate.models.naive

import torch
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F

from ...tools.conv import same_padding
from ..canonical_computations.filters import circular_gaussian2d

__all__ = ["Identity", "Linear", "Gaussian", "CenterSurround"]


[docs] class Identity(torch.nn.Module): r"""simple class that just returns a copy of the image We use this as a "dummy model" for metrics that we don't have the representation for. We use this as the model and then just change the objective function. """ def __init__(self, name=None): super().__init__() if name is not None: self.name = name
[docs] def forward(self, img): """Return a copy of the image Parameters ---------- img : torch.Tensor The image to return Returns ------- img : torch.Tensor a clone of the input image """ y = 1 * img return y
[docs] class Linear(nn.Module): r"""Simplistic linear convolutional model: It splits the input greyscale image into low and high frequencies. Parameters ---------- kernel_size: Convolutional kernel size. pad_mode: Mode with which to pad image using `nn.functional.pad()`. default_filters: Initialize the filters to a low-pass and a band-pass. """ def __init__( self, kernel_size: int | tuple[int, int] = (3, 3), pad_mode: str = "circular", default_filters: bool = True, ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.kernel_size = kernel_size self.pad_mode = pad_mode self.conv = nn.Conv2d(1, 2, kernel_size, bias=False) if default_filters: var = torch.as_tensor(3.0) f1 = circular_gaussian2d(kernel_size, std=torch.sqrt(var)) f2 = circular_gaussian2d(kernel_size, std=torch.sqrt(var / 3)) f2 = f2 - f1 f2 = f2 / f2.sum() self.conv.weight.data = torch.cat([f1, f2], dim=0)
[docs] def forward(self, x: Tensor) -> Tensor: y = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) h = self.conv(y) return h
[docs] class Gaussian(nn.Module): """Isotropic Gaussian convolutional filter. Kernel elements are normalized and sum to one. Parameters ---------- kernel_size: Size of convolutional kernel. std: Standard deviation of circularly symmtric Gaussian kernel. pad_mode: Padding mode argument to pass to `torch.nn.functional.pad`. out_channels: Number of filters with which to convolve. cache_filt: Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to `self._filt`. """ def __init__( self, kernel_size: int | tuple[int, int], std: float | Tensor = 3.0, pad_mode: str = "reflect", out_channels: int = 1, cache_filt: bool = False, ): super().__init__() assert std > 0, "Gaussian standard deviation must be positive" if isinstance(std, float) or std.shape == torch.Size([]): std = torch.ones(out_channels) * std self.std = nn.Parameter(torch.as_tensor(std)) if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.kernel_size = kernel_size self.pad_mode = pad_mode self.out_channels = out_channels self.cache_filt = cache_filt self.register_buffer("_filt", None) @property def filt(self): if self._filt is not None: # use old filter return self._filt else: # create new filter, optionally cache it filt = circular_gaussian2d(self.kernel_size, self.std, self.out_channels) if self.cache_filt: self.register_buffer("_filt", filt) return filt
[docs] def forward(self, x: Tensor, **conv2d_kwargs) -> Tensor: self.std.data = self.std.data.abs() # ensure stdev is positive x = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) y = F.conv2d(x, self.filt, **conv2d_kwargs) return y
[docs] class CenterSurround(nn.Module): """Center-Surround, Difference of Gaussians (DoG) filter model. Can be either on-center/off-surround, or vice versa. Filter is constructed as: .. math:: f &= amplitude_ratio * center - surround \\ f &= f/f.sum() The signs of center and surround are determined by ``on_center`` argument. Parameters ---------- kernel_size: Shape of convolutional kernel. on_center: Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equal `out_channels`, if just a single bool, then all `out_channels` will be assumed to be all on-off or off-on. amplitude_ratio: Ratio of center/surround amplitude. Applied before filter normalization. center_std: Standard deviation of circular Gaussian for center. surround_std: Standard deviation of circular Gaussian for surround. Must be at least `ratio_limit` times `center_std`. out_channels: Number of filters. pad_mode: Padding for convolution, defaults to "circular". cache_filt: Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to `self._filt` """ def __init__( self, kernel_size: int | tuple[int, int], on_center: bool | list[bool] = True, amplitude_ratio: float = 1.25, center_std: float | Tensor = 1.0, surround_std: float | Tensor = 4.0, out_channels: int = 1, pad_mode: str = "reflect", cache_filt: bool = False, ): super().__init__() # make sure each channel is on-off or off-on if isinstance(on_center, bool): on_center = [on_center] * out_channels assert len(on_center) == out_channels, "len(on_center) must match out_channels" # make sure each channel has a center and surround std if isinstance(center_std, float) or center_std.shape == torch.Size([]): center_std = torch.ones(out_channels) * center_std if isinstance(surround_std, float) or surround_std.shape == torch.Size([]): surround_std = torch.ones(out_channels) * surround_std assert ( len(center_std) == out_channels and len(surround_std) == out_channels ), "stds must correspond to each out_channel" assert amplitude_ratio >= 1.0, "ratio of amplitudes must at least be 1." self.on_center = on_center if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.kernel_size = kernel_size self.register_buffer("amplitude_ratio", torch.as_tensor(amplitude_ratio)) self.center_std = nn.Parameter(torch.ones(out_channels) * center_std) self.surround_std = nn.Parameter(torch.ones(out_channels) * surround_std) self.out_channels = out_channels self.pad_mode = pad_mode self.cache_filt = cache_filt self.register_buffer("_filt", None) @property def filt(self) -> Tensor: """Creates an on center/off surround, or off center/on surround conv filter""" if self._filt is not None: # use cached filt return self._filt else: # generate new filt and optionally cache on_amp = self.amplitude_ratio device = on_amp.device filt_center = circular_gaussian2d( self.kernel_size, self.center_std, self.out_channels ) filt_surround = circular_gaussian2d( self.kernel_size, self.surround_std, self.out_channels ) # sign is + or - depending on center is on or off sign = torch.as_tensor([1.0 if x else -1.0 for x in self.on_center]).to( device ) sign = sign.view(self.out_channels, 1, 1, 1) filt = on_amp * (sign * (filt_center - filt_surround)) if self.cache_filt: self.register_buffer("_filt", filt) return filt
[docs] def forward(self, x: Tensor) -> Tensor: x = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) y = F.conv2d(x, self.filt, bias=None) return y