plenoptic.validate.validate_model#

plenoptic.validate.validate_model(model, image_shape=None, image_dtype=torch.float32, device='cpu')[source]#

Determine whether model can be used for synthesis.

In particular, this function checks the following (with their associated errors raised):

  • If model adds a gradient to an input tensor, which implies that some of it is learnable (ValueError).

  • If model returns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError).

  • If model strips gradient from an input with gradient attached (ValueError).

  • If model casts an input tensor to something else and returns it to a tensor before returning it (ValueError).

  • If model changes the precision of the input tensor (TypeError).

  • If model changes the device of the input (RuntimeError).

Finally, we raise a UserWarning:

  • If model is in training mode. Note that this is different from having learnable parameters, see pytorch docs.

  • If model returns an output with other than 3 or 4 dimensions when given a tensor with shape image_shape.

Parameters:
  • model (Module) – The model to validate.

  • image_shape (tuple[int, int, int, int] | None (default: None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16).

  • image_dtype (dtype (default: torch.float32)) – What dtype to validate against.

  • device (str | device (default: 'cpu')) – What device to place test image on.

Raises:
  • ValueError – If model fails one of the checks listed above.

  • TypeError – If model changes the precision of the input tensor.

  • RuntimeError – If model changes the device of the input tensor.

Warns:
  • UserWarning – If model is in training mode.

  • UserWarning – If returns an output with other than 3 or 4 dimensions.

See also

remove_grad

Helper function for detaching all parameters (in place).

Examples

Check that one of our built-in models work:

>>> import plenoptic as po
>>> model = po.models.PortillaSimoncelli((256, 256))
>>> po.validate.validate_model(model, image_shape=(1, 1, 256, 256))

Intentionally fail:

>>> import plenoptic as po
>>> import torch
>>> class FailureModel(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...
...     def forward(self, x):
...         x = x.detach().numpy()
...         return torch.as_tensor(x)
>>> po.validate.validate_model(FailureModel())
Traceback (most recent call last):
ValueError: model strips gradient from input, ...