plenoptic.simulate.models package

Submodules

plenoptic.simulate.models.frontend module

Model architectures in this file are found in [1], [2]. frontend.OnOff() has optional pretrained filters that were reverse-engineered from a previously-trained model and should be used at your own discretion.

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

class plenoptic.simulate.models.frontend.LinearNonlinear(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear-Nonlinear model, applies a difference of Gaussians filter followed by an activation function. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

CenterSurround difference of Gaussians filter.

Type:

nn.Module

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.LuminanceContrastGainControl(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear center-surround followed by luminance and contrast gain control, and activation function. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

Difference of Gaussians linear filter.

Type:

nn.Module

luminance

Gaussian convolutional kernel used to normalize signal by local luminance.

Type:

nn.Module

contrast

Gaussian convolutional kernel used to normalize signal by local contrast.

Type:

nn.Module

luminance_scalar

Scale factor for luminance normalization.

Type:

nn.Parameter

contrast_scalar

Scale factor for contrast normalization.

Type:

nn.Parameter

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.LuminanceGainControl(kernel_size, on_center=True, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', activation=<built-in function softplus>)[source]

Bases: Module

Linear center-surround followed by luminance gain control and activation. Model is described in [1] and [2].

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (bool) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on).

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear convolution.

center_surround

Difference of Gaussians linear filter.

Type:

nn.Module

luminance

Gaussian convolutional kernel used to normalize signal by local luminance.

Type:

nn.Module

luminance_scalar

Scale factor for luminance normalization.

Type:

nn.Parameter

References

[1]

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.frontend.OnOff(kernel_size, width_ratio_limit=4.0, amplitude_ratio=1.25, pad_mode='reflect', pretrained=False, activation=<built-in function softplus>, apply_mask=False, cache_filt=False)[source]

Bases: Module

Two-channel on-off and off-on center-surround model with local contrast and luminance gain control.

This model is called OnOff in Berardino et al 2017.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • pad_mode (str) – Padding for convolution, defaults to “reflect”.

  • pretrained – Whether or not to load model params estimated from [1]. See Notes for details.

  • activation (Callable[[Tensor], Tensor]) – Activation function following linear and gain control operations.

  • apply_mask (bool) – Whether or not to apply circular disk mask centered on the input image. This is useful for synthesis methods like Eigendistortions to ensure that the synthesized distortion will not appear in the periphery. See plenoptic.tools.signal.make_disk() for details on how mask is created.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt.

Notes

These 12 parameters (standard deviations & scalar constants) were reverse-engineered from model from [1], [2]. Please use these pretrained weights at your own discretion.

References

[1] (1,2)

A Berardino, J Ballé, V Laparra, EP Simoncelli, Eigen-distortions of hierarchical representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

display_filters([zoom])

Displays convolutional filters of model

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

display_filters(zoom=5.0, **kwargs)[source]

Displays convolutional filters of model

Parameters:
  • zoom (float) – Magnification factor for po.imshow()

  • **kwargs – Keyword args for po.imshow

Returns:

fig

Return type:

PyrFigure

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

plenoptic.simulate.models.naive module

class plenoptic.simulate.models.naive.CenterSurround(kernel_size, on_center=True, width_ratio_limit=2.0, amplitude_ratio=1.25, center_std=1.0, surround_std=4.0, out_channels=1, pad_mode='reflect', cache_filt=False)[source]

Bases: Module

Center-Surround, Difference of Gaussians (DoG) filter model. Can be either on-center/off-surround, or vice versa.

Filter is constructed as: .. math:

f &= amplitude_ratio * center - surround \
f &= f/f.sum()

The signs of center and surround are determined by center argument. The standard deviation of the surround Gaussian is constrained to be at least width_ratio_limit times that of the center Gaussian.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Shape of convolutional kernel.

  • on_center (Union[bool, List[bool]]) – Dictates whether center is on or off; surround will be the opposite of center (i.e. on-off or off-on). If List of bools, then list length must equal out_channels, if just a single bool, then all out_channels will be assumed to be all on-off or off-on.

  • width_ratio_limit (float) – Sets a lower bound on the ratio of surround_std over center_std. The surround Gaussian must be wider than the center Gaussian in order to be a proper Difference of Gaussians. surround_std will be clamped to ratio_limit times center_std.

  • amplitude_ratio (float) – Ratio of center/surround amplitude. Applied before filter normalization.

  • center_std (Union[float, Tensor]) – Standard deviation of circular Gaussian for center.

  • surround_std (Union[float, Tensor]) – Standard deviation of circular Gaussian for surround. Must be at least ratio_limit times center_std.

  • out_channels (int) – Number of filters.

  • pad_mode (str) – Padding for convolution, defaults to “circular”.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt

Attributes:
filt

Creates an on center/off surround, or off center/on surround conv filter

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

property filt: Tensor

Creates an on center/off surround, or off center/on surround conv filter

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.naive.Gaussian(kernel_size, std=3.0, pad_mode='reflect', out_channels=1, cache_filt=False)[source]

Bases: Module

Isotropic Gaussian convolutional filter. Kernel elements are normalized and sum to one.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Size of convolutional kernel.

  • std (Union[float, Tensor]) – Standard deviation of circularly symmtric Gaussian kernel.

  • pad_mode (str) – Padding mode argument to pass to torch.nn.functional.pad.

  • out_channels (int) – Number of filters with which to convolve.

  • cache_filt (bool) – Whether or not to cache the filter. Avoids regenerating filt with each forward pass. Cached to self._filt.

Attributes:
filt

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x, **conv2d_kwargs)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

property filt
forward(x, **conv2d_kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class plenoptic.simulate.models.naive.Identity(name=None)[source]

Bases: Module

simple class that just returns a copy of the image

We use this as a “dummy model” for metrics that we don’t have the representation for. We use this as the model and then just change the objective function.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(img)

Return a copy of the image

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(img)[source]

Return a copy of the image

Parameters:

img (torch.Tensor) – The image to return

Returns:

img – a clone of the input image

Return type:

torch.Tensor

class plenoptic.simulate.models.naive.Linear(kernel_size=(3, 3), pad_mode='circular', default_filters=True)[source]

Bases: Module

Simplistic linear convolutional model: It splits the input greyscale image into low and high frequencies.

Parameters:
  • kernel_size (Union[int, Tuple[int, int]]) – Convolutional kernel size.

  • pad_mode (str) – Mode with which to pad image using nn.functional.pad().

  • default_filters (bool) – Initialize the filters to a low-pass and a band-pass.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(x)

Define the computation performed at every call.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

plenoptic.simulate.models.portilla_simoncelli module

class plenoptic.simulate.models.portilla_simoncelli.PortillaSimoncelli(im_shape, n_scales=4, n_orientations=4, spatial_corr_width=9, use_true_correlations=True)[source]

Bases: Module

Model for measuring texture statistics originally proposed in [1] for the purpose of synthesizing texture metamers. These statistics are proposed in [1] as a sufficient set measurements for describing and synthesizing a given visual texture.

Currently we do not support batch measurement of images.

Parameters:
  • n_scales (int, optional) – The number of pyramid scales used to measure the statistics (default=4)

  • n_orientations (int, optional) – The number of orientations used to measure the statistics (default=4)

  • spatial_corr_width (int, optional) – The width of the spatial cross- and auto-correlation statistics in the representation

  • use_true_correlations (bool) – In the original Portilla-Simoncelli model the statistics in the representation that are labelled correlations were actually covariance matrices (i.e. not properly scaled). In order to match the original statistics use_true_correlations must be set to false. But in order to synthesize metamers from this model use_true_correlations must be set to true (default).

pyr

The complex steerable pyramid object used to calculate the portilla-simoncelli representation

Type:

SteerablePyramidFreq

pyr_coeffs

The coefficients of the complex steerable pyramid.

Type:

OrderedDict

mag_pyr_coeffs

The magnitude of the pyramid coefficients.

Type:

OrderedDict

real_pyr_coeffs

The real parts of the pyramid coefficients.

Type:

OrderedDict

scales

The names of the unique scales of coefficients in the pyramid.

Type:

list

representation_scales

The scale for each coefficient in its vector form

Type:

list

representation

A dictionary containing the Portilla-Simoncelli statistics

Type:

dictionary

References

[1]

J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int’l Journal of Computer Vision. 40(1):49-71, October, 2000. http://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html http://www.cns.nyu.edu/~lcv/texture/

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

compute_autocorrelation(ch)

Computes the autocorrelation and variance of a given matrix (ch)

compute_crosscorrelation(ch1, ch2, band_num_el)

Computes either the covariance of the two matrices or the cross-correlation depending on the value self.use_true_correlations.

compute_skew_kurtosis(ch, vari)

Computes the skew and kurtosis of ch.

convert_to_dict(vec)

Converts vector of statistics to a dictionary.

convert_to_vector([stats_dict])

Converts dictionary of statistics to a vector (for synthesis).

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

expand(im, mult)

Resize an image (im) by a multiplier (mult).

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(image[, scales])

Generate Texture Statistics representation of an image (see reference [1]_)

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

kurtosis(X[, mu, var])

Computes the kurtosis of a matrix X.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

plot_representation([data, ax, figsize, ...])

Plot the representation in a human viewable format -- stem plots with data separated out by statistic type.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

skew(X[, mu, var])

Computes the skew of a matrix X.

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update_plot(axes[, batch_idx, data])

Update the information in our representation plot

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

compute_autocorrelation(ch)[source]

Computes the autocorrelation and variance of a given matrix (ch)

Parameters:

ch (torch.Tensor)

Returns:

  • ac (torch.Tensor) – Autocorrelation of matrix (ch).

  • vari (torch.Tensor) – Variance of matrix (ch).

compute_crosscorrelation(ch1, ch2, band_num_el)[source]

Computes either the covariance of the two matrices or the cross-correlation depending on the value self.use_true_correlations.

Parameters:
  • ch1 (torch.Tensor) – First matrix for cross correlation.

  • ch2 (torch.Tensor) – Second matrix for cross correlation.

  • band_num_el (int) – Number of elements for bands in the scale

Returns:

cross-correlation.

Return type:

torch.Tensor

compute_skew_kurtosis(ch, vari)[source]

Computes the skew and kurtosis of ch.

Skew and kurtosis of ch are computed. If the ratio of its variance (vari) and the pixel variance of the original image are below a certain threshold (1e-6) skew and kurtosis are assigned the default values (0,3).

Parameters:
  • ch (torch.Tensor)

  • vari (torch.Tensor) – variance of ch

Returns:

  • skew (torch.Tensor) – skew of ch or default value (0)

  • kurtosis (torch.Tensor) – kurtosis of ch or default value (3)

convert_to_dict(vec)[source]

Converts vector of statistics to a dictionary.

Parameters:

vec – Flattened 1d vector of statistics.

Return type:

Dictionary of representation, with informative keys.

convert_to_vector(stats_dict=None)[source]

Converts dictionary of statistics to a vector (for synthesis).

Parameters:

stats_dict (optional) – If None, we use self.representation. Dictionary of representation.

Return type:

Flattened 1d vector of statistics.

static expand(im, mult)[source]

Resize an image (im) by a multiplier (mult).

Parameters:
  • im (torch.Tensor) – An image for expansion.

  • mult (int) – Multiplier by which to resize image.

Returns:

im_large – resized image

Return type:

torch.Tensor

forward(image, scales=None)[source]

Generate Texture Statistics representation of an image (see reference [1]_)

Parameters:
  • image (torch.Tensor) – A tensor containing the image to analyze. We want to operate on this in the pytorch-y way, so we want it to be 4d (batch, channel, height, width). Currently, only single-batch and single-channel images are supported.

  • scales (list, optional) – Which scales to include in the returned representation. If an empty list (the default), we include all scales. Otherwise, can contain subset of values present in this model’s scales attribute, and the returned vector will then contain the subset of the full representation corresponding to those scales.

Returns:

representation_vector – 3d tensor containing the measured representation statistics.

Return type:

torch.Tensor

static kurtosis(X, mu=None, var=None)[source]

Computes the kurtosis of a matrix X.

Parameters:
  • X (torch.Tensor) – matrix to compute the kurtosis of.

  • mu (torch.Tensor) – pre-computed mean. If None, we compute it.

  • var (torch.Tensor) – pre-computed variance. If None, we compute it.

Returns:

kurtosis – kurtosis of the matrix X

Return type:

torch.Tensor

plot_representation(data=None, ax=None, figsize=(15, 15), ylim=None, batch_idx=0, title=None)[source]

Plot the representation in a human viewable format – stem plots with data separated out by statistic type.

Parameters:
  • data (torch.Tensor, dict, or None, optional) – The data to show on the plot. If None, we use self.representation. Else, should look like self.representation, with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • ax – axis where we will plot the data

  • figsize ((int, int), optional) – the size of the figure

  • ylim ((int,int) or None, optional)

  • batch_idx (int, optional) – Which index to take from the batch dimension (the first one)

  • title (string) – title for the plot

Returns:

data – The data that was plotted.

Return type:

torch.Tensor, dict, or None, optional

static skew(X, mu=None, var=None)[source]

Computes the skew of a matrix X.

Parameters:
  • X (torch.Tensor) – matrix to compute the skew of.

  • mu (torch.Tensor or None, optional) – pre-computed mean. If None, we compute it.

  • var (torch.Tensor or None, optional) – pre-computed variance. If None, we compute it.

Returns:

skew – skew of the matrix X

Return type:

torch.Tensor

to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]

Its signature is similar to torch.Tensor.to(), but only accepts floating point desired dtype s. In addition, this method will only cast the floating point parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point type of

the floating point parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

Returns:

Module: self

update_plot(axes, batch_idx=0, data=None)[source]

Update the information in our representation plot

This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we’ve used plot_representation to create the plots we want to update and so know that they’re stem plots.

We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we’re showing other information, as done by Metamer.animate), grab the representation from plotting and, since these are both lists, iterate through them, updating as we go.

