Source code for plenoptic.synthesize.metamer

"""Synthesize model metamers."""

import contextlib
import re
import warnings
from collections import OrderedDict
from collections.abc import Callable
from typing import Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
from tqdm.auto import tqdm

from ..tools import data, display, optim, signal
from ..tools.convergence import coarse_to_fine_enough, loss_convergence
from ..tools.validate import validate_coarse_to_fine, validate_input, validate_model
from .synthesis import OptimizedSynthesis


[docs] class Metamer(OptimizedSynthesis): r"""Synthesize metamers for image-computable differentiable models. Following the basic idea in [1]_, this class creates a metamer for a given model on a given image. We start with ``initial_image`` and iteratively adjust the pixel values so as to match the representation of the ``metamer`` and ``image``. All ``saved_`` attributes are initialized as empty lists and will be non-empty if the ``store_progress`` arg to ``synthesize()`` is not ``False``. They will be appended to on every iteration if ``store_progress=True`` or every ``store_progress`` iterations if it's an ``int``. Parameters ---------- image : A 4d tensor, this is the image whose representation we wish to match. If this is not a tensor, we try to cast it as one. model : A visual model, see `Metamer` notebook for more details loss_function : the loss function to use to compare the representations of the models in order to determine their loss. Because of the limitations of pickle, you cannot use a lambda function for this if you wish to save the Metamer object (i.e., it must be one of our built-in functions or defined using a `def` statement) range_penalty_lambda : strength of the regularizer that enforces the allowed_range. Must be non-negative. allowed_range : Range (inclusive) of allowed pixel values. Any values outside this range will be penalized. initial_image : 4d Tensor to initialize our metamer with. If None, will draw a sample of uniform noise within ``allowed_range``. Attributes ---------- target_representation : torch.Tensor Whatever is returned by ``model(image)``, this is what we match in order to create a metamer metamer : torch.Tensor The metamer. This may be unfinished depending on how many iterations we've run for. losses : list A list of our loss over iterations. gradient_norm : list A list of the gradient's L2 norm over iterations. pixel_change_norm : list A list containing the L2 norm of the pixel change over iterations (``pixel_change_norm[i]`` is the pixel change norm in ``metamer`` between iterations ``i`` and ``i-1``). saved_metamer : torch.Tensor Saved ``self.metamer`` for later examination. References ---------- .. [1] J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int'l Journal of Computer Vision. 40(1):49-71, October, 2000. https://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html https://www.cns.nyu.edu/~lcv/texture/ """ def __init__( self, image: Tensor, model: torch.nn.Module, loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, range_penalty_lambda: float = 0.1, allowed_range: tuple[float, float] = (0, 1), initial_image: Tensor | None = None, ): super().__init__(range_penalty_lambda, allowed_range) validate_input(image, allowed_range=allowed_range) validate_model( model, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) self._model = model self._image = image self._image_shape = image.shape self._target_representation = self.model(self.image) self.scheduler = None self.loss_function = loss_function self._initialize(initial_image) self._saved_metamer = [] self._store_progress = None def _initialize(self, initial_image: Tensor | None = None): """Initialize the metamer. Set the ``self.metamer`` attribute to be an attribute with the user-supplied data, making sure it's the right shape. Parameters ---------- initial_image : The tensor we use to initialize the metamer. If None (the default), we initialize with uniformly-distributed random noise lying between 0 and 1. """ if initial_image is None: metamer = torch.rand_like(self.image) # rescale metamer to lie within the interval # self.allowed_range metamer = signal.rescale(metamer, *self.allowed_range) metamer.requires_grad_() else: if initial_image.ndimension() < 4: raise ValueError( "initial_image must be torch.Size([n_batch" ", n_channels, im_height, im_width]) but got " f"{initial_image.size()}" ) if initial_image.size() != self.image.size(): raise ValueError("initial_image and image must be same size!") metamer = initial_image.clone().detach() metamer = metamer.to(dtype=self.image.dtype, device=self.image.device) metamer.requires_grad_() self._metamer = metamer
[docs] def synthesize( self, max_iter: int = 100, optimizer: torch.optim.Optimizer | None = None, scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, ): r"""Synthesize a metamer. Update the pixels of ``initial_image`` until its representation matches that of ``image``. We run this until either we reach ``max_iter`` or the change over the past ``stop_iters_to_check`` iterations is less than ``stop_criterion``, whichever comes first Parameters ---------- max_iter : The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). optimizer : The optimizer to use. If None and this is the first time calling synthesize, we use Adam(lr=.01, amsgrad=True); if synthesize has been called before, this must be None and we reuse the previous optimizer. scheduler : The learning rate scheduler to use. If None, we don't use one. store_progress : Whether we should store the metamer image in progress on every iteration. If False, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). stop_criterion : If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check : How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). """ # initialize the optimizer and scheduler self._initialize_optimizer(optimizer, scheduler) # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for i in pbar: # update saved_* attrs. len(losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self.losses)) loss = self._optimizer_step(pbar) if not torch.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): warnings.warn("Loss has converged, stopping synthesis") break pbar.close()
[docs] def objective_function( self, metamer_representation: Tensor | None = None, target_representation: Tensor | None = None, ) -> Tensor: """Compute the metamer synthesis loss. This calls self.loss_function on ``metamer_representation`` and ``target_representation`` and then adds the weighted range penalty. Parameters ---------- metamer_representation : Model response to ``metamer``. If None, we use ``self.model(self.metamer)`` target_representation : Model response to ``image``. If None, we use ``self.target_representation``. Returns ------- loss """ if metamer_representation is None: metamer_representation = self.model(self.metamer) if target_representation is None: target_representation = self.target_representation loss = self.loss_function(metamer_representation, target_representation) range_penalty = optim.penalize_range(self.metamer, self.allowed_range) return loss + self.range_penalty_lambda * range_penalty
def _optimizer_step(self, pbar: tqdm) -> Tensor: r"""Compute and propagate gradients, then step the optimizer to update metamer. Parameters ---------- pbar : A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). Returns ------- loss : torch.Tensor 1-element tensor containing the loss on this step """ last_iter_metamer = self.metamer.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) grad_norm = torch.linalg.vector_norm(self.metamer.grad.data, ord=2, dim=None) self._gradient_norm.append(grad_norm.item()) # optionally step the scheduler if self.scheduler is not None: self.scheduler.step(loss.item()) pixel_change_norm = torch.linalg.vector_norm( self.metamer - last_iter_metamer, ord=2, dim=None ) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{loss.item():.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], gradient_norm=f"{grad_norm.item():.04e}", pixel_change_norm=f"{pixel_change_norm.item():.04e}", ) ) return loss def _check_convergence(self, stop_criterion, stop_iters_to_check): r"""Check whether the loss has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? | | no yes | '---->Is ``abs(synth.loss[-1] - synth.losses[-stop_iters_to_check]) < stop_criterion``? | no | | | yes <-------' | | '------> return ``True`` | '---------> return ``False`` Parameters ---------- stop_criterion : If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check : How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Returns ------- loss_stabilized : Whether the loss has stabilized or not. """ # noqa: E501 return loss_convergence(self, stop_criterion, stop_iters_to_check) def _initialize_optimizer( self, optimizer: torch.optim.Optimizer | None, scheduler: torch.optim.lr_scheduler._LRScheduler | None, ): """Initialize optimizer and scheduler.""" # this uses the OptimizedSynthesis setter super()._initialize_optimizer(optimizer, "metamer") self.scheduler = scheduler for pg in self.optimizer.param_groups: # initialize initial_lr if it's not here. Scheduler should add it # if it's not None. if "initial_lr" not in pg: pg["initial_lr"] = pg["lr"] def _store(self, i: int) -> bool: """Store metamer, if appropriate. if it's the right iteration, we update ``saved_metamer``. Parameters ---------- i the current iteration Returns ------- stored : True if we stored this iteration, False if not. """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs self._saved_metamer.append(self.metamer.clone().to("cpu")) stored = True else: stored = False return stored
[docs] def save(self, file_path: str): r"""Save all relevant variables in .pt file. Note that if store_progress is True, this will probably be very large. See ``load`` docstring for an example of use. Parameters ---------- file_path : str The path to save the metamer object to """ super().save(file_path, attrs=None)
[docs] def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired :attr:`dtype` s. In addition, this method will only cast the floating point parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module """ attrs = ["_image", "_target_representation", "_metamer", "_saved_metamer"] super().to(*args, attrs=attrs, **kwargs) # try to call .to() on model. this should work, but it might fail if e.g., this # a custom model that doesn't inherit torch.nn.Module try: self._model = self._model.to(*args, **kwargs) except AttributeError: warnings.warn("Unable to call model.to(), so we leave it as is.")
[docs] def load( self, file_path: str, map_location: str | None = None, **pickle_load_args, ): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Metamer`` object -- we will ensure that ``image``, ``target_representation`` (and thus ``model``), and ``loss_function`` are all identical. Note this operates in place and so doesn't return anything. Parameters ---------- file_path : str The path to load the synthesis object from map_location : str, optional map_location argument to pass to ``torch.load``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to ``torch.device`` pickle_load_args : any additional kwargs will be added to ``pickle_module.load`` via ``torch.load``, see that function's docstring for details. Examples -------- >>> metamer = po.synth.Metamer(img, model) >>> metamer.synthesize(max_iter=10, store_progress=True) >>> metamer.save('metamers.pt') >>> metamer_copy = po.synth.Metamer(img, model) >>> metamer_copy.load('metamers.pt') Note that you must create a new instance of the Synthesis object and *then* load. """ self._load(file_path, map_location, **pickle_load_args)
def _load( self, file_path: str, map_location: str | None = None, additional_check_attributes: list[str] = [], additional_check_loss_functions: list[str] = [], **pickle_load_args, ): r"""Helper function for loading. Users interact with ``load`` (without the underscore), this is to allow subclasses to specify additional attributes or loss functions to check. """ check_attributes = [ "_image", "_target_representation", "_range_penalty_lambda", "_allowed_range", ] check_attributes += additional_check_attributes check_loss_functions = ["loss_function"] check_loss_functions += additional_check_loss_functions super().load( file_path, map_location=map_location, check_attributes=check_attributes, check_loss_functions=check_loss_functions, **pickle_load_args, ) # make this require a grad again self.metamer.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. if len(self._saved_metamer) and self._saved_metamer[0].device.type != "cpu": self._saved_metamer = [met.to("cpu") for met in self._saved_metamer] @property def model(self): return self._model @property def image(self): return self._image @property def target_representation(self): """Model representation of ``image``, the goal of synthesis is for ``model(metamer)`` to match this value.""" return self._target_representation @property def metamer(self): return self._metamer @property def saved_metamer(self): return torch.stack(self._saved_metamer)
[docs] class MetamerCTF(Metamer): """Synthesize model metamers with coarse-to-fine synthesis. This is a special case of ``Metamer``, which uses the coarse-to-fine synthesis procedure described in [1]_: we start by updating metamer with respect to only a subset of the model's representation (generally, that which corresponds to the lowest spatial frequencies), and changing which subset we consider over the course of synthesis. This is similar to optimizing with a blurred version of the objective function and gradually adding in finer details. It improves synthesis performance for some models. Parameters ---------- image : A 4d tensor, this is the image whose representation we wish to match. If this is not a tensor, we try to cast it as one. model : A visual model, see `Metamer` notebook for more details loss_function : the loss function to use to compare the representations of the models in order to determine their loss. Because of the limitations of pickle, you cannot use a lambda function for this if you wish to save the Metamer object (i.e., it must be one of our built-in functions or defined using a `def` statement) range_penalty_lambda : strength of the regularizer that enforces the allowed_range. Must be non-negative. allowed_range : Range (inclusive) of allowed pixel values. Any values outside this range will be penalized. initial_image : 4d Tensor to initialize our metamer with. If None, will draw a sample of uniform noise within ``allowed_range``. coarse_to_fine : - 'together': start with the coarsest scale, then gradually add each finer scale. - 'separate': compute the gradient with respect to each scale separately (ignoring the others), then with respect to all of them at the end. (see ``Metamer`` tutorial for more details). Attributes ---------- target_representation : torch.Tensor Whatever is returned by ``model(image)``, this is what we match in order to create a metamer metamer : torch.Tensor The metamer. This may be unfinished depending on how many iterations we've run for. losses : list A list of our loss over iterations. gradient_norm : list A list of the gradient's L2 norm over iterations. pixel_change_norm : list A list containing the L2 norm of the pixel change over iterations (``pixel_change_norm[i]`` is the pixel change norm in ``metamer`` between iterations ``i`` and ``i-1``). saved_metamer : torch.Tensor Saved ``self.metamer`` for later examination. scales : list or None The list of scales in optimization order (i.e., from coarse to fine). Will be modified during the course of optimization. scales_loss : list or None The scale-specific loss at each iteration scales_timing : dict or None Keys are the values found in ``scales``, values are lists, specifying the iteration where we started and stopped optimizing this scale. scales_finished : list or None List of scales that we've finished optimizing. """ def __init__( self, image: Tensor, model: torch.nn.Module, loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, range_penalty_lambda: float = 0.1, allowed_range: tuple[float, float] = (0, 1), initial_image: Tensor | None = None, coarse_to_fine: Literal["together", "separate"] = "together", ): super().__init__( image, model, loss_function, range_penalty_lambda, allowed_range, initial_image, ) self._init_ctf(coarse_to_fine) def _init_ctf(self, coarse_to_fine: Literal["together", "separate"]): """Initialize stuff related to coarse-to-fine.""" # this will hold the reduced representation of the target image. if coarse_to_fine not in ["separate", "together"]: raise ValueError( f"Don't know how to handle value {coarse_to_fine}!" " Must be one of: 'separate', 'together'" ) self._ctf_target_representation = None validate_coarse_to_fine( self.model, image_shape=self.image.shape, device=self.image.device ) # if self.scales is not None, we're continuing a previous version # and want to continue. this list comprehension creates a new # object, so we don't modify model.scales self._scales = [i for i in self.model.scales[:-1]] if coarse_to_fine == "separate": self._scales += [self.model.scales[-1]] self._scales += ["all"] self._scales_timing = dict((k, []) for k in self.scales) self._scales_timing[self.scales[0]].append(0) self._scales_loss = [] self._scales_finished = [] self._coarse_to_fine = coarse_to_fine
[docs] def synthesize( self, max_iter: int = 100, optimizer: torch.optim.Optimizer | None = None, scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, change_scale_criterion: float | None = 1e-2, ctf_iters_to_check: int = 50, ): r"""Synthesize a metamer. Update the pixels of ``initial_image`` until its representation matches that of ``image``. We run this until either we reach ``max_iter`` or the change over the past ``stop_iters_to_check`` iterations is less than ``stop_criterion``, whichever comes first Parameters ---------- max_iter : The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). optimizer : The optimizer to use. If None and this is the first time calling synthesize, we use Adam(lr=.01, amsgrad=True); if synthesize has been called before, this must be None and we reuse the previous optimizer. scheduler : The learning rate scheduler to use. If None, we don't use one. store_progress : Whether we should store the metamer image in progress on every iteration. If False, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). stop_criterion : If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check : How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). change_scale_criterion Scale-specific analogue of ``change_scale_criterion``: we consider a given scale finished (and move onto the next) if the loss has changed less than this in the past ``ctf_iters_to_check`` iterations. If ``None``, we'll change scales as soon as we've spent ``ctf_iters_to_check`` on a given scale ctf_iters_to_check Scale-specific analogue of ``stop_iters_to_check``: how many iterations back in order to check in order to see if we should switch scales. """ if (change_scale_criterion is not None) and ( stop_criterion >= change_scale_criterion ): raise ValueError( "stop_criterion must be strictly less than " "change_scale_criterion, or things get weird!" ) # initialize the optimizer and scheduler self._initialize_optimizer(optimizer, scheduler) # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for i in pbar: # update saved_* attrs. len(losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self.losses)) loss = self._optimizer_step( pbar, change_scale_criterion, ctf_iters_to_check ) if not torch.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence( i, stop_criterion, stop_iters_to_check, ctf_iters_to_check ): warnings.warn("Loss has converged, stopping synthesis") break pbar.close()
def _optimizer_step( self, pbar: tqdm, change_scale_criterion: float, ctf_iters_to_check: int, ) -> Tensor: r"""Compute and propagate gradients, then step the optimizer to update metamer. Parameters ---------- pbar : A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). change_scale_criterion : How many iterations back to check to see if the loss has stopped decreasing and we should thus move to the next scale in coarse-to-fine optimization. ctf_iters_to_check : Minimum number of iterations coarse-to-fine must run at each scale. Returns ------- loss : torch.Tensor 1-element tensor containing the loss on this step """ last_iter_metamer = self.metamer.clone() # Check if conditions hold for switching scales: # - Check if loss has decreased below the change_scale_criterion and # - if we've been optimizing this scale for the required number of iterations # - The first check here is because the last scale will be 'all', and # we never remove it if ( len(self.scales) > 1 and len(self.scales_loss) >= ctf_iters_to_check and ( change_scale_criterion is None or abs(self.scales_loss[-1] - self.scales_loss[-ctf_iters_to_check]) < change_scale_criterion ) and ( len(self.losses) - self.scales_timing[self.scales[0]][0] >= ctf_iters_to_check ) ): self._scales_timing[self.scales[0]].append(len(self.losses) - 1) self._scales_finished.append(self._scales.pop(0)) # Only append if scales list is still non-empty after the pop if self.scales: self._scales_timing[self.scales[0]].append(len(self.losses)) # Reset optimizer's learning rate for pg in self.optimizer.param_groups: pg["lr"] = pg["initial_lr"] # Reset ctf target representation for the next update self._ctf_target_representation = None loss, overall_loss = self.optimizer.step(self._closure) self._scales_loss.append(loss.item()) self._losses.append(overall_loss.item()) grad_norm = torch.linalg.vector_norm(self.metamer.grad.data, ord=2, dim=None) self._gradient_norm.append(grad_norm.item()) # optionally step the scheduler if self.scheduler is not None: self.scheduler.step(loss.item()) pixel_change_norm = torch.linalg.vector_norm( self.metamer - last_iter_metamer, ord=2, dim=None ) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{overall_loss.item():.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], gradient_norm=f"{grad_norm.item():.04e}", pixel_change_norm=f"{pixel_change_norm.item():.04e}", current_scale=self.scales[0], current_scale_loss=f"{loss.item():.04e}", ) ) return overall_loss def _closure(self) -> tuple[Tensor, Tensor]: r"""An abstraction of the gradient calculation, before the optimization step. This enables optimization algorithms that perform several evaluations of the gradient before taking a step (ie. second order methods like LBFGS). Additionally, this is where: - ``metamer_representation`` is calculated, and thus any modifications to the model's forward call (e.g., specifying `scale` kwarg for coarse-to-fine) should happen. - ``loss`` is calculated and ``loss.backward()`` is called. Returns ------- loss Loss of the current objective function overall_loss Loss of the complete model. This differs from ``loss`` if we're doing coarse-to-fine synthesis """ self.optimizer.zero_grad() analyze_kwargs = {} # if we've reached 'all', we use the full model if self.scales[0] != "all": analyze_kwargs["scales"] = [self.scales[0]] # if 'together', then we also want all the coarser # scales if self.coarse_to_fine == "together": analyze_kwargs["scales"] += self.scales_finished metamer_representation = self.model(self.metamer, **analyze_kwargs) # if analyze_kwargs is empty, we can just compare # metamer_representation against our cached target_representation if analyze_kwargs: if self._ctf_target_representation is None: target_rep = self.model(self.image, **analyze_kwargs) self._ctf_target_representation = target_rep else: target_rep = self._ctf_target_representation # this is just for display, so don't compute gradients with torch.no_grad(): overall_loss = self.objective_function(None, None) else: target_rep = None overall_loss = None loss = self.objective_function(metamer_representation, target_rep) loss.backward(retain_graph=False) if overall_loss is None: overall_loss = loss.clone() return loss, overall_loss def _check_convergence( self, i: int, stop_criterion: float, stop_iters_to_check: int, ctf_iters_to_check: int, ) -> bool: r"""Check whether the loss has stabilized and whether we've synthesized all scales. Have we been synthesizing for ``stop_iters_to_check`` iterations? | | no yes | '---->Is ``abs(self.loss[-1] - self.losses[-stop_iters_to_check] < stop_criterion``? | no | | | yes |-------' '---->Have we synthesized all scales and done so for ``ctf_iters_to_check`` iterations? | no | | | yes |---------------' '----> return ``True`` | | | | | | '---------> return ``False`` Parameters ---------- i The current iteration (0-indexed). stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). ctf_iters_to_check Minimum number of iterations coarse-to-fine must run at each scale. Returns ------- loss_stabilized : Whether the loss has stabilized and we've synthesized all scales. """ # noqa: E501 loss_conv = loss_convergence(self, stop_criterion, stop_iters_to_check) return loss_conv and coarse_to_fine_enough(self, i, ctf_iters_to_check)
[docs] def load( self, file_path: str, map_location: str | None = None, **pickle_load_args, ): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Metamer`` object -- we will ensure that ``image``, ``target_representation`` (and thus ``model``), and ``loss_function`` are all identical. Note this operates in place and so doesn't return anything. Parameters ---------- file_path : str The path to load the synthesis object from map_location : str, optional map_location argument to pass to ``torch.load``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to ``torch.device`` pickle_load_args : any additional kwargs will be added to ``pickle_module.load`` via ``torch.load``, see that function's docstring for details. Examples -------- >>> metamer = po.synth.Metamer(img, model) >>> metamer.synthesize(max_iter=10, store_progress=True) >>> metamer.save('metamers.pt') >>> metamer_copy = po.synth.Metamer(img, model) >>> metamer_copy.load('metamers.pt') Note that you must create a new instance of the Synthesis object and *then* load. """ super()._load(file_path, map_location, ["_coarse_to_fine"], **pickle_load_args)
@property def coarse_to_fine(self): return self._coarse_to_fine @property def scales(self): return tuple(self._scales) @property def scales_loss(self): return tuple(self._scales_loss) @property def scales_timing(self): return self._scales_timing @property def scales_finished(self): return tuple(self._scales_finished)
[docs] def plot_loss( metamer: Metamer, iteration: int | None = None, ax: mpl.axes.Axes | None = None, **kwargs, ) -> mpl.axes.Axes: """Plot synthesis loss with log-scaled y axis. Plots ``metamer.losses`` over all iterations. Also plots a red dot at ``iteration``, to highlight the loss there. If ``iteration=None``, then the dot will be at the final iteration. Parameters ---------- metamer : Metamer object whose loss we want to plot. iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. kwargs : passed to plt.semilogy Returns ------- ax : The matplotlib axes containing the plot. """ if iteration is None: loss_idx = len(metamer.losses) - 1 elif iteration < 0: # in order to get the x-value of the dot to line up, # need to use this work-around loss_idx = len(metamer.losses) + iteration else: loss_idx = iteration if ax is None: ax = plt.gca() ax.semilogy(metamer.losses, **kwargs) with contextlib.suppress(IndexError): # then there's no loss to plot ax.scatter(loss_idx, metamer.losses[loss_idx], c="r") ax.set(xlabel="Synthesis iteration", ylabel="Loss") return ax
[docs] def display_metamer( metamer: Metamer, batch_idx: int = 0, channel_idx: int | None = None, zoom: float | None = None, iteration: int | None = None, ax: mpl.axes.Axes | None = None, **kwargs, ) -> mpl.axes.Axes: """Display metamer. You can specify what iteration to view by using the ``iteration`` arg. The default, ``None``, shows the final one. We use ``plenoptic.imshow`` to display the metamer and attempt to automatically find the most reasonable zoom value. You can override this value using the zoom arg, but remember that ``plenoptic.imshow`` is opinionated about the size of the resulting image and will throw an Exception if the axis created is not big enough for the selected zoom. Parameters ---------- metamer : Metamer object whose synthesized metamer we want to display. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we assume image is RGB(A) and show all channels. iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. zoom : How much to zoom in / enlarge the metamer, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. kwargs : Passed to ``plenoptic.imshow`` Returns ------- ax : The matplotlib axes containing the plot. """ image = metamer.metamer if iteration is None else metamer.saved_metamer[iteration] if batch_idx is None: raise ValueError("batch_idx must be an integer!") # we're only plotting one image here, so if the user wants multiple # channels, they must be RGB as_rgb = bool(channel_idx is None and image.shape[1] > 1) if ax is None: ax = plt.gca() display.imshow( image, ax=ax, title="Metamer", zoom=zoom, batch_idx=batch_idx, channel_idx=channel_idx, as_rgb=as_rgb, **kwargs, ) ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) return ax
def _representation_error( metamer: Metamer, iteration: int | None = None, **kwargs ) -> Tensor: r"""Get the representation error. This is ``metamer.model(metamer) - target_representation)``. If ``iteration`` is not None, we use ``metamer.model(saved_metamer[iteration])`` instead. Parameters ---------- metamer : Metamer object whose representation error we want to compute. iteration : Which iteration to compute the representation error for. If None, we show the most recent one. Negative values are also allowed. kwargs : Passed to ``metamer.model.forward`` Returns ------- representation_error """ if iteration is not None: metamer_rep = metamer.model( metamer.saved_metamer[iteration].to(metamer.target_representation.device) ) else: metamer_rep = metamer.model(metamer.metamer, **kwargs) return metamer_rep - metamer.target_representation
[docs] def plot_representation_error( metamer: Metamer, batch_idx: int = 0, iteration: int | None = None, ylim: tuple[float, float] | None | Literal[False] = None, ax: mpl.axes.Axes | None = None, as_rgb: bool = False, **kwargs, ) -> list[mpl.axes.Axes]: r"""Plot distance ratio showing how close we are to convergence. We plot ``_representation_error(metamer, iteration)``. For more details, see ``plenoptic.tools.display.plot_representation``. Parameters ---------- metamer : Metamer object whose synthesized metamer we want to display. batch_idx : Which index to take from the batch dimension iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ylim : If ``ylim`` is ``None``, we sets the axes' y-limits to be ``(-y_max, y_max)``, where ``y_max=np.abs(data).max()``. If it's ``False``, we do nothing. If a tuple, we use that range. ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. as_rgb : bool, optional The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the response doesn't look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to `po.imshow()`, see that methods docstring for details. kwargs : Passed to ``metamer.model.forward`` Returns ------- axes : List of created axes """ representation_error = _representation_error( metamer=metamer, iteration=iteration, **kwargs ) if ax is None: ax = plt.gca() return display.plot_representation( metamer.model, representation_error, ax, title="Representation error", ylim=ylim, batch_idx=batch_idx, as_rgb=as_rgb, )
[docs] def plot_pixel_values( metamer: Metamer, batch_idx: int = 0, channel_idx: int | None = None, iteration: int | None = None, ylim: tuple[float, float] | Literal[False] = False, ax: mpl.axes.Axes | None = None, **kwargs, ) -> mpl.axes.Axes: r"""Plot histogram of pixel values of target image and its metamer. As a way to check the distributions of pixel intensities and see if there's any values outside the allowed range Parameters ---------- metamer : Metamer object with the images whose pixel values we want to compare. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) images). iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ylim : if tuple, the ylimit to set for this axis. If False, we leave it untouched ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. kwargs : passed to plt.hist Returns ------- ax : Created axes. """ def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule. copied from seaborn.""" # From https://stats.stackexchange.com/questions/798/ a = np.asarray(a) iqr = np.diff(np.percentile(a, [0.25, 0.75]))[0] if len(a) < 2: return 1 h = 2 * iqr / (len(a) ** (1 / 3)) # fall back to sqrt(a) bins if iqr is 0 if h == 0: return int(np.sqrt(a.size)) else: return int(np.ceil((a.max() - a.min()) / h)) kwargs.setdefault("alpha", 0.4) if iteration is None: met = metamer.metamer[batch_idx] else: met = metamer.saved_metamer[iteration, batch_idx] image = metamer.image[batch_idx] if channel_idx is not None: image = image[channel_idx] image = image[channel_idx] if ax is None: ax = plt.gca() image = data.to_numpy(image).flatten() met = data.to_numpy(met).flatten() ax.hist( met, bins=min(_freedman_diaconis_bins(image), 50), label="metamer", **kwargs, ) ax.hist( image, bins=min(_freedman_diaconis_bins(image), 50), label="target image", **kwargs, ) ax.legend() if ylim: ax.set_ylim(ylim) ax.set_title("Histogram of pixel values") return ax
def _check_included_plots(to_check: list[str] | dict[str, float], to_check_name: str): """Check whether the user wanted us to create plots that we can't. Helper function for plot_synthesis_status and animate. Raises a ValueError to_check contains any values that are not allowed. Parameters ---------- to_check : The variable to check. We ensure that it doesn't contain any extra (not allowed) values. If a list, we check its contents. If a dict, we check its keys. to_check_name : Name of the `to_check` variable, used in the error message. """ allowed_vals = [ "display_metamer", "plot_loss", "plot_representation_error", "plot_pixel_values", "misc", ] try: vals = to_check.keys() except AttributeError: vals = to_check not_allowed = [v for v in vals if v not in allowed_vals] if not_allowed: raise ValueError( f"{to_check_name} contained value(s) {not_allowed}! " f"Only {allowed_vals} are permissible!" ) def _setup_synthesis_fig( fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float, float] | None = None, included_plots: list[str] = [ "display_metamer", "plot_loss", "plot_representation_error", ], display_metamer_width: float = 1, plot_loss_width: float = 1, plot_representation_error_width: float = 1, plot_pixel_values_width: float = 1, ) -> tuple[mpl.figure.Figure, list[mpl.axes.Axes], dict[str, int]]: """Set up figure for plot_synthesis_status. Creates figure with enough axes for the all the plots you want. Will also create index in axes_idx for them if you haven't done so already. By default, all axes will be on the same row and have the same width. If you want them to be on different rows, will need to initialize fig yourself and pass that in. For changing width, change the corresponding *_width arg, which gives width relative to other axes. So if you want the axis for the representation_error plot to be twice as wide as the others, set representation_error_width=2. Parameters ---------- fig : The figure to plot on or None. If None, we create a new figure axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: loss, representation_error, pixel_values, misc. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in misc. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have relative width=1 correspond to 5 included_plots : Which plots to include. Must be some subset of ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values'``. display_metamer_width : Relative width of the axis for the synthesized metamer. plot_loss_width : Relative width of the axis for loss plot. plot_representation_error_width : Relative width of the axis for representation error plot. plot_pixel_values_width : Relative width of the axis for image pixel intensities histograms. Returns ------- fig : The figure to plot on axes : List or array of axes contained in fig axes_idx : Dictionary identifying the idx for each plot type """ n_subplots = 0 axes_idx = axes_idx.copy() width_ratios = [] if "display_metamer" in included_plots: n_subplots += 1 width_ratios.append(display_metamer_width) if "display_metamer" not in axes_idx: axes_idx["display_metamer"] = data._find_min_int(axes_idx.values()) if "plot_loss" in included_plots: n_subplots += 1 width_ratios.append(plot_loss_width) if "plot_loss" not in axes_idx: axes_idx["plot_loss"] = data._find_min_int(axes_idx.values()) if "plot_representation_error" in included_plots: n_subplots += 1 width_ratios.append(plot_representation_error_width) if "plot_representation_error" not in axes_idx: axes_idx["plot_representation_error"] = data._find_min_int( axes_idx.values() ) if "plot_pixel_values" in included_plots: n_subplots += 1 width_ratios.append(plot_pixel_values_width) if "plot_pixel_values" not in axes_idx: axes_idx["plot_pixel_values"] = data._find_min_int(axes_idx.values()) if fig is None: width_ratios = np.array(width_ratios) if figsize is None: # we want (5, 5) for each subplot, with a bit of room between # each subplot figsize = ((width_ratios * 5).sum() + width_ratios.sum() - 1, 5) width_ratios = width_ratios / width_ratios.sum() fig, axes = plt.subplots( 1, n_subplots, figsize=figsize, gridspec_kw={"width_ratios": width_ratios}, ) if n_subplots == 1: axes = [axes] else: axes = fig.axes # make sure misc contains all the empty axes misc_axes = axes_idx.get("misc", []) if not hasattr(misc_axes, "__iter__"): misc_axes = [misc_axes] all_axes = [] for i in axes_idx.values(): # so if it's a list of ints if hasattr(i, "__iter__"): all_axes.extend(i) else: all_axes.append(i) misc_axes += [i for i, _ in enumerate(fig.axes) if i not in all_axes] axes_idx["misc"] = misc_axes return fig, axes, axes_idx
[docs] def plot_synthesis_status( metamer: Metamer, batch_idx: int = 0, channel_idx: int | None = None, iteration: int | None = None, ylim: tuple[float, float] | None | Literal[False] = None, vrange: tuple[float, float] | str = "indep1", zoom: float | None = None, plot_representation_error_as_rgb: bool = False, fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float, float] | None = None, included_plots: list[str] = [ "display_metamer", "plot_loss", "plot_representation_error", ], width_ratios: dict[str, float] = {}, ) -> tuple[mpl.figure.Figure, dict[str, int]]: r"""Make a plot showing synthesis status. We create several subplots to analyze this. By default, we create three subplots on a new figure: the first one contains the synthesized metamer, the second contains the loss, and the third contains the representation error. There is an optional additional plot: ``plot_pixel_values``, a histogram of pixel values of the metamer and target image. The plots to include are specified by including their name in the ``included_plots`` list. All plots can be created separately using the method with the same name. Parameters ---------- metamer : Metamer object whose status we want to plot. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) image). iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ylim : The ylimit to use for the representation_error plot. We pass this value directly to ``plot_representation_error`` vrange : The vrange option to pass to ``display_metamer()``. See docstring of ``imshow`` for possible values. zoom : How much to zoom in / enlarge the metamer, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. plot_representation_error_as_rgb : bool, optional The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the response doesn't look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to `po.imshow()`, see that methods docstring for details. fig : if None, we create a new figure. otherwise we assume this is an empty figure that has the appropriate size and number of subplots axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values', 'misc'``. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in ``'misc'``. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have each axis be of size (5, 5) included_plots : Which plots to include. Must be some subset of ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values'``. width_ratios : By default, all plots axes will have the same width. To change that, specify their relative widths using the keys: ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values'`` and floats specifying their relative width. Any not included will be assumed to be 1. Returns ------- fig : The figure containing this plot axes_idx : Dictionary giving index of each plot. """ if iteration is not None and not metamer.store_progress: raise ValueError( "synthesis() was run with store_progress=False, " "cannot specify which iteration to plot (only" " last one, with iteration=None)" ) if metamer.metamer.ndim not in [3, 4]: raise ValueError( "plot_synthesis_status() expects 3 or 4d data;" "unexpected behavior will result otherwise!" ) _check_included_plots(included_plots, "included_plots") _check_included_plots(width_ratios, "width_ratios") _check_included_plots(axes_idx, "axes_idx") width_ratios = {f"{k}_width": v for k, v in width_ratios.items()} fig, axes, axes_idx = _setup_synthesis_fig( fig, axes_idx, figsize, included_plots, **width_ratios ) def check_iterables(i, vals): for j in vals: try: # then it's an iterable if i in j: return True except TypeError: # then it's not an iterable if i == j: return True if "display_metamer" in included_plots: display_metamer( metamer, batch_idx=batch_idx, channel_idx=channel_idx, iteration=iteration, ax=axes[axes_idx["display_metamer"]], zoom=zoom, vrange=vrange, ) if "plot_loss" in included_plots: plot_loss(metamer, iteration=iteration, ax=axes[axes_idx["plot_loss"]]) if "plot_representation_error" in included_plots: plot_representation_error( metamer, batch_idx=batch_idx, iteration=iteration, ax=axes[axes_idx["plot_representation_error"]], ylim=ylim, as_rgb=plot_representation_error_as_rgb, ) # this can add a bunch of axes, so this will try and figure # them out new_axes = [ i for i, _ in enumerate(fig.axes) if not check_iterables(i, axes_idx.values()) ] + [axes_idx["plot_representation_error"]] axes_idx["plot_representation_error"] = new_axes if "plot_pixel_values" in included_plots: plot_pixel_values( metamer, batch_idx=batch_idx, channel_idx=channel_idx, iteration=iteration, ax=axes[axes_idx["plot_pixel_values"]], ) return fig, axes_idx
[docs] def animate( metamer: Metamer, framerate: int = 10, batch_idx: int = 0, channel_idx: int | None = None, ylim: str | None | tuple[float, float] | Literal[False] = None, vrange: tuple[float, float] | str = (0, 1), zoom: float | None = None, plot_representation_error_as_rgb: bool = False, fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float, float] | None = None, included_plots: list[str] = [ "display_metamer", "plot_loss", "plot_representation_error", ], width_ratios: dict[str, float] = {}, ) -> mpl.animation.FuncAnimation: r"""Animate synthesis progress. This is essentially the figure produced by ``metamer.plot_synthesis_status`` animated over time, for each stored iteration. This functions returns a matplotlib FuncAnimation object. See our documentation (e.g., [Quickstart](https://docs.plenoptic.org/docs/branch/main/tutorials/00_quickstart.html)) for examples on how to view it in a Jupyter notebook. In order to save, use ``anim.save(filename)``. In either case, this can take a while and you'll need the appropriate writer installed and on your path, e.g., ffmpeg, imagemagick, etc). See [matplotlib documentation](https://matplotlib.org/stable/api/animation_api.html) for more details. Parameters ---------- metamer : Metamer object whose synthesis we want to animate. framerate : How many frames a second to display. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) image). ylim : The y-limits of the representation_error plot: * If a tuple, then this is the ylim of all plots * If None, then all plots have the same limits, all symmetric about 0 with a limit of ``np.abs(representation_error).max()`` (for the initial representation_error) * If False, don't modify limits. * If a string, must be 'rescale' or of the form 'rescaleN', where N can be any integer. If 'rescaleN', we rescale the limits every N frames (we rescale as if ylim = None). If 'rescale', then we do this 10 times over the course of the animation vrange : The vrange option to pass to ``display_metamer()``. See docstring of ``imshow`` for possible values. zoom : How much to zoom in / enlarge the metamer, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. plot_representation_error_as_rgb : The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the representation doesn't look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to `po.imshow()`, see that methods docstring for details. since plot_synthesis_status normally sets it up for us fig : If None, create the figure from scratch. Else, should be an empty figure with enough axes (the expected use here is have same-size movies with different plots). axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values', 'misc'``. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in ``'misc'``. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have each axis be of size (5, 5) included_plots : Which plots to include. Must be some subset of ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values'``. width_ratios : By default, all plots axes will have the same width. To change that, specify their relative widths using the keys: ``'display_metamer', 'plot_loss', 'plot_representation_error', 'plot_pixel_values'`` and floats specifying their relative width. Any not included will be assumed to be 1. Returns ------- anim : The animation object. In order to view, must convert to HTML or save. Notes ----- By default, we use the ffmpeg backend, which requires that you have ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use a different, use the matplotlib rcParams: `matplotlib.rcParams['animation.writer'] = writer`, see https://matplotlib.org/stable/api/animation_api.html#writer-classes for more details. For displaying in a jupyter notebook, ffmpeg appears to be required. """ if not metamer.store_progress: raise ValueError( "synthesize() was run with store_progress=False, cannot animate!" ) if metamer.metamer.ndim not in [3, 4]: raise ValueError( "animate() expects 3 or 4d data; unexpected" " behavior will result otherwise!" ) _check_included_plots(included_plots, "included_plots") _check_included_plots(width_ratios, "width_ratios") _check_included_plots(axes_idx, "axes_idx") if metamer.target_representation.ndimension() == 4: # we have to do this here so that we set the # ylim_rescale_interval such that we never rescale ylim # (rescaling ylim messes up an image axis) ylim = False try: if ylim.startswith("rescale"): try: ylim_rescale_interval = int(ylim.replace("rescale", "")) except ValueError: # then there's nothing we can convert to an int there ylim_rescale_interval = int((metamer.saved_metamer.shape[0] - 1) // 10) if ylim_rescale_interval == 0: ylim_rescale_interval = int(metamer.saved_metamer.shape[0] - 1) ylim = None else: raise ValueError(f"Don't know how to handle ylim {ylim}!") except AttributeError: # this way we'll never rescale ylim_rescale_interval = len(metamer.saved_metamer) + 1 # we run plot_synthesis_status to initialize the figure if either fig is # None or if there are no titles on any axes, which we assume means that # it's an empty figure if fig is None or not any([ax.get_title() for ax in fig.axes]): fig, axes_idx = plot_synthesis_status( metamer=metamer, batch_idx=batch_idx, channel_idx=channel_idx, iteration=0, figsize=figsize, ylim=ylim, vrange=vrange, zoom=zoom, fig=fig, axes_idx=axes_idx, included_plots=included_plots, plot_representation_error_as_rgb=plot_representation_error_as_rgb, width_ratios=width_ratios, ) # grab the artist for the second plot (we don't need to do this for the # metamer or representation plot, because we use the update_plot # function for that) if "plot_loss" in included_plots: scat = fig.axes[axes_idx["plot_loss"]].collections[0] # can have multiple plots if "plot_representation_error" in included_plots: try: rep_error_axes = [ fig.axes[i] for i in axes_idx["plot_representation_error"] ] except TypeError: # in this case, axes_idx['plot_representation_error'] is not iterable and # so is a single value rep_error_axes = [fig.axes[axes_idx["plot_representation_error"]]] else: rep_error_axes = [] # can also have multiple plots if metamer.target_representation.ndimension() == 4: if "plot_representation_error" in included_plots: warnings.warn( "Looks like representation is image-like, haven't fully" " thought out how to best handle rescaling color ranges yet!" ) # replace the bit of the title that specifies the range, # since we don't make any promises about that. we have to do # this here because we need the figure to have been created for ax in rep_error_axes: ax.set_title(re.sub(r"\n range: .* \n", "\n\n", ax.get_title())) def movie_plot(i): artists = [] if "display_metamer" in included_plots: artists.extend( display.update_plot( fig.axes[axes_idx["display_metamer"]], data=metamer.saved_metamer[i], batch_idx=batch_idx, ) ) if "plot_representation_error" in included_plots: rep_error = _representation_error(metamer, iteration=i) # we pass rep_error_axes to update, and we've grabbed # the right things above artists.extend( display.update_plot( rep_error_axes, batch_idx=batch_idx, model=metamer.model, data=rep_error, ) ) # again, we know that rep_error_axes contains all the axes # with the representation ratio info if ( (i + 1) % ylim_rescale_interval == 0 and metamer.target_representation.ndimension() == 3 ): display.rescale_ylim(rep_error_axes, rep_error) if "plot_pixel_values" in included_plots: # this is the dumbest way to do this, but it's simple -- # clearing the axes can cause problems if the user has, for # example, changed the tick locator or formatter. not sure how # to handle this best right now fig.axes[axes_idx["plot_pixel_values"]].clear() plot_pixel_values( metamer, batch_idx=batch_idx, channel_idx=channel_idx, iteration=i, ax=fig.axes[axes_idx["plot_pixel_values"]], ) if "plot_loss" in included_plots: # loss always contains values from every iteration, but everything # else will be subsampled. x_val = i * metamer.store_progress scat.set_offsets((x_val, metamer.losses[x_val])) artists.append(scat) # as long as blitting is True, need to return a sequence of artists return artists # don't need an init_func, since we handle initialization ourselves anim = mpl.animation.FuncAnimation( fig, movie_plot, frames=len(metamer.saved_metamer), blit=True, interval=1000.0 / framerate, repeat=False, ) plt.close(fig) return anim