Source code for plenoptic.metric._model_metric

"""
Model metrics.

Simple functions to convert models, which can return a tensor of arbitrary shape, to
metrics, which must return a tensor.
"""  # numpydoc ignore=EX01

import torch

__all__ = [
    "model_metric_factory",
]


def __dir__() -> list[str]:
    return __all__


[docs] def model_metric_factory(model: torch.nn.Module) -> torch.Tensor: r""" Create a metric function which returns the root mean squared error in model space. The returned callable will compute, for two images, :math:`x` and :math:`y`, and model :math:`M`: .. math:: metric = \sqrt{\frac{1}{n}\sum_i (M(x)_i - M(y)_i)^2+\epsilon} where :math:`M(x)` and :math:`M(y)` are the model representations of ``x`` and ``y``, with :math:`n` elements, and :math:`\epsilon=1e-10` is to stabilize the gradient around zero. This allows users to convert models into metrics. Parameters ---------- model Torch model with defined forward operation. Returns ------- metric_func A callable which accepts two tensors and returns their root mean squared error in model space. Examples -------- >>> import plenoptic as po >>> einstein_img = po.data.einstein() >>> curie_img = po.data.curie() >>> model = po.models.Gaussian(30) >>> model_metric = po.metric.model_metric_factory(model) >>> model_metric(einstein_img, curie_img) tensor(0.3128, grad_fn=<SqrtBackward0>) >>> # calculate this model metric manually: >>> torch.mean((model(einstein_img) - model(curie_img)).pow(2)).sqrt() tensor(0.3128, grad_fn=<SqrtBackward0>) """ def metric(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # numpydoc ignore=GL08 repx = model(x) repy = model(y) # for optimization purpose (stabilizing the gradient around zero) epsilon = 1e-10 return torch.sqrt(torch.mean((repx - repy) ** 2) + epsilon) return metric