plenoptic.models.CenterSurround#
Note
This object is a torch.nn.Module. It therefore has all the methods and attributes
from that class, even though they are not documented here (to avoid cluttering this page).
- class plenoptic.models.CenterSurround(kernel_size, on_center=True, amplitude_ratio=1.25, center_std=1.0, surround_std=4.0, out_channels=None, pad_mode='reflect', cache_filt=False)[source]#
Bases:
ModuleCenter-Surround, Difference of Gaussians (DoG) filter model.
Can be either on-center/off-surround, or vice versa.
Filter is constructed as:
f = amplitude_ratio * center - surround f = f / f.sum()
The signs of center and surround are determined by
on_centerargument.- Parameters:
kernel_size (
int|tuple[int,int]) – Shape of convolutional kernel.on_center (
bool|list[bool] (default:True)) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equalout_channels, if just a single bool, then allout_channelswill be assumed to be all on-off or off-on.amplitude_ratio (
float(default:1.25)) – Ratio of center/surround amplitude. Applied before filter normalization. Must be greater than or equal to 1.center_std (
int|list[int] |float|list[float] |Tensor(default:1.0)) – Standard deviation of circular Gaussian for center.surround_std (
int|list[int] |float|list[float] |Tensor(default:4.0)) – Standard deviation of circular Gaussian for surround.out_channels (
int|None(default:None)) – Number of filters. IfNone, inferred from shape ofcenter_std.pad_mode (
str(default:'reflect')) – Padding for convolution.cache_filt (
bool(default:False)) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass.
- Raises:
ValueError – If out_channels is not a positive integer.
ValueError – If kernel_size is not a positive integer.
ValueError – If center_std or surround_std are not positive.
ValueError – If center_std and surround_std do not have the same number of values.
ValueError – If center_std or surround_std are non-scalar and their lengths do not equal
out_channels
Examples
>>> import plenoptic as po >>> cs_model = po.models.CenterSurround(kernel_size=10) >>> cs_model CenterSurround()
Model with both on-center/off-surround and off-center/on-surround:
>>> import plenoptic as po >>> cs_model = po.models.CenterSurround(10, [True, False]) >>> cs_model CenterSurround()
Methods
forward(x)Convolve center-surround filter with input tensor.
Attributes
Center-surround filter(s).
- forward(x)[source]#
Convolve center-surround filter with input tensor.
We use same-padding to ensure that the output and input shapes are matched.
- Parameters:
x (
Tensor) – The input tensor, should be 4d (batch, channel, height, width).- Return type:
- Returns:
y – A linear convolution of the input image, of same shape as the input.
Examples
>>> import plenoptic as po >>> cs_model = po.models.CenterSurround(kernel_size=10) >>> img = po.data.curie() >>> y = cs_model.forward(img) >>> po.plot.imshow([img, y], title=["Input image", "Output"]) <PyrFigure size...>
Model with both on-center/off-surround and off-center/on-surround:
>>> import plenoptic as po >>> cs_model = po.models.CenterSurround(10, [True, False]) >>> img = po.data.curie() >>> y = cs_model.forward(img) >>> titles = [ ... "Input image", ... "On-center/off-surround", ... "Off-center/on-surround", ... ] >>> po.plot.imshow([img, y], title=titles) <PyrFigure size...>