plenoptic.validate.validate_penalty#

plenoptic.validate.validate_penalty(penalty_function, image_shape=None, image_dtype=torch.float32, device='cpu')[source]#

Determine whether penalty_function can be used for regularization in synthesis.

In particular, this function checks the following (with their associated errors raised):

  • Whether penalty_function is callable and accepts a single tensor of shape image_shape as input (TypeError).

  • Whether penalty_function returns a scalar when called with a tensor of shape image_shape as input (ValueError).

  • If penalty_function adds a gradient to an input tensor, which implies that learnable parameters are being used (ValueError).

  • If penalty_function returns a tensor when given a tensor, failure implies that not all computations are done using torch (ValueError).

  • If penalty_function strips gradient from an input with gradient attached (ValueError).

  • If penalty_function casts an input tensor to something else and returns it to a tensor before returning it (ValueError).

  • If penalty_function changes the precision of the input tensor (TypeError)

  • If penalty_function returns a complex output (TypeError).

  • If penalty_function changes the device of the input (RuntimeError).

Parameters:
  • penalty_function (Module | Callable[[Tensor], Tensor]) – The penalty function to validate.

  • image_shape (tuple[int, int, int, int] | None (default: None)) – Some models (e.g., the steerable pyramid) can only accept inputs of a certain shape. If that’s the case for model, use this to specify the expected shape. If None, we use an image of shape (1,1,16,16).

  • image_dtype (dtype (default: torch.float32)) – What dtype to validate against.

  • device (str | device (default: 'cpu')) – What device to place test image on.

Raises:
  • ValueError – If penalty_function fails one of the checks listed above.

  • TypeError – If penalty_function changes the precision of the input tensor.

  • TypeError – If penalty_function returns a complex tensor.

  • RuntimeError – If penalty_function changes the device of the input tensor.

Examples

Check that one of our built-in penalty functions work:

>>> import plenoptic as po
>>> penalty_fun = po.regularize.penalize_range
>>> po.validate.validate_penalty(penalty_fun)

Intentionally fail:

>>> import plenoptic as po
>>> import torch
>>> def failure_penalty_dtype(img):
...     wrong_dtype = torch.mean(img)
...     return wrong_dtype.to(dtype=torch.float64)
>>> po.validate.validate_penalty(failure_penalty_dtype)
Traceback (most recent call last):
TypeError: penalty_function should not change precision...