plenoptic.loss.groupwise_relative_l2_norm_factory#

plenoptic.loss.groupwise_relative_l2_norm_factory(model, image, reweighting_dict=None)[source]#

Create loss function that computes groupwise relative L2 norm for synthesis.

This loss factory returns a callable which should make optimization easier when used as the loss_function when initializing Metamer for synthesizing metamers. The resulting loss function will normalize each group within the representation by the L2 norm of that group on image, which should be the target image for that synthesis.

This requires that model has two methods, convert_to_dict and convert_to_tensor, which convert the representation between a tensor (as returned by forward) and an OrderedDict. The dictionary representation should have keys that define the different groups within the representation, and its values should be tensors (of any shape).

The optional reweighting_dict argument allows users to further tweak the weights, if necessary. If not None, keys should be a subset of those found in the output of model.convert_to_dict, and whose values are Tensors (broadcastable to the shape of the corresponding values in model.convert_to_dict output) which will be multiplied by the corresponding group after normalization. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely.

For an example of a compliant model, see the PortillaSimoncelli model.

Parameters:
  • model (Module) – An instantiated model.

  • image (Tensor) – The target image for metamer synthesis.

  • reweighting_dict (dict[str, Tensor | float] | None (default: None)) – Dictionary specifying further reweighting. See above for details.

Return type:

Callable[[Tensor, Tensor], Tensor]

Returns:

loss_func – A callable to use as your loss function for metamer synthesis.

Raises:

ValueError – If reweighting_dict contains keys not found in the model representation (model.convert_to_dict(model(image))).

Warns:

UserWarning – If model.convert_to_dict() does not return an OrderedDict. convert_to_dict and convert_to_tensor need to invert each other, which means you should probably use an OrderedDict, which guarantees that the order of the keys is preserved. You can use validate_convert_tensor_dict to heuristically check whether your model satisfies this constraint.

Examples

Create the loss function with a simple model.

>>> import plenoptic as po
>>> from collections import OrderedDict
>>> import torch
>>> po.set_seed(0)
>>> class TestModel(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False)
...         self.kernel.weight.detach_()
...
...     def forward(self, x):
...         return self.kernel(x)
...
...     def convert_to_dict(self, rep):
...         return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)})
...
...     def convert_to_tensor(self, rep_dict):
...         return torch.stack(list(rep_dict.values()), axis=1)
>>> img = po.data.einstein()
>>> img2 = torch.rand_like(img)
>>> model = TestModel()
>>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img)
>>> loss(model(img), model(img2))
tensor(0.6512)
>>> po.loss.l2_norm(model(img), model(img2))
tensor(78.5674)

Use reweighting_dict to further tweak weighting.

>>> import plenoptic as po
>>> from collections import OrderedDict
>>> import torch
>>> po.set_seed(0)
>>> class TestModel(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False)
...         self.kernel.weight.detach_()
...
...     def forward(self, x):
...         return self.kernel(x)
...
...     def convert_to_dict(self, rep):
...         return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)})
...
...     def convert_to_tensor(self, rep_dict):
...         return torch.stack(list(rep_dict.values()), axis=1)
>>> img = po.data.einstein()
>>> img2 = torch.rand_like(img)
>>> model = TestModel()
>>> reweighting_dict = {"channel_0": 0.5}
>>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict)
>>> loss(model(img), model(img2))
tensor(0.4822)
>>> # channel_0 is of shape (1, 256, 256)
>>> channel_0 = torch.ones_like(model.convert_to_dict(model(img))["channel_0"])
>>> channel_0[..., 128:] = 0
>>> reweighting_dict = {"channel_0": channel_0}
>>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict)
>>> loss(model(img), model(img2))
tensor(0.5612)