plenoptic.loss.portilla_simoncelli_loss_factory#
- plenoptic.loss.portilla_simoncelli_loss_factory(model, image, reweighting_dict=None)[source]#
Create the loss function required for
PortillaSimoncellimetamer synthesis.This loss factory returns a callable which should be used as the
loss_functionwhen initializingMetamerfor synthesizing metamers with thePortillaSimoncellimodel. It zeroes the model’s representation of the images’ min/max pixel values and increases the weight on the variance of the highpass residuals before computing the L2-norm.The optional
reweighting_dictargument allows users to tweak the weights. If notNone, keys should be a subset of those found in the output ofconvert_to_dictand whose values are Tensors (broadcastable to the shape of the corresponding values inconvert_to_dictoutput) which will be multiplied by the corresponding group. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely.reweighting_dicttakes precedence, so e.g., if it includes a"pixel_statistics"key, that will dictate how min/max pixel values are weighted.To understand how the returned loss works and see how to write your own loss factory, see Portilla-Simoncelli optimization details.
- Parameters:
model (
PortillaSimoncelli) – An instantiatedPortillaSimoncellimodel.image (
Tensor) – The target image for metamer synthesis, or an image with the same shape, dtype, and device.reweighting_dict (
dict[str,Tensor|float] |None(default:None)) – Dictionary specifying reweighting. See above for details.
- Return type:
- Returns:
loss_func – A callable to use as your loss function for
PortillaSimoncellimetamer synthesis.- Raises:
ValueError – If
reweighting_dictcontains keys not found in the model representation (model.convert_to_dict(model(image))).ValueError – If model representation (
model.convert_to_dict(model(image))) includes the key"pixel_statistics"but the corresponding tensor does not haveshape[-1] == 6or if it includes the key"var_highpass_residual"but the corresponding tensor does not haveshape[-1] == 1and the corresponding key is not included explicitly inreweighting_dict.
- Warns:
UserWarning – If model representation (
model.convert_to_dict(model(image))) does not include the keys"pixel_statistics"or"var_highpass_residual".
Examples
Create the loss function.
>>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img) >>> loss(model(img), model(img2)) tensor(30.9390) >>> po.loss.l2_norm(model(img), model(img2)) tensor(30.5549)
Use the loss function for metamer synthesis.
>>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img) >>> met = po.Metamer(img, model, loss_function=loss)
Use
reweighting_dictto increase weight on image pixel moments, while keeping min/max out of the loss. The model includes 6 pixel stats (see Understanding Portilla-Simoncelli model statistics for details)>>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> rep = model.convert_to_dict(model(img)) >>> pixel_stats = torch.as_tensor([10, 10, 10, 10, 0, 0]) >>> pixel_stats = pixel_stats * torch.ones_like(rep["pixel_statistics"]) >>> reweighting_dict = {"pixel_statistics": pixel_stats} >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(35.1118)
Use
reweighting_dictto include min/max in the loss and increase the importance of the standard deviations of the magnitude bands.>>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein().to(torch.float64) >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> reweighting_dict = {"pixel_statistics": 1, "magnitude_std": 100} >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(253.2572, dtype=torch.float64)