plenoptic.metric.model_metric_factory#
- plenoptic.metric.model_metric_factory(model)[source]#
Create a metric function which returns the root mean squared error in model space.
The returned callable will compute, for two images, \(x\) and \(y\), and model \(M\):
\[metric = \sqrt{\frac{1}{n}\sum_i (M(x)_i - M(y)_i)^2+\epsilon}\]where \(M(x)\) and \(M(y)\) are the model representations of
xandy, with \(n\) elements, and \(\epsilon=1e-10\) is to stabilize the gradient around zero.This allows users to convert models into metrics.
- Parameters:
model (
Module) – Torch model with defined forward operation.- Return type:
- Returns:
metric_func – A callable which accepts two tensors and returns their root mean squared error in model space.
Examples
>>> import plenoptic as po >>> einstein_img = po.data.einstein() >>> curie_img = po.data.curie() >>> model = po.models.Gaussian(30) >>> model_metric = po.metric.model_metric_factory(model) >>> model_metric(einstein_img, curie_img) tensor(0.3128, grad_fn=<SqrtBackward0>) >>> # calculate this model metric manually: >>> torch.mean((model(einstein_img) - model(curie_img)).pow(2)).sqrt() tensor(0.3128, grad_fn=<SqrtBackward0>)