Source code for plenoptic.models.frontend

"""
Simple convolutional models of the visual system's front-end.

All models are some combination of linear filtering, non-linear activation, and
(optionally) gain control. Model architectures in this file are described in Berardino
et al., 2017 [1]_, found online [2]_, and the pretrained parameters come from Berardino,
2018 [3]_.

References
----------
.. [1] A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical
    representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266
.. [2] https://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html
.. [3] A Berardino, Hierarchically normalized models of visual distortion
   sensitivity: Physiology, perception, and application; Ph.D. Thesis,
   2018; https://www.cns.nyu.edu/pub/lcv/berardino-phd.pdf
"""  # numpydoc ignore=EX01

from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from warnings import warn

import torch
import torch.nn as nn
import torch.nn.functional as F
from pyrtools.tools.display import PyrFigure
from torch import Tensor

from ..data import disk
from ..plot import imshow
from .naive import CenterSurround, Gaussian

__all__ = [
    "LinearNonlinear",
    "LuminanceGainControl",
    "LuminanceContrastGainControl",
    "OnOff",
]


def __dir__() -> list[str]:
    return __all__


[docs] class LinearNonlinear(nn.Module): """ Linear-Nonlinear model. This model applies a difference of Gaussians filter followed by an activation function. Model is described in Berardino et al., 2017 [4]_ and online [5]_, where it is called LN. Parameters ---------- kernel_size Shape of convolutional kernel. on_center Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). amplitude_ratio Ratio of center/surround amplitude. Applied before filter normalization. pad_mode Padding for convolution. pretrained Whether or not to load model params from Berardion, 2018 [6]_. See Notes for details. activation Activation function following linear convolution. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Attributes ---------- center_surround: ~plenoptic.models.CenterSurround Difference of Gaussians filter. Notes ----- These 2 parameters (standard deviations) were taken from Table 2, page 149 from Berardino, 2018 [6]_ and are the values used in Berardino et al., 2017 [4]_. Please use these pretrained weights at your own discretion. References ---------- .. [4] A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266 .. [5] https://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html .. [6] A Berardino, Hierarchically normalized models of visual distortion sensitivity: Physiology, perception, and application; Ph.D. Thesis, 2018; https://www.cns.nyu.edu/pub/lcv/berardino-phd.pdf Examples -------- >>> import plenoptic as po >>> ln_model = po.models.LinearNonlinear(31, pretrained=True, cache_filt=True) """ def __init__( self, kernel_size: int | tuple[int, int], on_center: bool = True, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", pretrained: bool = False, activation: Callable[[Tensor], Tensor] = F.softplus, cache_filt: bool = False, ): super().__init__() if pretrained: if kernel_size not in [31, (31, 31)]: raise ValueError("pretrained model has kernel_size (31, 31)") if cache_filt is False: warn( "pretrained is True but cache_filt is False. Set cache_filt to " "True for efficiency unless you are fine-tuning." ) if not on_center: warn( "pretrained model had on_center=True, so on_center=False might " "not make sense" ) self.center_surround = CenterSurround( kernel_size, on_center, amplitude_ratio, pad_mode=pad_mode, cache_filt=cache_filt, ) if pretrained: self.load_state_dict(self._pretrained_state_dict()) self.activation = activation
[docs] def forward(self, x: Tensor) -> Tensor: """ Compute model response on input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). Returns ------- y Model response to input. Examples -------- .. plot:: >>> import plenoptic as po >>> ln_model = po.models.LinearNonlinear(31, pretrained=True, cache_filt=True) >>> img = po.data.einstein() >>> y = ln_model.forward(img) >>> titles = ["Input image", "Output"] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...> """ y = self.activation(self.center_surround(x)) return y
[docs] def display_filters( self, vrange: tuple[float, float] | str = "indep0", zoom: float | None = 5.0, title: str | list[str] | None = "linear filter", **kwargs: Any, ) -> PyrFigure: """ Display convolutional filter of model. Parameters ---------- vrange, zoom, title Arguments for :func:`~plenoptic.plot.imshow`, see its docstrings for details. **kwargs Keyword args for :func:`~plenoptic.plot.imshow`. Returns ------- fig: The figure containing the image. Examples -------- .. plot:: >>> import plenoptic as po >>> ln_model = po.models.LinearNonlinear(31, pretrained=True, cache_filt=True) >>> ln_model.display_filters() <PyrFigure ...> """ # numpydoc ignore=ES01 weights = self.center_surround.filt.detach() fig = imshow(weights, title=title, zoom=zoom, vrange=vrange, **kwargs) return fig
@staticmethod def _pretrained_state_dict() -> OrderedDict: """ Return parameters fit to human distortion judgments. Values copied from Table 2 in Berardino, 2018 [6]_. Returns ------- state_dict Dictionary of parameters, to pass to :func:`load_state_dict`. """ # numpydoc ignore=EX01 state_dict = OrderedDict( [ ("center_surround.center_std", torch.as_tensor([0.5339])), ("center_surround.surround_std", torch.as_tensor([6.148])), ("center_surround.amplitude_ratio", torch.as_tensor([1.25])), ] ) return state_dict
[docs] class LuminanceGainControl(nn.Module): """ Linear center-surround followed by luminance gain control and activation. Model is described in Berardino et al., 2017 [7]_ and online [8]_, where it is called LG. Parameters ---------- kernel_size Shape of convolutional kernel. on_center Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). amplitude_ratio Ratio of center/surround amplitude. Applied before filter normalization. pad_mode Padding for convolution. pretrained Whether or not to load model params from Berardino, 2018 [9]_. See Notes for details. activation Activation function following linear convolution. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Attributes ---------- center_surround: ~plenoptic.models.CenterSurround Difference of Gaussians linear filter. luminance: ~plenoptic.models.Gaussian Gaussian convolutional kernel used to normalize signal by local luminance. luminance_scalar: torch.nn.parameter.Parameter Scale factor for luminance normalization. Notes ----- These 4 parameters (standard deviations and scalar constants) were taken from Table 2, page 149 from Berardino, 2018 [9]_ and are the values used Berardino et al., 2017 [7]_. Please use these pretrained weights at your own discretion. References ---------- .. [7] A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266 .. [8] https://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html .. [9] A Berardino, Hierarchically normalized models of visual distortion sensitivity: Physiology, perception, and application; Ph.D. Thesis, 2018; https://www.cns.nyu.edu/pub/lcv/berardino-phd.pdf Examples -------- >>> import plenoptic as po >>> lg_model = po.models.LuminanceGainControl(31, pretrained=True, cache_filt=True) """ def __init__( self, kernel_size: int | tuple[int, int], on_center: bool = True, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", pretrained: bool = False, activation: Callable[[Tensor], Tensor] = F.softplus, cache_filt: bool = False, ): super().__init__() if pretrained: if kernel_size not in [31, (31, 31)]: raise ValueError("pretrained model has kernel_size (31, 31)") if cache_filt is False: warn( "pretrained is True but cache_filt is False. Set cache_filt to " "True for efficiency unless you are fine-tuning." ) if not on_center: warn( "pretrained model had on_center=True, so on_center=False might " "not make sense" ) self.center_surround = CenterSurround( kernel_size, on_center, amplitude_ratio, pad_mode=pad_mode, cache_filt=cache_filt, ) self.luminance = Gaussian( kernel_size=kernel_size, pad_mode=pad_mode, cache_filt=cache_filt, ) self.luminance_scalar = nn.Parameter(torch.rand(1) * 10) if pretrained: self.load_state_dict(self._pretrained_state_dict()) self.activation = activation
[docs] def forward(self, x: Tensor) -> Tensor: """ Compute model response on input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). Returns ------- y Model response to input. Examples -------- .. plot:: >>> import plenoptic as po >>> lg_model = po.models.LuminanceGainControl( ... 31, pretrained=True, cache_filt=True ... ) >>> img = po.data.einstein() >>> y = lg_model.forward(img) >>> titles = ["Input image", "Output"] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...> """ linear = self.center_surround(x) lum = self.luminance(x) lum_normed = linear / (1 + self.luminance_scalar * lum) y = self.activation(lum_normed) return y
[docs] def display_filters( self, vrange: tuple[float, float] | str = "indep0", zoom: float | None = 5.0, title: str | list[str] | None = ["linear filt", "luminance filt"], col_wrap: int | None = 2, **kwargs: Any, ) -> PyrFigure: """ Display convolutional filters of model. Parameters ---------- vrange, zoom, title, col_wrap Arguments for :func:`~plenoptic.plot.imshow`, see its docstrings for details. **kwargs Keyword args for :func:`~plenoptic.plot.imshow`. Returns ------- fig: The figure containing the displayed filters. Examples -------- .. plot:: >>> import plenoptic as po >>> lg_model = po.models.LuminanceGainControl( ... 31, pretrained=True, cache_filt=True ... ) >>> lg_model.display_filters() <PyrFigure ...> """ # numpydoc ignore=ES01 weights = torch.cat( [ self.center_surround.filt, self.luminance.filt, ], dim=0, ).detach() fig = imshow( weights, title=title, col_wrap=col_wrap, zoom=zoom, vrange=vrange, **kwargs, ) return fig
@staticmethod def _pretrained_state_dict() -> OrderedDict: """ Return parameters fit to human distortion judgments. Values copied from Table 2 in Berardino, 2018 [9]_. Returns ------- state_dict Dictionary of parameters, to pass to :func:`load_state_dict`. """ # numpydoc ignore=EX01 state_dict = OrderedDict( [ ("luminance_scalar", torch.as_tensor([14.95])), ("center_surround.center_std", torch.as_tensor([1.962])), ("center_surround.surround_std", torch.as_tensor([4.235])), ("center_surround.amplitude_ratio", torch.as_tensor([1.25])), ("luminance.std", torch.as_tensor([4.235])), ] ) return state_dict
[docs] class LuminanceContrastGainControl(nn.Module): """ Center-surround followed by luminance and contrast gain control, then activation. Model is described in Berardino et al., 2017 [10]_ and online [11]_, where it is called LGG. Parameters ---------- kernel_size Shape of convolutional kernel. on_center Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). amplitude_ratio Ratio of center/surround amplitude. Applied before filter normalization. pad_mode Padding for convolution. pretrained Whether or not to load model params from Berardino, 2018 [12]_. See Notes for details. activation Activation function following linear convolution. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Attributes ---------- center_surround: ~plenoptic.models.CenterSurround Difference of Gaussians linear filter. luminance: ~plenoptic.models.Gaussian Gaussian convolutional kernel used to normalize signal by local luminance. contrast: ~plenoptic.models.Gaussian Gaussian convolutional kernel used to normalize signal by local contrast. luminance_scalar: torch.nn.parameter.Parameter Scale factor for luminance normalization. contrast_scalar: torch.nn.parameter.Parameter Scale factor for contrast normalization. Notes ----- These 6 parameters (standard deviations and constants) were taken from Table 2, page 149 from Berardino, 2018 [12]_ and are the values used Berardino et al., 2017 [10]_. Please use these pretrained weights at your own discretion. References ---------- .. [10] A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266 .. [11] https://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html .. [12] A Berardino, Hierarchically normalized models of visual distortion sensitivity: Physiology, perception, and application; Ph.D. Thesis, 2018; https://www.cns.nyu.edu/pub/lcv/berardino-phd.pdf Examples -------- >>> import plenoptic as po >>> lgg_model = po.models.LuminanceContrastGainControl( ... 31, pretrained=True, cache_filt=True ... ) """ def __init__( self, kernel_size: int | tuple[int, int], on_center: bool = True, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", pretrained: bool = False, activation: Callable[[Tensor], Tensor] = F.softplus, cache_filt: bool = False, ): super().__init__() if pretrained: if kernel_size not in [31, (31, 31)]: raise ValueError("pretrained model has kernel_size (31, 31)") if cache_filt is False: warn( "pretrained is True but cache_filt is False. Set cache_filt to " "True for efficiency unless you are fine-tuning." ) if not on_center: warn( "pretrained model had on_center=True, so on_center=False might " "not make sense" ) self.center_surround = CenterSurround( kernel_size, on_center, amplitude_ratio, pad_mode=pad_mode, cache_filt=cache_filt, ) self.luminance = Gaussian( kernel_size=kernel_size, pad_mode=pad_mode, cache_filt=cache_filt, ) self.contrast = Gaussian( kernel_size=kernel_size, pad_mode=pad_mode, cache_filt=cache_filt, ) self.luminance_scalar = nn.Parameter(torch.rand(1) * 10) self.contrast_scalar = nn.Parameter(torch.rand(1) * 10) if pretrained: self.load_state_dict(self._pretrained_state_dict()) self.activation = activation
[docs] def forward(self, x: Tensor) -> Tensor: """ Compute model response on input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). Returns ------- y Model response to input. Examples -------- .. plot:: >>> import plenoptic as po >>> lgg_model = po.models.LuminanceContrastGainControl( ... 31, pretrained=True, cache_filt=True ... ) >>> img = po.data.einstein() >>> y = lgg_model.forward(img) >>> titles = ["Input image", "Output"] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...> """ linear = self.center_surround(x) lum = self.luminance(x) lum_normed = linear / (1 + self.luminance_scalar * lum) con = self.contrast(lum_normed.pow(2)).sqrt() + 1e-6 # avoid div by zero con_normed = lum_normed / (1 + self.contrast_scalar * con) y = self.activation(con_normed) return y
[docs] def display_filters( self, vrange: tuple[float, float] | str = "indep0", zoom: float | None = 5.0, title: str | list[str] | None = [ "linear filt", "luminance filt", "contrast filt", ], col_wrap: int | None = 3, **kwargs: Any, ) -> PyrFigure: """ Display convolutional filters of model. Parameters ---------- vrange, zoom, title, col_wrap Arguments for :func:`~plenoptic.plot.imshow`, see its docstrings for details. **kwargs Keyword args for :func:`~plenoptic.plot.imshow`. Returns ------- fig: The figure containing the displayed filters. Examples -------- .. plot:: >>> import plenoptic as po >>> lgg_model = po.models.LuminanceContrastGainControl( ... 31, pretrained=True, cache_filt=True ... ) >>> lgg_model.display_filters() <PyrFigure ...> """ # numpydoc ignore=ES01 weights = torch.cat( [ self.center_surround.filt, self.luminance.filt, self.contrast.filt, ], dim=0, ).detach() fig = imshow( weights, title=title, col_wrap=col_wrap, zoom=zoom, vrange=vrange, **kwargs, ) return fig
@staticmethod def _pretrained_state_dict() -> OrderedDict: """ Return parameters fit to human distortion judgments. Values copied from Table 2 in Berardino, 2018 [12]_. Returns ------- state_dict Dictionary of parameters, to pass to :func:`load_state_dict`. """ # numpydoc ignore=EX01 state_dict = OrderedDict( [ ("luminance_scalar", torch.as_tensor([2.94])), ("contrast_scalar", torch.as_tensor([34.03])), ("center_surround.center_std", torch.as_tensor([0.7363])), ("center_surround.surround_std", torch.as_tensor([48.37])), ("center_surround.amplitude_ratio", torch.as_tensor([1.25])), ("luminance.std", torch.as_tensor([170.99])), ("contrast.std", torch.as_tensor([2.658])), ] ) return state_dict
[docs] class OnOff(nn.Module): """ On-off and off-on center-surround with contrast and luminance gain control. Model is described in Berardino et al., 2017 [13]_ and online [14]_, where it is called OnOff. Parameters ---------- kernel_size Shape of convolutional kernel. amplitude_ratio Ratio of center/surround amplitude. Applied before filter normalization. pad_mode Padding for convolution. pretrained Whether or not to load model params from Berardino, 2018 [15]_. See Notes for details. activation Activation function following linear and gain control operations. apply_mask Whether or not to apply circular disk mask centered on the input image. This is useful for synthesis methods like Eigendistortions to ensure that the synthesized distortion will not appear in the periphery. See :func:`plenoptic.data.disk()` for details on how mask is created. cache_filt Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Attributes ---------- center_surround: ~plenoptic.models.CenterSurround 2-channel (on-off and off-on) difference of Gaussians linear filter. luminance: ~plenoptic.models.Gaussian 2-channel Gaussian convolutional kernel used to normalize signal by local luminance. contrast: ~plenoptic.models.Gaussian 2-channel Gaussian convolutional kernel used to normalize signal by local contrast. luminance_scalar: torch.nn.parameter.Parameter Scale factor for luminance normalization. contrast_scalar: torch.nn.parameter.Parameter Scale factor for contrast normalization. Notes ----- These 12 parameters (standard deviations & scalar constants) were taken from Table 2, page 149 from Berardino, 2018 [15]_ and are the values used Berardino et al., 2017 [13]_. Please use these pretrained weights at your own discretion. References ---------- .. [13] A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266 .. [14] https://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html .. [15] A Berardino, Hierarchically normalized models of visual distortion sensitivity: Physiology, perception, and application; Ph.D. Thesis, 2018; https://www.cns.nyu.edu/pub/lcv/berardino-phd.pdf Examples -------- >>> import plenoptic as po >>> onoff_model = po.models.OnOff(31, pretrained=True, cache_filt=True) """ def __init__( self, kernel_size: int | tuple[int, int], amplitude_ratio: float = 1.25, pad_mode: str = "reflect", pretrained: bool = False, activation: Callable[[Tensor], Tensor] = F.softplus, apply_mask: bool = False, cache_filt: bool = False, ): super().__init__() if pretrained: if kernel_size not in [31, (31, 31)]: raise ValueError("pretrained model has kernel_size (31, 31)") if cache_filt is False: warn( "pretrained is True but cache_filt is False. Set" " cache_filt to True for efficiency unless you are" " fine-tuning." ) self.center_surround = CenterSurround( kernel_size=kernel_size, on_center=[True, False], amplitude_ratio=amplitude_ratio, pad_mode=pad_mode, cache_filt=cache_filt, ) self.luminance = Gaussian( kernel_size=kernel_size, out_channels=2, pad_mode=pad_mode, cache_filt=cache_filt, ) self.contrast = Gaussian( kernel_size=kernel_size, out_channels=2, pad_mode=pad_mode, cache_filt=cache_filt, ) # init scalar values around fitted parameters found in Berardino et al 2017 self.luminance_scalar = nn.Parameter(torch.rand(2) * 10) self.contrast_scalar = nn.Parameter(torch.rand(2) * 10) if pretrained: self.load_state_dict(self._pretrained_state_dict()) self.apply_mask = apply_mask self._disk = None # cached disk to apply to image self.activation = activation
[docs] def forward(self, x: Tensor) -> Tensor: """ Compute model response on input tensor. We use same-padding to ensure that the output and input shapes are matched. Parameters ---------- x The input tensor, should be 4d (batch, channel, height, width). Returns ------- y Model response to input. Examples -------- .. plot:: >>> import plenoptic as po >>> onoff_model = po.models.OnOff(31, pretrained=True, cache_filt=True) >>> img = po.data.einstein() >>> y = onoff_model.forward(img) >>> titles = ["Input image", "Output channel 0", "Output channel 1"] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...> """ linear = self.center_surround(x) lum = self.luminance(x) lum_normed = linear / (1 + self.luminance_scalar.view(1, 2, 1, 1) * lum) con = self.contrast(lum_normed.pow(2), groups=2).sqrt() + 1e-6 # avoid div by 0 con_normed = lum_normed / (1 + self.contrast_scalar.view(1, 2, 1, 1) * con) y = self.activation(con_normed) if self.apply_mask: im_shape = x.shape[-2:] if self._disk is None or self._disk.shape != im_shape: # cache new mask self._disk = disk(im_shape).to(x.device) if self._disk.device != x.device: self._disk = self._disk.to(x.device) y = self._disk * y # apply the mask return y
[docs] def display_filters( self, vrange: tuple[float, float] | str = "indep0", zoom: float | None = 5.0, title: str | list[str] | None = [ "linear filt on", "linear filt off", "luminance filt on", "luminance filt off", "contrast filt on", "contrast filt off", ], col_wrap: int | None = 2, **kwargs: Any, ) -> PyrFigure: """ Display convolutional filters of model. Parameters ---------- vrange, zoom, title, col_wrap Arguments for :func:`~plenoptic.plot.imshow`, see its docstrings for details. **kwargs Keyword args for :func:`~plenoptic.plot.imshow`. Returns ------- fig: The figure containing the displayed filters. Examples -------- .. plot:: >>> import plenoptic as po >>> onoff_model = po.models.OnOff(31, pretrained=True, cache_filt=True) >>> onoff_model.display_filters() <PyrFigure ...> """ # numpydoc ignore=ES01 weights = torch.cat( [ self.center_surround.filt, self.luminance.filt, self.contrast.filt, ], dim=0, ).detach() fig = imshow( weights, title=title, col_wrap=col_wrap, zoom=zoom, vrange=vrange, **kwargs, ) return fig
@staticmethod def _pretrained_state_dict() -> OrderedDict: """ Return parameters fit to human distortion judgments. Values copied from Table 2 in Berardino, 2018 [15]_. Returns ------- state_dict Dictionary of parameters, to pass to :func:`load_state_dict`. """ # numpydoc ignore=EX01 state_dict = OrderedDict( [ ("luminance_scalar", torch.as_tensor([3.2637, 14.3961])), ("contrast_scalar", torch.as_tensor([7.3405, 16.7423])), ("center_surround.center_std", torch.as_tensor([1.237, 0.3233])), ("center_surround.surround_std", torch.as_tensor([30.12, 2.184])), ("center_surround.amplitude_ratio", torch.as_tensor([1.25])), ("luminance.std", torch.as_tensor([76.4, 2.184])), ("contrast.std", torch.as_tensor([7.49, 2.43])), ] ) return state_dict