plenoptic.loss.portilla_simoncelli_loss_factory#

plenoptic.loss.portilla_simoncelli_loss_factory(model, image, reweighting_dict=None)[source]#

Create the loss function required for PortillaSimoncelli metamer synthesis.

This loss factory returns a callable which should be used as the loss_function when initializing Metamer for synthesizing metamers with the PortillaSimoncelli model. It zeroes the model’s representation of the images’ min/max pixel values and increases the weight on the variance of the highpass residuals before computing the L2-norm.

The optional reweighting_dict argument allows users to tweak the weights. If not None, keys should be a subset of those found in the output of convert_to_dict and whose values are Tensors (broadcastable to the shape of the corresponding values in convert_to_dict output) which will be multiplied by the corresponding group. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely. reweighting_dict takes precedence, so e.g., if it includes a "pixel_statistics" key, that will dictate how min/max pixel values are weighted.

To understand how the returned loss works and see how to write your own loss factory, see Portilla-Simoncelli optimization details.

Parameters:
Return type:

Callable[[Tensor, Tensor], Tensor]

Returns:

loss_func – A callable to use as your loss function for PortillaSimoncelli metamer synthesis.

Raises:
  • ValueError – If reweighting_dict contains keys not found in the model representation (model.convert_to_dict(model(image))).

  • ValueError – If model representation (model.convert_to_dict(model(image))) includes the key "pixel_statistics" but the corresponding tensor does not have shape[-1] == 6 or if it includes the key "var_highpass_residual" but the corresponding tensor does not have shape[-1] == 1 and the corresponding key is not included explicitly in reweighting_dict.

Warns:

UserWarning – If model representation (model.convert_to_dict(model(image))) does not include the keys "pixel_statistics" or "var_highpass_residual".

Examples

Create the loss function.

>>> import plenoptic as po
>>> import torch
>>> po.set_seed(0)
>>> img = po.data.einstein()
>>> img2 = torch.rand_like(img)
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> loss = po.loss.portilla_simoncelli_loss_factory(model, img)
>>> loss(model(img), model(img2))
tensor(30.9390)
>>> po.loss.l2_norm(model(img), model(img2))
tensor(30.5549)

Use the loss function for metamer synthesis.

>>> import plenoptic as po
>>> img = po.data.einstein()
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> loss = po.loss.portilla_simoncelli_loss_factory(model, img)
>>> met = po.Metamer(img, model, loss_function=loss)

Use reweighting_dict to increase weight on image pixel moments, while keeping min/max out of the loss. The model includes 6 pixel stats (see Understanding Portilla-Simoncelli model statistics for details)

>>> import plenoptic as po
>>> import torch
>>> po.set_seed(0)
>>> img = po.data.einstein()
>>> img2 = torch.rand_like(img)
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> rep = model.convert_to_dict(model(img))
>>> pixel_stats = torch.as_tensor([10, 10, 10, 10, 0, 0])
>>> pixel_stats = pixel_stats * torch.ones_like(rep["pixel_statistics"])
>>> reweighting_dict = {"pixel_statistics": pixel_stats}
>>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict)
>>> loss(model(img), model(img2))
tensor(35.1118)

Use reweighting_dict to include min/max in the loss and increase the importance of the standard deviations of the magnitude bands.

>>> import plenoptic as po
>>> import torch
>>> po.set_seed(0)
>>> img = po.data.einstein().to(torch.float64)
>>> img2 = torch.rand_like(img)
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> reweighting_dict = {"pixel_statistics": 1, "magnitude_std": 100}
>>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict)
>>> loss(model(img), model(img2))
tensor(253.2572, dtype=torch.float64)