We can optionally accept a data argument, in which case it should look just like the representation of this model.

In order for this to be used by FuncAnimation, we need to return Artists, so we return a list of the relevant artists, the markerline and stemlines from the StemContainer.

Parameters:
  • axes (list) – A list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • batch_idx (int, optional) – Which index to take from the batch dimension (the first one)

  • data (torch.Tensor, dict, or None, optional) – The data to show on the plot. If None, we use self.representation. Else, should look like self.representation, with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

Returns:

stem_artists – A list of the artists used to update the information on the stem plots

Return type:

list

plenoptic.simulate.models.portilla_simoncelli_full module

class plenoptic.simulate.models.portilla_simoncelli_full.PortillaSimoncelliFull(image_shape, n_scales=4, n_orientations=4, spatial_corr_width=9)[source]

Bases: Module

Portila-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in [1], that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as belonging to the same family of texture.

The PS stats are computed based on the steerable pyramid [2]. They consist of the local auto-correlations, cross-scale (within-orientation) correlations, and cross-orientation (within-scale) correlations of both the pyramid coefficients and the local energy (as computed by those coefficients). Additionally, they include the first four global moments (mean, variance, skew, and kurtosis) of the image and down-sampled versions of that image. See the paper and notebook for more description.

Parameters:
  • image_shape (Tuple[int, int]) – Shape of input image.

  • n_scales (int) – The number of pyramid scales used to measure the statistics (default=4)

  • n_orientations (int) – The number of orientations used to measure the statistics (default=4)

  • spatial_corr_width (int) – The width of the spatial cross- and auto-correlation statistics in the representation

