plenoptic.validate.validate_model#
- plenoptic.validate.validate_model(model, image_shape=None, image_dtype=torch.float32, device='cpu')[source]#
Determine whether model can be used for synthesis.
In particular, this function checks the following (with their associated errors raised):
If
modeladds a gradient to an input tensor, which implies that some of it is learnable (ValueError).If
modelreturns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError).If
modelstrips gradient from an input with gradient attached (ValueError).If
modelcasts an input tensor to something else and returns it to a tensor before returning it (ValueError).If
modelchanges the precision of the input tensor (TypeError).If
modelchanges the device of the input (RuntimeError).
Finally, we raise a
UserWarning:If
modelis in training mode. Note that this is different from having learnable parameters, see pytorch docs.If
modelreturns an output with other than 3 or 4 dimensions when given a tensor with shapeimage_shape.
- Parameters:
model (
Module) – The model to validate.image_shape (
tuple[int,int,int,int] |None(default:None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel, use this to specify the expected shape. IfNone, we use an image of shape(1,1,16,16).image_dtype (
dtype(default:torch.float32)) – What dtype to validate against.device (
str|device(default:'cpu')) – What device to place test image on.
- Raises:
ValueError – If
modelfails one of the checks listed above.TypeError – If
modelchanges the precision of the input tensor.RuntimeError – If
modelchanges the device of the input tensor.
- Warns:
UserWarning – If
modelis in training mode.UserWarning – If returns an output with other than 3 or 4 dimensions.
See also
remove_gradHelper function for detaching all parameters (in place).
Examples
Check that one of our built-in models work:
>>> import plenoptic as po >>> model = po.models.PortillaSimoncelli((256, 256)) >>> po.validate.validate_model(model, image_shape=(1, 1, 256, 256))
Intentionally fail:
>>> import plenoptic as po >>> import torch >>> class FailureModel(torch.nn.Module): ... def __init__(self): ... super().__init__() ... ... def forward(self, x): ... x = x.detach().numpy() ... return torch.as_tensor(x) >>> po.validate.validate_model(FailureModel()) Traceback (most recent call last): ValueError: model strips gradient from input, ...