Source code for plenoptic.models.naive

"""
Very simple visual models.

While these may be useful as is, they are useful when combined with each other or
non-linearities, as in :mod:`plenoptic.models.frontend`.
"""  # numpydoc ignore=EX01

from typing import Any

import torch
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F

from ..process.convolutions import same_padding
from ..process.filters import _validate_filter_args, circular_gaussian2d

__all__ = ["Identity", "Linear", "Gaussian", "CenterSurround"]


def __dir__() -> list[str]:
    return __all__


[docs] class Identity(torch.nn.Module): r""" Simple class that just returns a copy of the image. We use this as a "dummy model" for metrics that we don't have the representation for. We use this as the model and then just change the objective function. Examples -------- >>> import plenoptic as po >>> identity_model = po.models.Identity() >>> identity_model Identity() """ def __init__(self): super().__init__()
[docs] def forward(self, x: Tensor) -> Tensor: """ Return a copy of the tensor. Parameters ---------- x The tensor to return. Returns ------- x A clone of the input tensor. Examples -------- .. plot:: >>> import plenoptic as po >>> identity_model = po.models.Identity() >>> img = po.data.curie() >>> y = identity_model.forward(img) >>> titles = ["Input", "Output (identical)"] >>> po.plot.imshow([img, y], title=titles) <PyrFigure ...> """ # numpydoc ignore=ES01 y = 1 * x return y
[docs] class Linear(nn.Module): r""" Simplistic linear convolutional model. If ``default_filters=True``, this model splits the input image into low and high frequencies. Parameters ---------- kernel_size Convolutional kernel size. pad_mode Mode with which to pad image using :func:`torch.nn.functional.pad()`. default_filters Initialize the filters to a low-pass and a band-pass. If ``False``, filters are randomly initialized. Raises ------ ValueError If kernel_size is not one or two positive integers. Examples -------- >>> import plenoptic as po >>> linear_model = po.models.Linear() >>> linear_model Linear( (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), bias=False) ) To specify the kernel size : >>> linear_model = po.models.Linear(kernel_size=(5, 5)) >>> linear_model Linear( (conv): Conv2d(1, 2, kernel_size=(5, 5), stride=(1, 1), bias=False) ) """ def __init__( self, kernel_size: int | tuple[int, int] = (3, 3), pad_mode: str = "circular", default_filters: bool = True, ): super().__init__() self.pad_mode = pad_mode # std and out_channels are not used by Linear, so set to values we know will # pass self.kernel_size, _, _ = _validate_filter_args(kernel_size, 1, 1) self.conv = nn.Conv2d(1, 2, kernel_size, bias=False) if default_filters: var = torch.as_tensor(3.0) f1 = circular_gaussian2d(kernel_size, std=torch.sqrt(var)) f2 = circular_gaussian2d(kernel_size, std=torch.sqrt(var / 3)) f2 = f2 - f1 f2 = f2 / f2.sum() self.conv.weight.data = torch.cat([f1, f2], dim=0)
[docs] def forward(self, x: Tensor) -> Tensor: """ Convolve filter with input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, with (batch, channel, height, width). Returns ------- y A linear convolution of the input image, of same shape as the input. Examples -------- .. plot:: >>> import plenoptic as po >>> linear_model = po.models.Linear() >>> img = po.data.curie() >>> y = linear_model.forward(img) >>> po.plot.imshow( ... [img, y], ... title=[ ... "Input image", ... "Lowpass channel output", ... "Bandpass channel output", ... ], ... ) <PyrFigure size...> """ y = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) h = self.conv(y) return h
[docs] class Gaussian(nn.Module): """ Isotropic Gaussian convolutional filter. Kernel elements are normalized and sum to one. Parameters ---------- kernel_size Size of convolutional kernel. std Standard deviation of circularly symmetric Gaussian kernel. pad_mode Padding mode argument to pass to :func:`torch.nn.functional.pad`. out_channels Number of filters. If ``None``, inferred from shape of ``std``. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Raises ------ ValueError If out_channels is not a positive integer. ValueError If kernel_size is not a positive integer. ValueError If std is not positive. ValueError If std is non-scalar and ``len(std) != out_channels`` Examples -------- >>> import plenoptic as po >>> gaussian_model = po.models.Gaussian(kernel_size=10) >>> gaussian_model Gaussian() """ def __init__( self, kernel_size: int | tuple[int, int], std: int | list[int] | float | list[float] | Tensor = 3.0, pad_mode: str = "reflect", out_channels: int | None = None, cache_filt: bool = False, ): super().__init__() self.kernel_size, std, out_channels = _validate_filter_args( kernel_size, std, out_channels ) self.std = nn.Parameter(std) self.pad_mode = pad_mode self.out_channels = out_channels self.cache_filt = cache_filt self.register_buffer("_filt", None) @property def filt(self) -> Tensor: """Gaussian filter(s).""" # numpydoc ignore=ES01,RT01,EX01 if self._filt is not None: # use old filter return self._filt else: # create new filter, optionally cache it filt = circular_gaussian2d(self.kernel_size, self.std, self.out_channels) if self.cache_filt: self.register_buffer("_filt", filt) return filt
[docs] def forward(self, x: Tensor, **conv2d_kwargs: Any) -> Tensor: """ Convolve Gaussian filter with input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). **conv2d_kwargs Passed to :func:`torch.nn.functional.conv2d`. Returns ------- y A linear convolution of the input image, of same shape as the input. Examples -------- .. plot:: >>> import plenoptic as po >>> gaussian_model = po.models.Gaussian(kernel_size=10) >>> img = po.data.curie() >>> y = gaussian_model.forward(img) >>> po.plot.imshow([img, y], title=["Input image", "Output"]) <PyrFigure size...> Multiple output channels with different standard deviations. .. plot:: >>> import plenoptic as po >>> gaussian_model = po.models.Gaussian( ... kernel_size=10, std=[2, 5], out_channels=2 ... ) >>> img = po.data.curie() >>> y = gaussian_model.forward(img) >>> po.plot.imshow( ... [img, y], ... title=["Input image", "Output Channel 0", "Output Channel 1"], ... ) <PyrFigure ...> """ self.std.data = self.std.data.abs() # ensure stdev is positive x = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) y = F.conv2d(x, self.filt, **conv2d_kwargs) return y
[docs] class CenterSurround(nn.Module): r""" Center-Surround, Difference of Gaussians (DoG) filter model. Can be either on-center/off-surround, or vice versa. Filter is constructed as: .. code:: f = amplitude_ratio * center - surround f = f / f.sum() The signs of center and surround are determined by ``on_center`` argument. Parameters ---------- kernel_size Shape of convolutional kernel. on_center Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equal ``out_channels``, if just a single bool, then all ``out_channels`` will be assumed to be all on-off or off-on. amplitude_ratio Ratio of center/surround amplitude. Applied before filter normalization. Must be greater than or equal to 1. center_std Standard deviation of circular Gaussian for center. surround_std Standard deviation of circular Gaussian for surround. out_channels Number of filters. If ``None``, inferred from shape of ``center_std``. pad_mode Padding for convolution. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Raises ------ ValueError If out_channels is not a positive integer. ValueError If kernel_size is not a positive integer. ValueError If center_std or surround_std are not positive. ValueError If center_std and surround_std do not have the same number of values. ValueError If center_std or surround_std are non-scalar and their lengths do not equal ``out_channels`` Examples -------- >>> import plenoptic as po >>> cs_model = po.models.CenterSurround(kernel_size=10) >>> cs_model CenterSurround() Model with both on-center/off-surround and off-center/on-surround: >>> import plenoptic as po >>> cs_model = po.models.CenterSurround(10, [True, False]) >>> cs_model CenterSurround() """ def __init__( self, kernel_size: int | tuple[int, int], on_center: bool | list[bool] = True, amplitude_ratio: float = 1.25, center_std: int | list[int] | float | list[float] | Tensor = 1.0, surround_std: int | list[int] | float | list[float] | Tensor = 4.0, out_channels: int | None = None, pad_mode: str = "reflect", cache_filt: bool = False, ): super().__init__() on_center = torch.as_tensor(on_center) if out_channels is None and on_center.numel() != 1: out_channels = len(on_center) self.kernel_size, center_std, out_channels = _validate_filter_args( kernel_size, center_std, out_channels, "center_std", ) _, surround_std, _ = _validate_filter_args( kernel_size, surround_std, out_channels, "surround_std", "len(center_std)" ) self.center_std = nn.Parameter(center_std) self.surround_std = nn.Parameter(surround_std) # make sure each channel is on-off or off-on if on_center.numel() == 1: on_center = on_center.repeat(out_channels) if len(on_center) != out_channels: raise ValueError("len(on_center) must equal out_channels") self.on_center = on_center amplitude_ratio = torch.as_tensor(amplitude_ratio) if amplitude_ratio.nelement() > 1: raise ValueError("amplitude_ratio must be a scalar") if amplitude_ratio < 1.0: raise ValueError("amplitude_ratio must at least be 1.") self.register_buffer("amplitude_ratio", amplitude_ratio) self.out_channels = out_channels self.pad_mode = pad_mode self.cache_filt = cache_filt self.register_buffer("_filt", None) @property def filt(self) -> Tensor: """Center-surround filter(s).""" # numpydoc ignore=ES01,RT01,EX01 if self._filt is not None: # use cached filt return self._filt else: # generate new filt and optionally cache filt_center = circular_gaussian2d( self.kernel_size, self.center_std, self.out_channels ) filt_surround = circular_gaussian2d( self.kernel_size, self.surround_std, self.out_channels ) # sign is + or - depending on center is on or off sign = torch.as_tensor( [1.0 if x else -1.0 for x in self.on_center], device=self.amplitude_ratio.device, ) sign = sign.view(self.out_channels, 1, 1, 1) filt = self.amplitude_ratio * (sign * (filt_center - filt_surround)) if self.cache_filt: self.register_buffer("_filt", filt) return filt
[docs] def forward(self, x: Tensor) -> Tensor: """ Convolve center-surround filter with input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). Returns ------- y A linear convolution of the input image, of same shape as the input. Examples -------- .. plot:: >>> import plenoptic as po >>> cs_model = po.models.CenterSurround(kernel_size=10) >>> img = po.data.curie() >>> y = cs_model.forward(img) >>> po.plot.imshow([img, y], title=["Input image", "Output"]) <PyrFigure size...> Model with both on-center/off-surround and off-center/on-surround: .. plot:: >>> import plenoptic as po >>> cs_model = po.models.CenterSurround(10, [True, False]) >>> img = po.data.curie() >>> y = cs_model.forward(img) >>> titles = [ ... "Input image", ... "On-center/off-surround", ... "Off-center/on-surround", ... ] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...> """ x = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) y = F.conv2d(x, self.filt, bias=None) return y