scales

The names of the unique scales of coefficients in the pyramid.

Type:

list

representation_scales

The scale for each coefficient in its vector form

Type:

list

References

[1]

J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int’l Journal of Computer Vision. 40(1):49-71, October, 2000. http://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html http://www.cns.nyu.edu/~lcv/texture/

[2]

E P Simoncelli and W T Freeman, “The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation,” Second Int’l Conf on Image Processing, Washington, DC, Oct 1995.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

convert_to_dict(vec)

Converts vector of statistics to a dictionary.

convert_to_vector(stats_dict)

Converts dictionary of statistics to a vector.

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(image[, scales])

Generate Texture Statistics representation of an image.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

plot_representation(data[, ax, figsize, ...])

Plot the representation in a human viewable format -- stem plots with data separated out by statistic type.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

remove_scales(representation_vector, ...)

Remove statistics not associated with scales

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update_plot(axes, data[, batch_idx])

Update the information in our representation plot

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

convert_to_dict(vec)[source]

Converts vector of statistics to a dictionary.

While the vector representation is required by plenoptic’s synthesis objects, the dictionary representation is easier to manually inspect.

Parameters:

vec (Tensor) – 3d vector of statistics.

Return type:

Dictionary of representation, with informative keys.

See also

convert_to_vector

