plenoptic.validate.validate_convert_tensor_dict#

plenoptic.validate.validate_convert_tensor_dict(model, image_shape=None, device='cpu', n_checks=100)[source]#

Determine if a model can properly convert between tensor and dict representations.

In general, when converting between these two representations, one should use OrderedDict instead of a regular dictionary. If, for some reason, you don’t want to do that, this function will attempt to check that model.convert_to_tensor and model.convert_to_dict invert each other, running the comparison on n_checks independent random images.

WARNING: This is a heuristic and cannot guarantee that the order never changes.

Parameters:
  • model (Module) – The model to validate.

  • image_shape (tuple[int, int, int, int] | None (default: None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16).

  • device (str | device (default: 'cpu')) – Which device to place the test image on.

  • n_checks (int (default: 100)) – How many independent random images to run the check on.

Raises:

ValueError – If any of the checks fail.

Examples

Check that one of our built-in models work:

>>> import plenoptic as po
>>> model = po.models.PortillaSimoncelli((256, 256))
>>> po.validate.validate_convert_tensor_dict(model, image_shape=(1, 1, 256, 256))

Intentionally fail:

>>> import plenoptic as po
>>> import torch
>>> # this model fails because we intentionally rearrange the channels
>>> # in convert_to_tensor
>>> class FailureModel(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False)
...         self.kernel.weight.detach_()
...
...     def forward(self, x):
...         return self.kernel(x)
...
...     def convert_to_dict(self, rep):
...         return {f"channel_{i}": rep[:, i] for i in range(2)}
...
...     def convert_to_tensor(self, rep_dict):
...         return torch.stack(
...             [rep_dict["channel_1"], rep_dict["channel_0"]], axis=1
...         )
>>> shape = (1, 1, 256, 256)
>>> model = FailureModel()
>>> po.validate.validate_convert_tensor_dict(model)
Traceback (most recent call last):
ValueError: On random image 0, model.convert_to_dict did not invert...