plenoptic.validate.validate_penalty#
- plenoptic.validate.validate_penalty(penalty_function, image_shape=None, image_dtype=torch.float32, device='cpu')[source]#
Determine whether
penalty_functioncan be used for regularization in synthesis.In particular, this function checks the following (with their associated errors raised):
Whether
penalty_functionis callable and accepts a single tensor of shapeimage_shapeas input (TypeError).Whether
penalty_functionreturns a scalar when called with a tensor of shapeimage_shapeas input (ValueError).If
penalty_functionadds a gradient to an input tensor, which implies that learnable parameters are being used (ValueError).If
penalty_functionreturns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError).If
penalty_functionstrips gradient from an input with gradient attached (ValueError).If
penalty_functioncasts an input tensor to something else and returns it to a tensor before returning it (ValueError).If
penalty_functionchanges the precision of the input tensor (TypeError)If
penalty_functionreturns a complex output (TypeError).If
penalty_functionchanges the device of the input (RuntimeError).
- Parameters:
penalty_function (
Module|Callable[[Tensor],Tensor]) – The penalty function to validate.image_shape (
tuple[int,int,int,int] |None(default:None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case formodel, use this to specify the expected shape. IfNone, we use an image of shape(1,1,16,16).image_dtype (
dtype(default:torch.float32)) – What dtype to validate against.device (
str|device(default:'cpu')) – What device to place test image on.
- Raises:
ValueError – If
penalty_functionfails one of the checks listed above.TypeError – If
penalty_functionchanges the precision of the input tensor.TypeError – If
penalty_functionreturns a complex tensor.RuntimeError – If
penalty_functionchanges the device of the input tensor.
Examples
Check that one of our built-in penalty functions work:
>>> import plenoptic as po >>> penalty_fun = po.regularize.penalize_range >>> po.validate.validate_penalty(penalty_fun)
Intentionally fail:
>>> import plenoptic as po >>> import torch >>> def failure_penalty_dtype(img): ... wrong_dtype = torch.mean(img) ... return wrong_dtype.to(dtype=torch.float64) >>> po.validate.validate_penalty(failure_penalty_dtype) Traceback (most recent call last): TypeError: penalty_function should not change precision...