plenoptic.loss.groupwise_relative_l2_norm_factory#
- plenoptic.loss.groupwise_relative_l2_norm_factory(model, image, reweighting_dict=None)[source]#
Create loss function that computes groupwise relative L2 norm for synthesis.
This loss factory returns a callable which should make optimization easier when used as the
loss_functionwhen initializingMetamerfor synthesizing metamers. The resulting loss function will normalize each group within the representation by the L2 norm of that group onimage, which should be the target image for that synthesis.This requires that
modelhas two methods,convert_to_dictandconvert_to_tensor, which convert the representation between a tensor (as returned byforward) and anOrderedDict. The dictionary representation should have keys that define the different groups within the representation, and its values should be tensors (of any shape).The optional
reweighting_dictargument allows users to further tweak the weights, if necessary. If notNone, keys should be a subset of those found in the output ofmodel.convert_to_dict, and whose values are Tensors (broadcastable to the shape of the corresponding values inmodel.convert_to_dictoutput) which will be multiplied by the corresponding group after normalization. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely.For an example of a compliant model, see the
PortillaSimoncellimodel.- Parameters:
- Return type:
- Returns:
loss_func – A callable to use as your loss function for metamer synthesis.
- Raises:
ValueError – If
reweighting_dictcontains keys not found in the model representation (model.convert_to_dict(model(image))).- Warns:
UserWarning – If
model.convert_to_dict()does not return anOrderedDict.convert_to_dictandconvert_to_tensorneed to invert each other, which means you should probably use anOrderedDict, which guarantees that the order of the keys is preserved. You can usevalidate_convert_tensor_dictto heuristically check whether your model satisfies this constraint.
Examples
Create the loss function with a simple model.
>>> import plenoptic as po >>> from collections import OrderedDict >>> import torch >>> po.set_seed(0) >>> class TestModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False) ... self.kernel.weight.detach_() ... ... def forward(self, x): ... return self.kernel(x) ... ... def convert_to_dict(self, rep): ... return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)}) ... ... def convert_to_tensor(self, rep_dict): ... return torch.stack(list(rep_dict.values()), axis=1) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = TestModel() >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img) >>> loss(model(img), model(img2)) tensor(0.6512) >>> po.loss.l2_norm(model(img), model(img2)) tensor(78.5674)
Use
reweighting_dictto further tweak weighting.>>> import plenoptic as po >>> from collections import OrderedDict >>> import torch >>> po.set_seed(0) >>> class TestModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False) ... self.kernel.weight.detach_() ... ... def forward(self, x): ... return self.kernel(x) ... ... def convert_to_dict(self, rep): ... return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)}) ... ... def convert_to_tensor(self, rep_dict): ... return torch.stack(list(rep_dict.values()), axis=1) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = TestModel() >>> reweighting_dict = {"channel_0": 0.5} >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(0.4822) >>> # channel_0 is of shape (1, 256, 256) >>> channel_0 = torch.ones_like(model.convert_to_dict(model(img))["channel_0"]) >>> channel_0[..., 128:] = 0 >>> reweighting_dict = {"channel_0": channel_0} >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(0.5612)