plenoptic.validate.validate_convert_tensor_dict#
- plenoptic.validate.validate_convert_tensor_dict(model, image_shape=None, device='cpu', n_checks=100)[source]#
Determine if a model can properly convert between tensor and dict representations.
In general, when converting between these two representations, one should use
OrderedDictinstead of a regular dictionary. If, for some reason, you don’t want to do that, this function will attempt to check thatmodel.convert_to_tensorandmodel.convert_to_dictinvert each other, running the comparison onn_checksindependent random images.WARNING: This is a heuristic and cannot guarantee that the order never changes.
- Parameters:
model (
Module) – The model to validate.image_shape (
tuple[int,int,int,int] |None(default:None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel, use this to specify the expected shape. IfNone, we use an image of shape(1,1,16,16).device (
str|device(default:'cpu')) – Which device to place the test image on.n_checks (
int(default:100)) – How many independent random images to run the check on.
- Raises:
ValueError – If any of the checks fail.
Examples
Check that one of our built-in models work:
>>> import plenoptic as po >>> model = po.models.PortillaSimoncelli((256, 256)) >>> po.validate.validate_convert_tensor_dict(model, image_shape=(1, 1, 256, 256))
Intentionally fail:
>>> import plenoptic as po >>> import torch >>> # this model fails because we intentionally rearrange the channels >>> # in convert_to_tensor >>> class FailureModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.kernel = torch.nn.Conv2d(1, 2, (5, 5), bias=False) ... self.kernel.weight.detach_() ... ... def forward(self, x): ... return self.kernel(x) ... ... def convert_to_dict(self, rep): ... return {f"channel_{i}": rep[:, i] for i in range(2)} ... ... def convert_to_tensor(self, rep_dict): ... return torch.stack( ... [rep_dict["channel_1"], rep_dict["channel_0"]], axis=1 ... ) >>> shape = (1, 1, 256, 256) >>> model = FailureModel() >>> po.validate.validate_convert_tensor_dict(model) Traceback (most recent call last): ValueError: On random image 0, model.convert_to_dict did not invert...