Source code for plenoptic.synthesize.simple_metamer
"""Simple Metamer Class"""
import torch
from deprecated.sphinx import deprecated
from tqdm.auto import tqdm
from ..tools import optim
from ..tools.validate import validate_input, validate_model
from .synthesis import Synthesis
[docs]
@deprecated(
"Use :py:class:`plenoptic.synthesize.metamer.Metamer` instead", version="1.1.0"
)
class SimpleMetamer(Synthesis):
r"""Simple version of metamer synthesis.
This doesn't have any of the bells and whistles of the full Metamer class,
but does perform basic metamer synthesis: given a target image and a model,
synthesize a new image (initialized with uniform noise) that has the same
model output.
This is meant as a demonstration of the basic logic of synthesis.
Parameters
----------
image
A 4d tensor, this is the image whose model representation we wish to
match.
model
The visual model whose representation we wish to match.
Notes
-----
"""
def __init__(self, image: torch.Tensor, model: torch.nn.Module):
validate_model(
model,
image_shape=image.shape,
image_dtype=image.dtype,
device=image.device,
)
self.model = model
validate_input(image)
self.image = image
self.metamer = torch.rand_like(self.image, requires_grad=True)
self.target_representation = self.model(self.image).detach()
self.optimizer = None
self.losses = []
[docs]
def synthesize(
self,
max_iter: int = 100,
optimizer: None | torch.optim.Optimizer = None,
) -> torch.Tensor:
"""Synthesize a simple metamer.
If called multiple times, will continue where we left off.
Parameters
----------
max_iter
Number of iterations to run synthesis for.
optimizer
The optimizer to use. If None and this is the first time calling
synthesize, we use Adam(lr=.01, amsgrad=True); if synthesize has
been called before, we reuse the previous optimizer.
Returns
-------
metamer
The synthesized metamer
"""
if optimizer is None:
if self.optimizer is None:
self.optimizer = torch.optim.Adam([self.metamer], lr=0.01, amsgrad=True)
else:
self.optimizer = optimizer
pbar = tqdm(range(max_iter))
for _ in pbar:
def closure():
self.optimizer.zero_grad()
metamer_representation = self.model(self.metamer)
# We want to make sure our metamer ends up in the range [0, 1],
# so we penalize all values outside that range in the loss
# function. You could theoretically also just clamp metamer on
# each step of the iteration, but the penalty in the loss seems
# to work better in practice
loss = optim.mse(metamer_representation, self.target_representation)
loss = loss + 0.1 * optim.penalize_range(self.metamer, (0, 1))
self.losses.append(loss.item())
loss.backward(retain_graph=False)
pbar.set_postfix(loss=loss.item())
return loss
self.optimizer.step(closure)
[docs]
def save(self, file_path: str):
r"""Save all relevant (non-model) variables in .pt file.
Parameters
----------
file_path :
The path to save the SimpleMetamer object to.
"""
super().save(file_path, attrs=None)
[docs]
def load(self, file_path: str, map_location: str | None = None):
r"""Load all relevant attributes from a .pt file.
Note this operates in place and so doesn't return anything.
Parameters
----------
file_path
The path to load the synthesis object from
"""
check_attributes = ["target_representation", "image"]
super().load(
file_path,
check_attributes=check_attributes,
map_location=map_location,
)
[docs]
def to(self, *args, **kwargs):
r"""Move and/or cast the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices. When calling this method to move tensors
to a CUDA device, items in ``attrs`` that start with "saved_" will not
be moved.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
attrs (:class:`list`): list of strs containing the attributes of
this object to move to the specified device/dtype
Returns:
Module: self
"""
attrs = ["model", "image", "target_representation", "metamer"]
super().to(*args, attrs=attrs, **kwargs)
return self