Source code for plenoptic.regularize
"""Regularization functions for image synthesis."""
# numpydoc ignore=ES01
from typing import Any
import torch
from torch import Tensor
__all__ = ["penalize_range"]
def __dir__() -> list[str]:
return __all__
[docs]
def penalize_range(
img: Tensor,
allowed_range: tuple[float, float] = (0.0, 1.0),
**kwargs: Any,
) -> Tensor:
r"""
Calculate quadratic penalty on values outside of ``allowed_range``.
Provides a 'soft' pixel-range regularization by imposing a
quadratic penalty on any values outside the allowed_range.
All values within the allowed_range have a penalty of 0.
To use as a ``penalty_function`` in synthesis methods,
``functools.partial`` must be used to fix the ``allowed_range``
(see Examples).
Parameters
----------
img
The tensor to penalize.
allowed_range
2-tuple of values giving the (min, max) allowed values.
**kwargs
Ignored, only present to absorb extra arguments.
Returns
-------
penalty
Penalty for values outside range.
Examples
--------
Initialize metamer with allowed range (0.2, 0.8):
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> # Make function penalizing values outside (0.2, 0.8)
>>> def custom_range(x):
... penalty = po.regularize.penalize_range(x, allowed_range=(0.2, 0.8))
... return penalty
>>> met_default = po.Metamer(img, model)
>>> met_custom = po.Metamer(img, model, penalty_function=custom_range)
>>> # Compare the value of the penalties
>>> met_default.penalty_function(img)
tensor(0.)
>>> met_custom.penalty_function(img)
tensor(49.3881)
"""
# Using clip like this is equivalent to using boolean indexing (e.g.,
# img[img < allowed_range[0]]) but much faster
below_min = torch.clip(img - allowed_range[0], max=0).pow(2).sum()
above_max = torch.clip(img - allowed_range[1], min=0).pow(2).sum()
return below_min + above_max