Source code for plenoptic.loss

"""Loss functions for image synthesis."""
# numpydoc ignore=ES01

# to avoid circular import error:
# https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
import warnings
from collections import OrderedDict
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from .models import PortillaSimoncelli

import numpy as np
import torch
from torch import Tensor

__all__ = [
    "mse",
    "l2_norm",
    "relative_sse",
    "portilla_simoncelli_loss_factory",
]


def __dir__() -> list[str]:
    return __all__


[docs] def set_seed(seed: int | None = None) -> None: """ Set the seed. We call both :func:`torch.manual_seed()` and :func:`numpy.random.seed()`. Parameters ---------- seed The seed to set. If ``None``, do nothing. """ if seed is not None: # random initialization torch.manual_seed(seed) np.random.seed(seed)
[docs] def mse(synth_rep: Tensor, ref_rep: Tensor, **kwargs: Any) -> Tensor: r""" Calculate the MSE between ``synth_rep`` and ``ref_rep``. For two tensors, :math:`x` and :math:`y`, with :math:`n` values each: .. math:: MSE = \frac{1}{n}\sum_{i=1}^n (x_i - y_i)^2 The two images must have a float dtype. Parameters ---------- synth_rep The first tensor to compare, model representation of the synthesized image. ref_rep The second tensor to compare, model representation of the reference image. must be same size as ``synth_rep``. **kwargs Ignored, only present to absorb extra arguments. Returns ------- loss The mean-squared error between ``synth_rep`` and ``ref_rep``. """ return torch.pow(synth_rep - ref_rep, 2).mean()
[docs] def l2_norm(synth_rep: Tensor, ref_rep: Tensor, **kwargs: Any) -> Tensor: r""" Calculate the L2-norm of the difference between ``ref_rep`` and ``synth_rep``. For two tensors, :math:`x` and :math:`y`, with :math:`n` values each: .. math:: L2 = \sqrt{\sum_{i=1}^n (x_i - y_i)^2} Parameters ---------- synth_rep The first tensor to compare, model representation of the synthesized image. ref_rep The second tensor to compare, model representation of the reference image. must be same size as ``synth_rep``. **kwargs Ignored, only present to absorb extra arguments. Returns ------- loss The L2-norm of the difference between ``ref_rep`` and ``synth_rep``. """ return torch.linalg.vector_norm(ref_rep - synth_rep, ord=2)
[docs] def relative_sse(synth_rep: Tensor, ref_rep: Tensor, **kwargs: Any) -> Tensor: r""" Calculate the relative sum of squared errors between two tensors. This is the squared L2-norm of the difference between reference representation and synthesized representation relative to the squared L2-norm of the reference representation: For two tensors, :math:`x` and :math:`y`: .. math:: \frac{||x - y||_2^2}{||x||_2^2} where :math:`x` is ``ref_rep``, :math:`x` is ``synth_rep``, and :math:`||x||_2` is the L2-norm. Parameters ---------- synth_rep The first tensor to compare, model representation of the synthesized image. ref_rep The second tensor to compare, model representation of the reference image. must be same size as ``synth_rep``. **kwargs Ignored, only present to absorb extra arguments. Returns ------- loss Ratio of the squared l2-norm of the difference between ``ref_rep`` and ``synth_rep`` to the squared l2-norm of ``ref_rep``. """ return ( torch.linalg.vector_norm(ref_rep - synth_rep, ord=2) ** 2 / torch.linalg.vector_norm(ref_rep, ord=2) ** 2 )
def _groupwise_l2_norm_weights( model: torch.nn.Module, image: Tensor, reweighting_dict: dict[str, Tensor | float] | None = None, ) -> dict[str, Tensor]: r""" Compute groupwise L2 norm, as a tensor for reweighting. This function returns a tensor that can be used to perform a groupwise reweighting of a model's representation. It is used by :func:`~plenoptic.loss.groupwise_relative_l2_norm_factory` and similar functions, which normalize model representations so that all statistics are roughly the same scale, which makes optimization easier. This requires that ``model`` has a ``convert_to_dict`` method, which converts the representation from a tensor (as returned by ``forward``) to a dictionary. The dictionary representation should have keys that define the different groups within the representation, and its values should be tensors (of any shape). The optional ``reweighting_dict`` argument allows users to further tweak the weights, if necessary. If not ``None``, keys should be a subset of those found in the output of ``model.convert_to_dict``, and whose values are Tensors (broadcastable to the shape of the corresponding values in ``model.convert_to_dict`` output) which will be multiplied by the corresponding group *after* normalization. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely. For an example of a compliant model, see the :class:`~plenoptic.models.PortillaSimoncelli` model. Parameters ---------- model An instantiated model. image The target image for metamer synthesis. reweighting_dict Dictionary specifying further reweighting. See above for details. Returns ------- weights Dictionary containing the L2 norm of each statistic group. Should probably be passed to ``model.convert_to_tensor``, but left in this form in case any further reweighting needs to be done (e.g., zeroing out some component). Raises ------ ValueError If ``reweighting_dict`` contains keys not found in the model representation (``model.convert_to_dict(model(image))``). Warns ----- UserWarning If ``model.convert_to_dict()`` does not return an :class:`~collections.OrderedDict`. ``convert_to_dict`` and ``convert_to_tensor`` need to invert each other, which means you should probably use an :class:`~collections.OrderedDict`, which guarantees that the order of the keys is preserved. You can use :func:`~plenoptic.validate.validate_convert_tensor_dict` to heuristically check whether your model satisfies this constraint. """ if reweighting_dict is None: reweighting_dict = {} weights = {} rep = model.convert_to_dict(model(image)) if not isinstance(rep, OrderedDict): warnings.warn( "model.convert_to_dict did not return an OrderedDict. This might " "not be a problem, but convert_to_dict and convert_to_tensor must" " invert each other. Calling " "plenoptic.validate.validate_convert_tensor_dict(model)" " will attempt to validate this constraint." ) if extra_keys := set(reweighting_dict.keys()) - set(rep.keys()): raise ValueError( "reweighting_dict contains keys not found in model representation!" f" {extra_keys}" ) for k, v in rep.items(): wt = torch.linalg.vector_norm(v[~v.isnan()], ord=2) weights[k] = reweighting_dict.get(k, 1) * torch.ones_like(v) / wt return weights
[docs] def groupwise_relative_l2_norm_factory( model: torch.nn.Module, image: Tensor, reweighting_dict: dict[str, Tensor | float] | None = None, ) -> Callable[[Tensor, Tensor], Tensor]: r""" Create loss function that computes groupwise relative L2 norm for synthesis. This loss factory returns a callable which should make optimization easier when used as the ``loss_function`` when initializing :class:`~plenoptic.Metamer` for synthesizing metamers. The resulting loss function will normalize each group within the representation by the L2 norm of that group on ``image``, which should be the target image for that synthesis. This requires that ``model`` has two methods, ``convert_to_dict`` and ``convert_to_tensor``, which convert the representation between a tensor (as returned by ``forward``) and an :class:`~collections.OrderedDict`. The dictionary representation should have keys that define the different groups within the representation, and its values should be tensors (of any shape). The optional ``reweighting_dict`` argument allows users to further tweak the weights, if necessary. If not ``None``, keys should be a subset of those found in the output of ``model.convert_to_dict``, and whose values are Tensors (broadcastable to the shape of the corresponding values in ``model.convert_to_dict`` output) which will be multiplied by the corresponding group *after* normalization. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely. For an example of a compliant model, see the :class:`~plenoptic.models.PortillaSimoncelli` model. Parameters ---------- model An instantiated model. image The target image for metamer synthesis. reweighting_dict Dictionary specifying further reweighting. See above for details. Returns ------- loss_func A callable to use as your loss function for metamer synthesis. Raises ------ ValueError If ``reweighting_dict`` contains keys not found in the model representation (``model.convert_to_dict(model(image))``). Warns ----- UserWarning If ``model.convert_to_dict()`` does not return an :class:`~collections.OrderedDict`. ``convert_to_dict`` and ``convert_to_tensor`` need to invert each other, which means you should probably use an :class:`~collections.OrderedDict`, which guarantees that the order of the keys is preserved. You can use :func:`~plenoptic.validate.validate_convert_tensor_dict` to heuristically check whether your model satisfies this constraint. Examples -------- Create the loss function with a simple model. >>> import plenoptic as po >>> from collections import OrderedDict >>> import torch >>> po.set_seed(0) >>> class TestModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False) ... self.kernel.weight.detach_() ... ... def forward(self, x): ... return self.kernel(x) ... ... def convert_to_dict(self, rep): ... return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)}) ... ... def convert_to_tensor(self, rep_dict): ... return torch.stack(list(rep_dict.values()), axis=1) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = TestModel() >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img) >>> loss(model(img), model(img2)) tensor(0.6512) >>> po.loss.l2_norm(model(img), model(img2)) tensor(78.5674) Use ``reweighting_dict`` to further tweak weighting. >>> import plenoptic as po >>> from collections import OrderedDict >>> import torch >>> po.set_seed(0) >>> class TestModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False) ... self.kernel.weight.detach_() ... ... def forward(self, x): ... return self.kernel(x) ... ... def convert_to_dict(self, rep): ... return OrderedDict({f"channel_{i}": rep[:, i] for i in range(2)}) ... ... def convert_to_tensor(self, rep_dict): ... return torch.stack(list(rep_dict.values()), axis=1) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = TestModel() >>> reweighting_dict = {"channel_0": 0.5} >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(0.4822) >>> # channel_0 is of shape (1, 256, 256) >>> channel_0 = torch.ones_like(model.convert_to_dict(model(img))["channel_0"]) >>> channel_0[..., 128:] = 0 >>> reweighting_dict = {"channel_0": channel_0} >>> loss = po.loss.groupwise_relative_l2_norm_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(0.5612) """ weights = _groupwise_l2_norm_weights(model, image, reweighting_dict) weights = model.convert_to_tensor(weights) def loss(x: Tensor, y: Tensor) -> Tensor: # numpydoc ignore=GL08 return l2_norm(weights * x, weights * y) return loss
[docs] def portilla_simoncelli_loss_factory( model: "PortillaSimoncelli", image: Tensor, reweighting_dict: dict[str, Tensor | float] | None = None, ) -> Callable[[Tensor, Tensor], Tensor]: """ Create the loss function required for ``PortillaSimoncelli`` metamer synthesis. This loss factory returns a callable which should be used as the ``loss_function`` when initializing :class:`~plenoptic.Metamer` for synthesizing metamers with the :class:`~plenoptic.models.PortillaSimoncelli` model. It zeroes the model's representation of the images' min/max pixel values and increases the weight on the variance of the highpass residuals before computing the L2-norm. The optional ``reweighting_dict`` argument allows users to tweak the weights. If not ``None``, keys should be a subset of those found in the output of :func:`~plenoptic.models.PortillaSimoncelli.convert_to_dict` and whose values are Tensors (broadcastable to the shape of the corresponding values in ``convert_to_dict`` output) which will be multiplied by the corresponding group. Thus, a number greater than 1 will increase its weight in the loss, a number less than 1 will decrease the weight, and 0 will remove it from the calculation entirely. ``reweighting_dict`` takes precedence, so e.g., if it includes a ``"pixel_statistics"`` key, that will dictate how min/max pixel values are weighted. To understand how the returned loss works and see how to write your own loss factory, see :ref:`ps-optimization`. Parameters ---------- model An instantiated :class:`~plenoptic.models.PortillaSimoncelli` model. image The target image for metamer synthesis, or an image with the same shape, dtype, and device. reweighting_dict Dictionary specifying reweighting. See above for details. Returns ------- loss_func A callable to use as your loss function for ``PortillaSimoncelli`` metamer synthesis. Raises ------ ValueError If ``reweighting_dict`` contains keys not found in the model representation (``model.convert_to_dict(model(image))``). ValueError If model representation (``model.convert_to_dict(model(image))``) includes the key ``"pixel_statistics"`` but the corresponding tensor does not have ``shape[-1] == 6`` or if it includes the key ``"var_highpass_residual"`` but the corresponding tensor does not have ``shape[-1] == 1`` and the corresponding key is not included explicitly in ``reweighting_dict``. Warns ----- UserWarning If model representation (``model.convert_to_dict(model(image))``) does not include the keys ``"pixel_statistics"`` or ``"var_highpass_residual"``. Examples -------- Create the loss function. >>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img) >>> loss(model(img), model(img2)) tensor(30.9390) >>> po.loss.l2_norm(model(img), model(img2)) tensor(30.5549) Use the loss function for metamer synthesis. >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img) >>> met = po.Metamer(img, model, loss_function=loss) Use ``reweighting_dict`` to increase weight on image pixel moments, while keeping min/max out of the loss. The model includes 6 pixel stats (see :ref:`ps-model-stats` for details) >>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein() >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> rep = model.convert_to_dict(model(img)) >>> pixel_stats = torch.as_tensor([10, 10, 10, 10, 0, 0]) >>> pixel_stats = pixel_stats * torch.ones_like(rep["pixel_statistics"]) >>> reweighting_dict = {"pixel_statistics": pixel_stats} >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(35.1118) Use ``reweighting_dict`` to include min/max in the loss and increase the importance of the standard deviations of the magnitude bands. >>> import plenoptic as po >>> import torch >>> po.set_seed(0) >>> img = po.data.einstein().to(torch.float64) >>> img2 = torch.rand_like(img) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> reweighting_dict = {"pixel_statistics": 1, "magnitude_std": 100} >>> loss = po.loss.portilla_simoncelli_loss_factory(model, img, reweighting_dict) >>> loss(model(img), model(img2)) tensor(253.2572, dtype=torch.float64) """ if reweighting_dict is None: reweighting_dict = {} weights = model.convert_to_dict(torch.ones_like(model(image))) # do this before adding defaults for pixel_stats and var_highpass_residual if extra_keys := set(reweighting_dict.keys()) - set(weights.keys()): raise ValueError( "reweighting_dict contains key(s) not found in model representation! " f"{extra_keys}" ) if "pixel_statistics" in weights: pixel_stats = torch.ones_like(weights["pixel_statistics"]) pixel_stats[..., -2:] = 0 if pixel_stats.shape[-1] != 6 and "pixel_statistics" not in reweighting_dict: raise ValueError( "Expected model's 'pixel_statistics' representation " f"to have 6 values, but it has {pixel_stats.shape[-1]}" " values instead! Unsure what corresponds to the " "min/max, set this directly in reweighting_dict" ) reweighting_dict.setdefault("pixel_statistics", pixel_stats) else: warnings.warn( "pixel_statistics not found in your model representation, " "continuing without removing them. Hope you know what " "you're doing..." ) if "var_highpass_residual" in weights: n_highpass = weights["var_highpass_residual"].shape[-1] if n_highpass != 1 and "var_highpass_residual" not in reweighting_dict: raise ValueError( "Expected model's 'var_highpass_residual' representation " f"to have 1 value, but it has {n_highpass}" " values instead! Unsure how to handle this," " set directly in reweighting_dict" ) reweighting_dict.setdefault("var_highpass_residual", 100) else: warnings.warn( "var_highpass_residual not found in your model representation, " "continuing without reweighting them. Hope you know what " "you're doing..." ) for k in weights: weights[k] *= reweighting_dict.get(k, 1) weights = model.convert_to_tensor(weights) def loss(x: Tensor, y: Tensor) -> Tensor: # numpydoc ignore=GL08 return l2_norm(weights * x, weights * y) return loss