Convert dictionary representation to vector.

convert_to_vector(stats_dict)[source]

Converts dictionary of statistics to a vector.

While the dictionary representation is easier to manually inspect, the vector representation is required by plenoptic’s synthesis objects.

Parameters:

stats_dict (OrderedDict) – Dictionary of representation.

Return type:

3d vector of statistics.

See also

convert_to_dict

Convert vector representation to dictionary.

forward(image, scales=None)[source]

Generate Texture Statistics representation of an image.

Note that separate batches and channels are analyzed in parallel.

Parameters:
  • image (Tensor) – A 4d tensor (batch, channel, height, width) containing the image(s) to analyze.

  • scales (Optional[List[Union[int, Literal['pixel_statistics', 'residual_lowpass', 'residual_highpass']]]]) – Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model’s scales attribute, and the returned vector will then contain the subset of the full representation corresponding to those scales.

Returns:

3d tensor of shape (B,C,S) containing the measured texture statistics.

Return type:

representation_vector

plot_representation(data, ax=None, figsize=(15, 15), ylim=None, batch_idx=0, title=None)[source]

Plot the representation in a human viewable format – stem plots with data separated out by statistic type.

Currently, this averages over all channels in the representation.

Parameters:
  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • ax (Optional[Axes]) – Axes where we will plot the data. If an mpl.axes.Axes, will subdivide into 7 new axes. If None, we create a new figure.

  • figsize (Tuple[float, float]) – The size of the figure. Ignored if ax is not None.

  • ylim (Optional[Tuple[float, float]]) – The ylimits of the plot.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • title (string) – Title for the plot

