Source code for plenoptic.metric.model_metric
import torch
[docs]
def model_metric(x, y, model):
"""
Calculate distance between x and y in model space root mean squared error
Parameters
----------
image: torch.Tensor
image, (B x C x H x W)
model: torch class
torch model with defined forward and backward operations
Notes
-----
"""
repx = model(x)
repy = model(y)
# for optimization purpose (stabilizing the gradient around zero)
epsilon = 1e-10
dist = torch.sqrt(torch.mean((repx - repy) ** 2) + epsilon)
return dist