"""abstract synthesis super-class."""
import abc
import warnings
import numpy as np
import torch
[docs]
class Synthesis(abc.ABC):
r"""Abstract super-class for synthesis objects.
All synthesis objects share a variety of similarities and thus need
to have similar methods. Some of these can be implemented here and
simply inherited, some of them will need to be different for each
sub-class and thus are marked as abstract methods here
"""
[docs]
@abc.abstractmethod
def synthesize(self):
r"""Synthesize something."""
pass
[docs]
def save(self, file_path: str, attrs: list[str] | None = None):
r"""Save all relevant (non-model) variables in .pt file.
If you leave attrs as None, we grab vars(self) and exclude 'model'.
This is probably correct, but the option is provided to override it
just in case
Parameters
----------
file_path : str
The path to save the synthesis object to
attrs : list or None, optional
List of strs containing the names of the attributes of this
object to save. See above for behavior if attrs is None.
"""
if attrs is None:
# this copies the attributes dict so we don't actually remove the
# model attribute in the next line
attrs = {k: v for k, v in vars(self).items()}
attrs.pop("_model", None)
save_dict = {}
for k in attrs:
if k == "_model":
warnings.warn(
"Models can be quite large and they don't change"
" over synthesis. Please be sure that you "
"actually want to save the model."
)
attr = getattr(self, k)
# detaching the tensors avoids some headaches like the
# tensors having extra hooks or the like
if isinstance(attr, torch.Tensor):
attr = attr.detach()
save_dict[k] = attr
torch.save(save_dict, file_path)
[docs]
def load(
self,
file_path: str,
map_location: str | None = None,
check_attributes: list[str] = [],
check_loss_functions: list[str] = [],
**pickle_load_args,
):
r"""Load all relevant attributes from a .pt file.
This should be called by an initialized ``Synthesis`` object -- we will
ensure that the attributes in the ``check_attributes`` arg all match in
the current and loaded object.
Note this operates in place and so doesn't return anything.
Parameters
----------
file_path :
The path to load the synthesis object from
map_location :
map_location argument to pass to ``torch.load``. If you save
stuff that was being run on a GPU and are loading onto a
CPU, you'll need this to make sure everything lines up
properly. This should be structured like the str you would
pass to ``torch.device``
check_attributes :
List of strings we ensure are identical in the current
``Synthesis`` object and the loaded one. Checking the model is
generally not recommended, since it can be hard to do (checking
callable objects is hard in Python) -- instead, checking the
``base_representation`` should ensure the model hasn't functinoally
changed.
check_loss_functions :
Names of attributes that are loss functions and so must be checked
specially -- loss functions are callables, and it's very difficult
to check python callables for equality so, to get around that, we
instead call the two versions on the same pair of tensors,
and compare the outputs.
pickle_load_args :
any additional kwargs will be added to ``pickle_module.load`` via
``torch.load``, see that function's docstring for details.
"""
tmp_dict = torch.load(file_path, map_location=map_location, **pickle_load_args)
if map_location is not None:
device = map_location
else:
for v in tmp_dict.values():
if isinstance(v, torch.Tensor):
device = v.device
break
for k in check_attributes:
# The only hidden attributes we'd check are those like
# range_penalty_lambda, where this function is checking the
# hidden version (which starts with '_'), but during
# initialization, the user specifies the version without
# the initial underscore. This is because this function
# needs to be able to set the attribute, which can only be
# done with the hidden version.
display_k = k[1:] if k.startswith("_") else k
if not hasattr(self, k):
raise AttributeError(
"All values of `check_attributes` should be "
"attributes set at initialization, but got "
f"attr {display_k}!"
)
if isinstance(getattr(self, k), torch.Tensor):
# there are two ways this can fail -- the first is if they're
# the same shape but different values and the second (in the
# except block) are if they're different shapes.
try:
if not torch.allclose(
getattr(self, k).to(tmp_dict[k].device),
tmp_dict[k],
rtol=5e-2,
):
raise ValueError(
f"Saved and initialized {display_k} are "
f"different! Initialized: {getattr(self, k)}"
f", Saved: {tmp_dict[k]}, difference: "
f"{getattr(self, k) - tmp_dict[k]}"
)
except RuntimeError as e:
# we end up here if dtype or shape don't match
if "The size of tensor a" in e.args[0]:
raise RuntimeError(
f"Attribute {display_k} have different shapes in"
" saved and initialized versions! Initialized"
f": {getattr(self, k).shape}, Saved: "
f"{tmp_dict[k].shape}"
)
elif "did not match" in e.args[0]:
raise RuntimeError(
f"Attribute {display_k} has different dtype in "
"saved and initialized versions! Initialized"
f": {getattr(self, k).dtype}, Saved: "
f"{tmp_dict[k].dtype}"
)
else:
raise e
elif isinstance(getattr(self, k), float):
if not np.allclose(getattr(self, k), tmp_dict[k]):
raise ValueError(
f"Saved and initialized {display_k} are different!"
f" Self: {getattr(self, k)}, "
f"Saved: {tmp_dict[k]}"
)
else:
if getattr(self, k) != tmp_dict[k]:
raise ValueError(
f"Saved and initialized {display_k} are different!"
f" Self: {getattr(self, k)}, "
f"Saved: {tmp_dict[k]}"
)
for k in check_loss_functions:
# same as above
display_k = k[1:] if k.startswith("_") else k
# this way, we know it's the right shape
tensor_a, tensor_b = torch.rand(2, *self._image_shape).to(device)
saved_loss = tmp_dict[k](tensor_a, tensor_b)
init_loss = getattr(self, k)(tensor_a, tensor_b)
if not torch.allclose(saved_loss, init_loss, rtol=1e-2):
raise ValueError(
f"Saved and initialized {display_k} are "
"different! On two random tensors: "
f"Initialized: {init_loss}, Saved: "
f"{saved_loss}, difference: "
f"{init_loss-saved_loss}"
)
for k, v in tmp_dict.items():
setattr(self, k, v)
[docs]
@abc.abstractmethod
def to(self, *args, attrs: list[str] = [], **kwargs):
r"""Moves and/or casts the parameters and buffers.
Similar to ``save``, this is an abstract method only because you
need to define the attributes to call to on.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices. When calling this method to move tensors
to a CUDA device, items in ``attrs`` that start with "saved_" will not
be moved.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
attrs (:class:`list`): list of strs containing the attributes of
this object to move to the specified device/dtype
"""
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(
*args, **kwargs
)
def move(a, k):
move_device = None if k.startswith("saved_") else device
if memory_format is not None and a.dim() == 4:
return a.to(
move_device,
dtype,
non_blocking,
memory_format=memory_format,
)
else:
return a.to(move_device, dtype, non_blocking)
for k in attrs:
if hasattr(self, k):
attr = getattr(self, k)
if isinstance(attr, torch.Tensor):
attr = move(attr.data, k)
if isinstance(getattr(self, k), torch.nn.Parameter):
attr = torch.nn.Parameter(attr)
if getattr(self, k).requires_grad:
attr = attr.requires_grad_()
setattr(self, k, attr)
elif isinstance(attr, list):
setattr(self, k, [move(a, k) for a in attr])
elif attr is not None:
setattr(self, k, move(attr, k))
[docs]
class OptimizedSynthesis(Synthesis):
r"""Abstract super-class for synthesis objects that use optimization.
The primary difference between this and the generic Synthesis class is that
these will use an optimizer object to iteratively update their output.
"""
def __init__(
self,
range_penalty_lambda: float = 0.1,
allowed_range: tuple[float, float] = (0, 1),
):
"""Initialize the properties of OptimizedSynthesis."""
self._losses = []
self._gradient_norm = []
self._pixel_change_norm = []
self._store_progress = None
self._optimizer = None
if range_penalty_lambda < 0:
raise Exception("range_penalty_lambda must be non-negative!")
self._range_penalty_lambda = range_penalty_lambda
self._allowed_range = allowed_range
@abc.abstractmethod
def _initialize(self):
r"""What to start synthesis with."""
pass
[docs]
@abc.abstractmethod
def objective_function(self):
r"""How good is the current synthesized object.
See ``plenoptic.tools.optim`` for some examples.
"""
pass
@abc.abstractmethod
def _check_convergence(self):
r"""How to determine if synthesis has finished.
See ``plenoptic.tools.convergence`` for some examples.
"""
pass
def _closure(self) -> torch.Tensor:
r"""An abstraction of the gradient calculation, before the optimization step.
This enables optimization algorithms that perform several evaluations
of the gradient before taking a step (ie. second order methods like
LBFGS).
Additionally, this is where ``loss`` is calculated and
``loss.backward()`` is called.
Returns
-------
loss
Loss of the current objective function
"""
self.optimizer.zero_grad()
loss = self.objective_function()
loss.backward(retain_graph=False)
return loss
def _initialize_optimizer(
self,
optimizer: torch.optim.Optimizer | None,
synth_name: str,
learning_rate: float = 0.01,
):
"""Initialize optimizer.
First time this is called, optimizer can be:
- None, in which case we create an Adam optimizer with amsgrad=True and
``lr=learning_rate`` with a single parameter, the synthesis attribute
- torch.optim.Optimizer, in which case it must already have the
synthesis attribute (e.g., metamer) as its only parameter.
The synthesis attribute is the one with the name ``synth_name``
Every subsequent time (so, when resuming synthesis), optimizer must be
None (and we use the original optimizer object).
"""
synth_attr = getattr(self, synth_name)
if optimizer is None:
if self.optimizer is None:
self._optimizer = torch.optim.Adam(
[synth_attr], lr=learning_rate, amsgrad=True
)
else:
if self.optimizer is not None:
raise TypeError("When resuming synthesis, optimizer arg must be None!")
params = optimizer.param_groups[0]["params"]
if len(params) != 1 or not torch.equal(params[0], synth_attr):
raise ValueError(
f"For {synth_name} synthesis, optimizer must have one "
f"parameter, the {synth_name} we're synthesizing."
)
self._optimizer = optimizer
@property
def range_penalty_lambda(self):
return self._range_penalty_lambda
@property
def allowed_range(self):
return self._allowed_range
@property
def losses(self):
"""Synthesis loss over iterations."""
return torch.as_tensor(self._losses)
@property
def gradient_norm(self):
"""Synthesis gradient's L2 norm over iterations."""
return torch.as_tensor(self._gradient_norm)
@property
def pixel_change_norm(self):
"""L2 norm change in pixel values over iterations."""
return torch.as_tensor(self._pixel_change_norm)
@property
def store_progress(self):
return self._store_progress
@store_progress.setter
def store_progress(self, store_progress: bool | int):
"""Initialize store_progress.
Sets the ``self.store_progress`` attribute, as well as changing the
``saved_metamer`` attibute to a list so we can append to them. finally,
adds first value to ``saved_metamer`` if it's empty.
Parameters
----------
store_progress : bool or int, optional
Whether we should store the metamer image in progress on every
iteration. If False, we don't save anything. If True, we save every
iteration. If an int, we save every ``store_progress`` iterations
(note then that 0 is the same as False and 1 the same as True). If
True or int>0, ``self.saved_metamer`` contains the stored images.
"""
if store_progress and store_progress is True:
store_progress = 1
if self.store_progress is not None and store_progress != self.store_progress:
# we require store_progress to be the same because otherwise the
# subsampling relationship between attrs that are stored every
# iteration (loss, gradient, etc) and those that are stored every
# store_progress iteration (e.g., saved_metamer) changes partway
# through and that's annoying
raise Exception(
"If you've already run synthesize() before, must "
"re-run it with same store_progress arg. You "
f"passed {store_progress} instead of "
f"{self.store_progress} (True is equivalent to 1)"
)
self._store_progress = store_progress
@property
def optimizer(self):
return self._optimizer