Return type:

Tuple[Figure, List[Axes]]

Returns:

  • fig – Figure containing the plot

  • axes – List of 7 axes containing the plot

remove_scales(representation_vector, scales_to_keep)[source]

Remove statistics not associated with scales

For a given representation_vector and a list of scales_to_keep, this attribute removes all statistics not associated with those scales.

Note that calling this method will always remove statistics.

Parameters:
  • representation_vector (Tensor) – 3d tensor containing the measured representation statistics.

  • scales_to_keep (List[Union[int, Literal['pixel_statistics', 'residual_lowpass', 'residual_highpass']]]) – Which scales to include in the returned representation. Can contain subset of values present in this model’s scales attribute, and the returned vector will then contain the subset of the full representation corresponding to those scales.

Returns:

Representation vector with some statistics removed.

Return type:

limited_representation_vector

to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]

Its signature is similar to torch.Tensor.to(), but only accepts floating point desired dtype s. In addition, this method will only cast the floating point parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point type of

the floating point parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

Returns:

Module: self

update_plot(axes, data, batch_idx=0)[source]

Update the information in our representation plot

This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we’ve used plot_representation to create the plots we want to update and so know that they’re stem plots.

We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we’re showing other information, as done by Metamer.animate), grab the representation from plotting and, since these are both lists, iterate through them, updating as we go.

We can optionally accept a data argument, in which case it should look just like the representation of this model.

