plenoptic.models.CenterSurround#

Note

This object is a torch.nn.Module. It therefore has all the methods and attributes from that class, even though they are not documented here (to avoid cluttering this page).

class plenoptic.models.CenterSurround(kernel_size, on_center=True, amplitude_ratio=1.25, center_std=1.0, surround_std=4.0, out_channels=None, pad_mode='reflect', cache_filt=False)[source]#

Bases: Module

Center-Surround, Difference of Gaussians (DoG) filter model.

Can be either on-center/off-surround, or vice versa.

Filter is constructed as:

f = amplitude_ratio * center - surround
f = f / f.sum()

The signs of center and surround are determined by on_center argument.

Parameters:
  • kernel_size (int | tuple[int, int]) – Shape of convolutional kernel.

  • on_center (bool | list[bool] (default: True)) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equal out_channels, if just a single bool, then all out_channels will be assumed to be all on-off or off-on.

  • amplitude_ratio (float (default: 1.25)) – Ratio of center/surround amplitude. Applied before filter normalization. Must be greater than or equal to 1.

  • center_std (int | list[int] | float | list[float] | Tensor (default: 1.0)) – Standard deviation of circular Gaussian for center.

  • surround_std (int | list[int] | float | list[float] | Tensor (default: 4.0)) – Standard deviation of circular Gaussian for surround.

  • out_channels (int | None (default: None)) – Number of filters. If None, inferred from shape of center_std.

  • pad_mode (str (default: 'reflect')) – Padding for convolution.

  • cache_filt (bool (default: False)) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass.

Raises:
  • ValueError – If out_channels is not a positive integer.

  • ValueError – If kernel_size is not a positive integer.

  • ValueError – If center_std or surround_std are not positive.

  • ValueError – If center_std and surround_std do not have the same number of values.

  • ValueError – If center_std or surround_std are non-scalar and their lengths do not equal out_channels

Examples

>>> import plenoptic as po
>>> cs_model = po.models.CenterSurround(kernel_size=10)
>>> cs_model
CenterSurround()

Model with both on-center/off-surround and off-center/on-surround:

>>> import plenoptic as po
>>> cs_model = po.models.CenterSurround(10, [True, False])
>>> cs_model
CenterSurround()

Methods

forward(x)

Convolve center-surround filter with input tensor.

Attributes

filt

Center-surround filter(s).

forward(x)[source]#

Convolve center-surround filter with input tensor.

We use same-padding to ensure that the output and input shapes are matched.

Parameters:

x (Tensor) – The input tensor, should be 4d (batch, channel, height, width).

Return type:

Tensor

Returns:

y – A linear convolution of the input image, of same shape as the input.

Examples

>>> import plenoptic as po
>>> cs_model = po.models.CenterSurround(kernel_size=10)
>>> img = po.data.curie()
>>> y = cs_model.forward(img)
>>> po.plot.imshow([img, y], title=["Input image", "Output"])
<PyrFigure size...>

(png, hires.png, pdf)

../../_images/plenoptic-models-CenterSurround-1.png

Model with both on-center/off-surround and off-center/on-surround:

>>> import plenoptic as po
>>> cs_model = po.models.CenterSurround(10, [True, False])
>>> img = po.data.curie()
>>> y = cs_model.forward(img)
>>> titles = [
...     "Input image",
...     "On-center/off-surround",
...     "Off-center/on-surround",
... ]
>>> po.plot.imshow([img, y], title=titles)
<PyrFigure size...>

(png, hires.png, pdf)

../../_images/plenoptic-models-CenterSurround-2.png
property filt: Tensor#

Center-surround filter(s).