plenoptic.metric.model_metric_factory#

plenoptic.metric.model_metric_factory(model)[source]#

Create a metric function which returns the root mean squared error in model space.

The returned callable will compute, for two images, \(x\) and \(y\), and model \(M\):

\[metric = \sqrt{\frac{1}{n}\sum_i (M(x)_i - M(y)_i)^2+\epsilon}\]

where \(M(x)\) and \(M(y)\) are the model representations of x and y, with \(n\) elements, and \(\epsilon=1e-10\) is to stabilize the gradient around zero.

This allows users to convert models into metrics.

Parameters:

model (Module) – Torch model with defined forward operation.

Return type:

Tensor

Returns:

metric_func – A callable which accepts two tensors and returns their root mean squared error in model space.

Examples

>>> import plenoptic as po
>>> einstein_img = po.data.einstein()
>>> curie_img = po.data.curie()
>>> model = po.models.Gaussian(30)
>>> model_metric = po.metric.model_metric_factory(model)
>>> model_metric(einstein_img, curie_img)
tensor(0.3128, grad_fn=<SqrtBackward0>)
>>> # calculate this model metric manually:
>>> torch.mean((model(einstein_img) - model(curie_img)).pow(2)).sqrt()
tensor(0.3128, grad_fn=<SqrtBackward0>)