In order for this to be used by FuncAnimation, we need to return Artists, so we return a list of the relevant artists, the markerline and stemlines from the StemContainer.

Currently, this averages over all channels in the representation.

Parameters:
  • axes (List[Axes]) – A list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

Returns:

A list of the artists used to update the information on the stem plots

Return type:

stem_artists

plenoptic.simulate.models.portilla_simoncelli_intermed module

class plenoptic.simulate.models.portilla_simoncelli_intermed.PortillaSimoncelliIntermed(image_shape, n_scales=4, n_orientations=4, spatial_corr_width=9, use_true_correlations=True)[source]

Bases: Module

Portila-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in [1], that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as belonging to the same family of texture.

The PS stats are computed based on the steerable pyramid [2]. They consist of the local auto-correlations, cross-scale (within-orientation) correlations, and cross-orientation (within-scale) correlations of both the pyramid coefficients and the local energy (as computed by those coefficients). Additionally, they include the first four global moments (mean, variance, skew, and kurtosis) of the image and down-sampled versions of that image. See the paper and notebook for more description.

Parameters:
  • image_shape (Tuple[int, int]) – Shape of input image.

  • n_scales (int) – The number of pyramid scales used to measure the statistics (default=4)

  • n_orientations (int) – The number of orientations used to measure the statistics (default=4)

  • spatial_corr_width (int) – The width of the spatial cross- and auto-correlation statistics in the representation

  • use_true_correlations (bool) – In the original Portilla-Simoncelli model the statistics in the representation that are labelled correlations were actually covariance matrices (i.e. not properly scaled). In order to match the original statistics use_true_correlations must be set to false. But in order to synthesize metamers from this model use_true_correlations must be set to True (note that, in this case, the diagonal entries are not rescaled, i.e., they’re the covariances).

pyr

The complex steerable pyramid object used to calculate the portilla-simoncelli representation

Type:

SteerablePyramidFreq

pyr_coeffs

The coefficients of the complex steerable pyramid.

Type:

OrderedDict

mag_pyr_coeffs

The magnitude of the pyramid coefficients.

Type:

OrderedDict

real_pyr_coeffs

The real parts of the pyramid coefficients.

Type:

OrderedDict

scales

The names of the unique scales of coefficients in the pyramid.

Type:

list

representation_scales

The scale for each coefficient in its vector form

Type:

list

representation

A dictionary containing the Portilla-Simoncelli statistics

Type:

dictionary

References

[1]

J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int’l Journal of Computer Vision. 40(1):49-71, October, 2000. http://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html http://www.cns.nyu.edu/~lcv/texture/

[2]

E P Simoncelli and W T Freeman, “The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation,” Second Int’l Conf on Image Processing, Washington, DC, Oct 1995.

Methods

add_module(name, module)

Add a child module to the current module.

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

compute_crosscorrelation(ch1, ch2, band_num_el)

Computes either the covariance of the two matrices or the cross-correlation depending on the value self.use_true_correlations.

convert_to_dict(vec)

Converts vector of statistics to a dictionary.

convert_to_vector(stats_dict)

Converts dictionary of statistics to a vector.

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Set the module in evaluation mode.

extra_repr()

Set the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

forward(image[, scales])

Generate Texture Statistics representation of an image.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.

modules()

Return an iterator over all modules in the network.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Return an iterator over module parameters.

plot_representation(data[, ax, figsize, ...])

Plot the representation in a human viewable format -- stem plots with data separated out by statistic type.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_pre_hook(hook)

Register a pre-hook for the load_state_dict() method.

remove_scales(representation_vector, ...)

Remove statistics not associated with scales

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

set_extra_state(state)

Set extra state contained in the loaded state_dict.

share_memory()

See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

update_plot(axes[, data, batch_idx])

Update the information in our representation plot

xpu([device])

Move all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Reset gradients of all model parameters.

__call__

compute_crosscorrelation(ch1, ch2, band_num_el)[source]

Computes either the covariance of the two matrices or the cross-correlation depending on the value self.use_true_correlations.

