Source code for plenoptic.synthesize.geodesic

import warnings
from collections import OrderedDict
from typing import Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.autograd as autograd
from deprecated.sphinx import deprecated
from torch import Tensor
from tqdm.auto import tqdm

from ..tools.convergence import pixel_change_convergence
from ..tools.data import to_numpy
from ..tools.optim import penalize_range
from ..tools.straightness import (
    deviation_from_line,
    make_straight_line,
    sample_brownian_bridge,
)
from ..tools.validate import validate_input, validate_model
from .synthesis import OptimizedSynthesis


[docs] @deprecated( "Geodesic is not robust enough yet, see https://github.com/plenoptic-org/geodesics for ongoing development", # noqa: E501 "1.1.0", ) class Geodesic(OptimizedSynthesis): r"""Synthesize an approximate geodesic between two images according to a model. This method can be used to visualize and refine the invariances of a model's representation as described in [1]_. NOTE: This synthesis method is still under construction. It will run, but it might not find the most informative geodesic. Parameters ---------- image_a, image_b Start and stop anchor points of the geodesic, of shape (1, channel, height, width). model an analysis model that computes representations on signals like `image_a`. n_steps the number of steps (i.e., transitions) in the trajectory between the two anchor points. initial_sequence initialize the geodesic with pixel linear interpolation (``'straight'``), or with a brownian bridge between the two anchors (``'bridge'``). range_penalty_lambda strength of the regularizer that enforces the allowed_range. Must be non-negative. allowed_range Range (inclusive) of allowed pixel values. Any values outside this range will be penalized. Attributes ---------- geodesic: Tensor the synthesized sequence of images between the two anchor points that minimizes representation path energy, of shape ``(n_steps+1, channel, height, width)``. It starts with image_a and ends with image_b. pixelfade: Tensor the straight interpolation between the two anchor points, used as reference losses : Tensor A list of our loss over iterations. gradient_norm : list A list of the gradient's L2 norm over iterations. pixel_change_norm : list A list containing the L2 norm of the pixel change over iterations (``pixel_change_norm[i]`` is the pixel change norm in ``geodesic`` between iterations ``i`` and ``i-1``). step_energy: Tensor step lengths in representation space, stored along the optimization process. dev_from_line: Tensor deviation of the representation to the straight line interpolation, measures distance from straight line and distance along straight line, stored along the optimization process Notes ----- Manifold prior hypothesis: natural images form a manifold 𝑀ˣ embedded in signal space (ℝⁿ), a model warps this manifold to another manifold 𝑀ʸ embedded in representation space (ℝᵐ), and thereby induces a different local metric. This method computes an approximate geodesics by solving an optimization problem: it minimizes the path energy (aka. action functional), which has the same minimum as minimizing path length and by Cauchy-Schwarz, reaches it with constant-speed minimizing geodesic Caveat: depending on the geometry of the manifold, geodesics between two anchor points may not be unique and may depend on the initialization. References ---------- .. [1] Geodesics of learned representations O J Hénaff and E P Simoncelli Published in Int'l Conf on Learning Representations (ICLR), May 2016. https://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Henaff16b """ def __init__( self, image_a: Tensor, image_b: Tensor, model: torch.nn.Module, n_steps: int = 10, initial_sequence: Literal["straight", "bridge"] = "straight", range_penalty_lambda: float = 0.1, allowed_range: tuple[float, float] = (0, 1), ): super().__init__(range_penalty_lambda, allowed_range) validate_input(image_a, no_batch=True, allowed_range=allowed_range) validate_input(image_b, no_batch=True, allowed_range=allowed_range) validate_model( model, image_shape=image_a.shape, image_dtype=image_a.dtype, device=image_a.device, ) self.n_steps = n_steps self._model = model self._image_a = image_a self._image_b = image_b self.pixelfade = make_straight_line(image_a, image_b, n_steps) self._initialize(initial_sequence, image_a, image_b, n_steps) self._dev_from_line = [] self._step_energy = [] def _initialize(self, initial_sequence, start, stop, n_steps): """initialize the geodesic Parameters ---------- initial_sequence initialize the geodesic with pixel linear interpolation (``'straight'``), or with a brownian bridge between the two anchors (``'bridge'``). """ if initial_sequence == "bridge": geodesic = sample_brownian_bridge(start, stop, n_steps) elif initial_sequence == "straight": geodesic = make_straight_line(start, stop, n_steps) else: raise ValueError( f"Don't know how to handle initial_sequence={initial_sequence}" ) _, geodesic, _ = torch.split(geodesic, [1, n_steps - 1, 1]) self._initial_sequence = initial_sequence geodesic.requires_grad_() self._geodesic = geodesic
[docs] def synthesize( self, max_iter: int = 1000, optimizer: torch.optim.Optimizer | None = None, store_progress: bool | int = False, stop_criterion: float | None = None, stop_iters_to_check: int = 50, ): """Synthesize a geodesic via optimization. Parameters ---------- max_iter The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). optimizer The optimizer to use. If None and this is the first time calling synthesize, we use Adam(lr=.001, amsgrad=True); if synthesize has been called before, this must be None and we reuse the previous optimizer. store_progress Whether we should store the step energy and deviation of the representation from a straight line. If False, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). stop_criterion If pixel_change_norm (i.e., the norm of the difference in ``self.geodesic`` from one iteration to the next) over the past ``stop_iters_to_check`` has been less than ``stop_criterion``, we terminate synthesis. If None, we pick a default value based on the norm of ``self.pixelfade``. stop_iters_to_check How many iterations back to check in order to see if pixel_change_norm has stopped decreasing (for ``stop_criterion``). """ if stop_criterion is None: # semi arbitrary default choice of tolerance stop_criterion = ( torch.linalg.vector_norm(self.pixelfade, ord=2) / 1e4 * (1 + 5**0.5) / 2 ) print(f"\n Stop criterion for pixel_change_norm = {stop_criterion:.5e}") self._initialize_optimizer(optimizer, "_geodesic", 0.001) # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for _ in pbar: self._store(len(self.losses)) loss = self._optimizer_step(pbar) if not torch.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): warnings.warn("Pixel change norm has converged, stopping synthesis") break pbar.close()
[docs] def objective_function(self, geodesic: Tensor | None = None) -> Tensor: """Compute geodesic synthesis loss. This is the path energy (i.e., squared L2 norm of each step) of the geodesic's model representation, with the weighted range penalty. Additionally, caches: - ``self._geodesic_representation = self.model(geodesic)`` - ``self._most_recent_step_energy = self._calculate_step_energy( self._geodesic_representation)`` These are cached because we might store them (if ``self.store_progress is True``) and don't want to recalcualte them Parameters ---------- geodesic Geodesic to check. If None, we use ``self.geodesic``. Returns ------- loss """ if geodesic is None: geodesic = self.geodesic self._geodesic_representation = self.model(geodesic) self._most_recent_step_energy = self._calculate_step_energy( self._geodesic_representation ) loss = self._most_recent_step_energy.mean() range_penalty = penalize_range(self.geodesic, self.allowed_range) return loss + self.range_penalty_lambda * range_penalty
def _calculate_step_energy(self, z): """calculate the energy (i.e. squared l2 norm) of each step in `z`.""" velocity = torch.diff(z, dim=0) step_energy = torch.linalg.vector_norm(velocity, ord=2, dim=[1, 2, 3]) ** 2 return step_energy def _optimizer_step(self, pbar): """ At each step of the optimization, the following is done: - compute the representation - compute the loss as a sum of: - representation's path energy - range constraint (weighted by lambda) - compute the gradients - make sure that neither the loss or the gradients are NaN - let the optimizer take a step in the direction of the gradients - display some information - store some information - return pixel_change_norm, the norm of the step just taken """ last_iter_geodesic = self._geodesic.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) grad_norm = torch.linalg.vector_norm(self._geodesic.grad.data, ord=2, dim=None) self._gradient_norm.append(grad_norm) pixel_change_norm = torch.linalg.vector_norm( self._geodesic - last_iter_geodesic, ord=2, dim=None ) self._pixel_change_norm.append(pixel_change_norm) # displaying some information pbar.set_postfix( OrderedDict( [ ("loss", f"{loss.item():.4e}"), ("gradient norm", f"{grad_norm.item():.4e}"), ("pixel change norm", f"{pixel_change_norm.item():.5e}"), ] ) ) return loss def _check_convergence( self, stop_criterion: float, stop_iters_to_check: int ) -> bool: """Check whether the pixel change norm has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? | | no yes | '---->Is ``(self.pixel_change_norm[-stop_iters_to_check:] < stop_criterion).all()``? | | | no | | | yes <-------' | | '------> return ``True`` | '---------> return ``False`` Parameters ---------- stop_criterion If the pixel change norm has been less than ``stop_criterion`` for all of the past ``stop_iters_to_check``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the pixel change norm has stopped decreasing (for ``stop_criterion``). Returns ------- loss_stabilized : Whether the pixel change norm has stabilized or not. """ # noqa: E501 return pixel_change_convergence(self, stop_criterion, stop_iters_to_check)
[docs] def calculate_jerkiness(self, geodesic: Tensor | None = None) -> Tensor: """ Compute the alignment of representation's acceleration to model local curvature. This is the first order optimality condition for a geodesic, and can be used to assess the validity of the solution obtained by optimization. Parameters ---------- geodesic Geodesic to check. If None, we use ``self.geodesic``. Must have a gradient attached. Returns ------- jerkiness """ if geodesic is None: geodesic = self.geodesic geodesic_representation = self.model(geodesic) velocity = torch.diff(geodesic_representation, dim=0) acceleration = torch.diff(velocity, dim=0) acc_magnitude = torch.linalg.vector_norm( acceleration, ord=2, dim=[1, 2, 3], keepdim=True ) acc_direction = torch.div(acceleration, acc_magnitude) # we slice the output of the VJP, rather than slicing geodesic, because # slicing interferes with the gradient computation: # https://stackoverflow.com/a/54767100 accJac = self._vector_jacobian_product( geodesic_representation[1:-1], geodesic, acc_direction )[1:-1] step_jerkiness = torch.linalg.vector_norm(accJac, dim=[1, 2, 3], ord=2) ** 2 return step_jerkiness
def _vector_jacobian_product(self, y, x, a): """compute vector-jacobian product: $a^T dy/dx = dy/dx^T a$, and allow for further gradient computations by retaining, and creating the graph. """ accJac = autograd.grad(y, x, a, retain_graph=True, create_graph=True)[0] return accJac def _store(self, i: int) -> bool: """Store step_energy and dev_from_line, if appropriate. if it's the right iteration, we update ``step_energy`` and ``dev_from_line``. Parameters ---------- i the current iteration Returns ------- stored True if we stored this iteration, False if not. """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs try: self._step_energy.append( self._most_recent_step_energy.detach().to("cpu") ) self._dev_from_line.append( torch.stack( deviation_from_line( self._geodesic_representation.detach().to("cpu") ) ).T ) except AttributeError: # the first time _store is called (i.e., before optimizer is # stepped for first time) those attributes won't be # initialized geod_rep = self.model(self.geodesic) self._step_energy.append( self._calculate_step_energy(geod_rep).detach().to("cpu") ) self._dev_from_line.append( torch.stack(deviation_from_line(geod_rep.detach().to("cpu"))).T ) stored = True else: stored = False return stored
[docs] def save(self, file_path: str): r"""Save all relevant variables in .pt file. See ``load`` docstring for an example of use. Parameters ---------- file_path : str The path to save the Geodesic object to """ # I don't think any of our existing attributes can be used to check # whether model has changed (unlike Metamer, which stores # target_representation), so we use the following as a proxy self._save_check = self.objective_function(self.pixelfade) super().save(file_path, attrs=None)
[docs] def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired :attr:`dtype` s. In addition, this method will only cast the floating point parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module """ attrs = [ "_image_a", "_image_b", "_geodesic", "_model", "_step_energy", "_dev_from_line", "pixelfade", ] super().to(*args, attrs=attrs, **kwargs)
[docs] def load( self, file_path: str, map_location: str | None = None, **pickle_load_args, ): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Geodesic`` object -- we will ensure that ``image_a``, ``image_b``, ``model``, ``n_steps``, ``initial_sequence``, ``range_penalty_lambda``, ``allowed_range``, and ``pixelfade`` are all identical. Note this operates in place and so doesn't return anything. Parameters ---------- file_path : str The path to load the synthesis object from map_location : str, optional map_location argument to pass to ``torch.load``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to ``torch.device`` pickle_load_args : any additional kwargs will be added to ``pickle_module.load`` via ``torch.load``, see that function's docstring for details. Examples -------- >>> geo = po.synth.Geodesic(img_a, img_b, model) >>> geo.synthesize(max_iter=10, store_progress=True) >>> geo.save('geo.pt') >>> geo_copy = po.synth.Geodesic(img_a, img_b, model) >>> geo_copy.load('geo.pt') Note that you must create a new instance of the Synthesis object and *then* load. """ check_attributes = [ "_image_a", "_image_b", "n_steps", "_initial_sequence", "_range_penalty_lambda", "_allowed_range", "pixelfade", ] check_loss_functions = [] new_loss = self.objective_function(self.pixelfade) super().load( file_path, map_location=map_location, check_attributes=check_attributes, check_loss_functions=check_loss_functions, **pickle_load_args, ) old_loss = self.__dict__.pop("_save_check") if not torch.allclose(new_loss, old_loss, rtol=1e-2): raise ValueError( "objective_function on pixelfade of saved and initialized" " Geodesic object are different! Do they use the same model?" f" Self: {new_loss}, Saved: {old_loss}" ) # make this require a grad again self._geodesic.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. if len(self._dev_from_line) and self._dev_from_line[0].device.type != "cpu": self._dev_from_line = [dev.to("cpu") for dev in self._dev_from_line] if len(self._step_energy) and self._step_energy[0].device.type != "cpu": self._step_energy = [step.to("cpu") for step in self._step_energy]
@property def model(self): return self._model @property def image_a(self): return self._image_a @property def image_b(self): return self._image_b # self._geodesic contains the portion we're optimizing, but self.geodesic # combines this with the end points @property def geodesic(self): return torch.cat([self.image_a, self._geodesic, self.image_b]) @property def step_energy(self): """ Squared L2 norm of transition between geodesic frames in representation space. Has shape ``(np.ceil(synth_iter/store_progress), n_steps)``, where ``synth_iter`` is the number of iterations of synthesis that have happened. """ return torch.stack(self._step_energy) @property def dev_from_line(self): """Deviation of representation each from of ``self.geodesic`` from a straight line. Has shape ``(np.ceil(synth_iter/store_progress), n_steps+1, 2)``, where ``synth_iter`` is the number of iterations of synthesis that have happened. For final dimension, the first element is the Euclidean distance along the straight line and the second is the Euclidean distance to the line. """ return torch.stack(self._dev_from_line)
[docs] @deprecated( "Geodesic is not robust enough yet, see https://github.com/plenoptic-org/geodesics for ongoing development", # noqa: E501 "1.1.0", ) def plot_loss( geodesic: Geodesic, ax: mpl.axes.Axes | None = None, **kwargs ) -> mpl.axes.Axes: """Plot synthesis loss. Parameters ---------- geodesic : Geodesic object whose synthesis loss we want to plot. ax : If not None, the axis to plot this representation on. If None, we call ``plt.gca()`` kwargs : passed to plt.semilogy Returns ------- ax : Axes containing the plot. Notes ----- """ if ax is None: ax = plt.gca() ax.semilogy(geodesic.losses, **kwargs) ax.set(xlabel="Synthesis iteration", ylabel="Loss") return ax
[docs] @deprecated( "Geodesic is not robust enough yet, see https://github.com/plenoptic-org/geodesics for ongoing development", # noqa: E501 "1.1.0", ) def plot_deviation_from_line( geodesic: Geodesic, natural_video: Tensor | None = None, ax: mpl.axes.Axes | None = None, ) -> mpl.axes.Axes: """Visual diagnostic of geodesic linearity in representation space. This plot illustrates the deviation from the straight line connecting the representations of a pair of images, for different paths in representation space. Parameters ---------- geodesic : Geodesic object to visualize. natural_video : Natural video that bridges the anchor points, for comparison. ax : If not None, the axis to plot this representation on. If None, we call ``plt.gca()`` Returns ------- ax: Axes containing the plot Notes ----- Axes are in the same units, normalized by the distance separating the end point representations. Knots along each curve indicate samples used to compute the path. When the representation is non-linear it may not be feasible for the geodesic to be straight (for example if the representation is normalized, all paths are constrained to live on a hypershpere). Nevertheless, if the representation is able to linearize the transformation between the anchor images, then we expect that both the ground truth natural video sequence and the geodesic will deviate from straight line similarly. By contrast the pixel-based interpolation will deviate significantly more from a straight line. """ if ax is None: ax = plt.gca() pixelfade_dev = deviation_from_line(geodesic.model(geodesic.pixelfade)) ax.plot(*[to_numpy(d) for d in pixelfade_dev], "g-o", label="pixelfade") geodesic_dev = deviation_from_line(geodesic.model(geodesic.geodesic).detach()) ax.plot(*[to_numpy(d) for d in geodesic_dev], "r-o", label="geodesic") if natural_video is not None: video_dev = deviation_from_line(geodesic.model(natural_video)) ax.plot(*[to_numpy(d) for d in video_dev], "b-o", label="natural video") ax.set( xlabel="Distance along representation line", ylabel="Distance from representation line", title="Deviation from the straight line", ) ax.legend(loc=1) return ax