Source code for plenoptic.metric._model_metric
"""
Model metrics.
Simple functions to convert models, which can return a tensor of arbitrary shape, to
metrics, which must return a tensor.
""" # numpydoc ignore=EX01
import torch
__all__ = [
"model_metric_factory",
]
def __dir__() -> list[str]:
return __all__
[docs]
def model_metric_factory(model: torch.nn.Module) -> torch.Tensor:
r"""
Create a metric function which returns the root mean squared error in model space.
The returned callable will compute, for two images, :math:`x` and :math:`y`, and
model :math:`M`:
.. math::
metric = \sqrt{\frac{1}{n}\sum_i (M(x)_i - M(y)_i)^2+\epsilon}
where :math:`M(x)` and :math:`M(y)` are the model representations of ``x`` and
``y``, with :math:`n` elements, and :math:`\epsilon=1e-10` is to stabilize the
gradient around zero.
This allows users to convert models into metrics.
Parameters
----------
model
Torch model with defined forward operation.
Returns
-------
metric_func
A callable which accepts two tensors and returns their root mean squared error
in model space.
Examples
--------
>>> import plenoptic as po
>>> einstein_img = po.data.einstein()
>>> curie_img = po.data.curie()
>>> model = po.models.Gaussian(30)
>>> model_metric = po.metric.model_metric_factory(model)
>>> model_metric(einstein_img, curie_img)
tensor(0.3128, grad_fn=<SqrtBackward0>)
>>> # calculate this model metric manually:
>>> torch.mean((model(einstein_img) - model(curie_img)).pow(2)).sqrt()
tensor(0.3128, grad_fn=<SqrtBackward0>)
"""
def metric(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# numpydoc ignore=GL08
repx = model(x)
repy = model(y)
# for optimization purpose (stabilizing the gradient around zero)
epsilon = 1e-10
return torch.sqrt(torch.mean((repx - repy) ** 2) + epsilon)
return metric