Parameters:
  • ch1 (torch.Tensor) – First matrix for cross correlation.

  • ch2 (torch.Tensor) – Second matrix for cross correlation.

  • band_num_el (int) – Number of elements for bands in the scale

Returns:

cross-correlation.

Return type:

torch.Tensor

convert_to_dict(vec)[source]

Converts vector of statistics to a dictionary.

While the vector representation is required by plenoptic’s synthesis objects, the dictionary representation is easier to manually inspect.

Parameters:

vec (Tensor) – 3d vector of statistics.

Return type:

Dictionary of representation, with informative keys.

See also

convert_to_vector

Convert dictionary representation to vector.

convert_to_vector(stats_dict)[source]

Converts dictionary of statistics to a vector.

While the dictionary representation is easier to manually inspect, the vector representation is required by plenoptic’s synthesis objects.

Parameters:

stats_dict (OrderedDict) – Dictionary of representation.

Return type:

3d vector of statistics.

See also

convert_to_dict

Convert vector representation to dictionary.

forward(image, scales=None)[source]

Generate Texture Statistics representation of an image.

Note that separate batches and channels are analyzed in parallel.

Parameters:
  • image (Tensor) – A 4d tensor (batch, channel, height, width) containing the image(s) to analyze.

  • scales (Optional[List[Union[int, Literal['pixel_statistics', 'residual_lowpass', 'residual_highpass']]]]) – Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model’s scales attribute, and the returned vector will then contain the subset of the full representation corresponding to those scales.

Returns:

3d tensor of shape (B,C,S) containing the measured texture statistics.

Return type:

representation_vector

plot_representation(data, ax=None, figsize=(15, 15), ylim=None, batch_idx=0, title=None)[source]

Plot the representation in a human viewable format – stem plots with data separated out by statistic type.

Parameters:
  • data (Union[Tensor, OrderedDict]) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • ax (Optional[Axes]) – Axes where we will plot the data. If an mpl.axes.Axes, will subdivide into 9 new axes. If None, we create a new figure.

  • figsize (Tuple[float, float]) – The size of the figure. Ignored if ax is not None.

  • ylim (Optional[Tuple[float, float]]) – The ylimits of the plot.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • title (string) – Title for the plot

Return type:

Tuple[Figure, List[Axes]]

Returns:

  • fig – Figure containing the plot

  • axes – List of 9 axes containing the plot

remove_scales(representation_vector, scales_to_keep)[source]

Remove statistics not associated with scales

For a given representation_vector and a list of scales_to_keep, this attribute removes all statistics not associated with those scales.

Note that calling this method will always remove statistics.

Parameters:
  • representation_vector (Tensor) – 3d tensor containing the measured representation statistics.

  • scales_to_keep (List[Union[int, Literal['pixel_statistics', 'residual_lowpass', 'residual_highpass']]]) – Which scales to include in the returned representation. Can contain subset of values present in this model’s scales attribute, and the returned vector will then contain the subset of the full representation corresponding to those scales.

Returns:

Representation vector with some statistics removed.

Return type:

limited_representation_vector

to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]

Its signature is similar to torch.Tensor.to(), but only accepts floating point desired dtype s. In addition, this method will only cast the floating point parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point type of

the floating point parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

Returns:

Module: self

update_plot(axes, data=None, batch_idx=0)[source]

Update the information in our representation plot

This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we’ve used plot_representation to create the plots we want to update and so know that they’re stem plots.

We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we’re showing other information, as done by Metamer.animate), grab the representation from plotting and, since these are both lists, iterate through them, updating as we go.

We can optionally accept a data argument, in which case it should look just like the representation of this model.

In order for this to be used by FuncAnimation, we need to return Artists, so we return a list of the relevant artists, the markerline and stemlines from the StemContainer.

Parameters:
  • axes (List[Axes]) – A list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • batch_idx (int) – Which index to take from the batch dimension (the first one)

  • data (Union[Tensor, OrderedDict, None]) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

Returns:

A list of the artists used to update the information on the stem plots

Return type:

stem_artists

plenoptic.simulate.models.portilla_simoncelli_old module

Module contents