Source code for plenoptic.models.feature_extractor

"""Adapter for compatibility with torchvision and timm models."""
# numpydoc ignore=EX01

import warnings
from collections import OrderedDict
from typing import Literal

import einops
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

try:
    from torchvision.models import feature_extraction
except ImportError:
    feature_extraction = None

from ..plot import display


[docs] class FeatureExtractorModel(torch.nn.Module): """ Return features from model. This adapter combines a torch model with a feature extractor and optional transform, allowing us to target the output of one or more particular layers in a deep neural network for use with synthesis objects. This adapter is intended to work with :external+torchvision:ref:`TorchVision <models>` and :external+timm:doc:`timm <models>`, two model zoos from the deep learning community that contain a large number of models. For more details on the node naming conventions used here, please see the :external+torchvision:ref:`About Node Names <about-node-names>` heading in the :external+torchvision:doc:`torchvision documentation <feature_extraction>`. .. attention:: This model requires the optional dependency ``torchvision``. Make sure it is installed before initializing this model. Parameters ---------- model The pytorch module to use. return_nodes The names of the nodes to return. See Examples and :external+torchvision:doc:`torchvision documentation <feature_extraction>`. transform Pre-processing transform to apply to image before passing to model. If ``None``, will not apply any transform. Raises ------ ImportError If torchvision is not installed. Examples -------- Use with a torchvision model: .. plot:: :context: reset >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights).eval() >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> tv_transform ImageClassification( crop_size=[224] resize_size=[256] mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] interpolation=InterpolationMode.BILINEAR ) >>> norm = torchvision.transforms.Normalize(tv_transform.mean, tv_transform.std) >>> model = po.models.FeatureExtractorModel(tv_model, "layer2", norm) >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> img.shape torch.Size([1, 3, 224, 224]) >>> model(img).shape torch.Size([1, 401408]) >>> po.remove_grad(model) >>> po.validate.validate_model(model, image_shape=img.shape) Use with timm a model. The primary difference is in the syntax for retrieving the model and the transform: .. attention:: The following block requires the additional package ``timm``. >>> import timm >>> from timm.data import resolve_data_config >>> from timm.data.transforms_factory import create_transform >>> timm_model = timm.create_model("timm/resnet50.tv_in1k", pretrained=True).eval() >>> # Create Transform >>> timm_transform = create_transform( ... **resolve_data_config(timm_model.pretrained_cfg, model=timm_model) ... ) >>> # This model has the same resizing, cropping, normalizing transform as above, >>> # but timm allows us to explicitly select the different steps >>> timm_transform Compose( Resize(size=256, interpolation=bilinear, max_size=None, antialias=True) CenterCrop(size=(224, 224)) MaybeToTensor() Normalize(mean=tensor([0.4850, ...]), std=tensor([0.2290, ...])) ) >>> timm_crop = timm_transform.transforms[1] >>> timm_norm = timm_transform.transforms[-1] >>> model = po.models.FeatureExtractorModel(timm_model, "layer2", timm_norm) >>> # this model requires a 3d input, and expects it to have a certain input size. >>> img = timm_crop(po.data.einstein(False)) >>> img.shape torch.Size([1, 3, 224, 224]) >>> model(img).shape torch.Size([1, 401408]) >>> po.remove_grad(model) >>> po.validate.validate_model(model, image_shape=img.shape) The torchvision function :external+torchvision:func:`torchvision.models.feature_extraction.get_graph_node_names` allows us to view possible node names: >>> from torchvision.models import feature_extraction >>> # This function returns two lists, one for nodes in train mode, one for those in >>> # eval mode. We want the eval mode list: >>> node_names = feature_extraction.get_graph_node_names(tv_model)[1] >>> len(node_names) 176 >>> node_names[77:81] ['layer2.3.add', 'layer2.3.relu_2', 'layer3.0.conv1', 'layer3.0.bn1'] >>> model = po.models.FeatureExtractorModel(tv_model, node_names[78], norm) >>> model(img).shape torch.Size([1, 401408]) We can even pass multiple node names, in which case all corresponding outputs are concatenated together. >>> model = po.models.FeatureExtractorModel(tv_model, ["layer2", "layer4"], norm) >>> model(img).shape torch.Size([1, 501760]) The order of elements in ``return_nodes`` does not matter: the outputs are always returned based on their order in ``model``. >>> rep = model(img) >>> model = po.models.FeatureExtractorModel(tv_model, ["layer4", "layer2"], norm) >>> rep[0, 0] == model(img)[0, 0] tensor(True) The function :meth:`convert_to_dict` will convert the output of :meth:`forward` to a dictionary and return its elements to their original shape. This may be useful for plotting or investigation. >>> [(k, v.shape) for k, v in model.convert_to_dict(model(img)).items()] [('layer2', torch.Size([1, 512, 28, 28])), ('layer4', torch.Size([1, 2048, 7, 7]))] Visualize model representation with :meth:`plot_representation`: .. plot:: :context: close-figs >>> fig, axes = model.plot_representation(model(img)) """ def __init__( self, model: torch.nn.Module, return_nodes: str | list[str] | dict[str, str], transform: torch.nn.Module | None = None, ): super().__init__() self.transform = transform if isinstance(return_nodes, str): return_nodes = [return_nodes] if feature_extraction is None: raise ImportError( "Missing optional dependency torchvision, which is needed for " "FeatureExtractorModel, please install it!" ) self.extractor = feature_extraction.create_feature_extractor( model, return_nodes ) self.model = model if hasattr(model, "training") and model.training: warnings.warn( "model is in training mode, you probably want to call eval()" " to switch to evaluation mode" ) elif hasattr(model, "training") and not model.training: # by default, all torch modules are in training mode. make sure # FeatureExtractor mode matches that of the underlying model self.eval() self._out_keys = None self._packed_shapes = None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute feature activity of an input. We flatten across all dimensions except the batch / first dimension. This allows us to support returning features of different shapes and dimensionality (as is common across layers in deep nets), while still returning only a single tensor, as is necessary for our synthesis methods. Parameters ---------- x The tensor to analyze. Returns ------- representation_tensor The feature activity as a 2d tensor, of shape ``(batch, representation)``. See Also -------- convert_to_dict Convert tensor representation to a dictionary, whose keys are the feature names, with their original shapes. Examples -------- >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights) >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> norm = torchvision.transforms.Normalize(tv_transform.mean, tv_transform.std) >>> model = po.models.FeatureExtractorModel(tv_model, "layer2", norm).eval() >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> model(img).shape torch.Size([1, 401408]) """ if self.transform is not None: x = self.transform(x) original_out = self.extractor(x) return self.convert_to_tensor(original_out)
[docs] def convert_to_tensor( self, representation_dict: OrderedDict[str, torch.Tensor] ) -> torch.Tensor: """ Convert dictionary of statistics to a tensor. The output has shape ``(batch, representation)``, flattening and concatenating across all representation features, channels, and additional dimensions. The dictionary representation may be easier to make sense of. Parameters ---------- representation_dict Dictionary of representation, whose keys are the feature names and whose values are tensors in the original shape. Returns ------- representation_tensor Feature activity as a 2d tensor, of shape ``(batch, representation)``. See Also -------- convert_to_dict Convert tensor representation to a dictionary, whose keys are the feature names, with their original shapes. Examples -------- >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights) >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> norm = torchvision.transforms.Normalize(tv_transform.mean, tv_transform.std) >>> model = po.models.FeatureExtractorModel( ... tv_model, ["layer2", "layer4"], norm ... ) >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> representation_tensor = model(img) >>> representation_dict = model.convert_to_dict(representation_tensor) >>> representation_tensor_new = model.convert_to_tensor(representation_dict) >>> torch.equal(representation_tensor, representation_tensor_new) True """ self._out_keys = representation_dict.keys() packed_out, self._packed_shapes = einops.pack( list(representation_dict.values()), "b *" ) return packed_out
[docs] def convert_to_dict( self, representation_tensor: torch.Tensor ) -> OrderedDict[str, torch.Tensor]: """ Convert tensor of statistics to a dictionary. The output of :meth:`forward` is flattened so as to allow us to return a single tensor, regardless of the specified features. This function undoes that flattening, returning a dictionary whose keys are the feature names and whose values have the original shape. This may be useful for investigation or plotting. This function requires calling either :func:`forward` or :func:`convert_to_tensor` first, so that it knows how to properly reshape the input. Parameters ---------- representation_tensor 2d tensor of model representation, e.g., the output of :meth:`forward`. Returns ------- representation_dict Dictionary of representation, with informative keys. Raises ------ ValueError If :func:`forward` or :func:`convert_to_tensor` was not called before this one, because then we don't know how to properly reshape. See Also -------- convert_to_tensor Convert dictionary representation to a 2d tensor. Examples -------- >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights) >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> norm = torchvision.transforms.Normalize(tv_transform.mean, tv_transform.std) >>> model = po.models.FeatureExtractorModel( ... tv_model, ["layer2", "layer4"], norm ... ) >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> representation_dict = model.convert_to_dict(model(img)) >>> [(k, v.shape) for k, v in representation_dict.items()] [('layer2', torch.Size([1, 512, 28, 28])), ('layer4', torch.Size([1, 2048, 7, 7]))] """ if self._packed_shapes is None or self._out_keys is None: raise ValueError( "Call forward or convert_to_tensor before this function," " otherwise we don't know how to properly reshape!" ) unpacked = einops.unpack(representation_tensor, self._packed_shapes, "b *") return OrderedDict({k: v for k, v in zip(self._out_keys, unpacked)})
[docs] def update_plot( self, axes: mpl.axes.Axes | list[mpl.axes.Axes], data: torch.Tensor | dict, batch_idx: int = 0, rescale_ylim: bool = False, ) -> list: """ Update representation plot (for an animation). This is a helper function for creating an animation over time. Parameters ---------- axes The list of axes to update. We assume that these are the axes created by :func:`plot_representation` and so contain artists in the correct order. data The new data to use for updating the plot. Should look like the output of :meth:`forward` or :meth:`convert_to_dict`, with the exact same structure (e.g., as returned by another instance of this class). batch_idx Which index to take from the batch dimension. rescale_ylim Whether to rescale the ylimits of the per-channel plot or not. Returns ------- artists A list of the artists used to update the information on the plots. See Also -------- plot_representation Create plots to summarize model representation, which we assume created the axes passed to this function for updating. :func:`plenoptic.plot.update_plot` Generic ``update_plot`` function. :func:`plenoptic.plot.synthesis_animate` Function which creates a video of synthesis process over time, makes use of this function. Examples -------- .. plot:: :context: reset >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights) >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> norm = torchvision.transforms.Normalize( ... tv_transform.mean, tv_transform.std ... ) >>> model = po.models.FeatureExtractorModel(tv_model, "layer2", norm) >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> fig, axes = model.plot_representation(model(img)) >>> img = po.process.center_crop( ... po.data.curie(False), tv_transform.crop_size[0] ... ) >>> model.update_plot(axes, model(img)) [<matplotlib...>] """ if isinstance(data, torch.Tensor): data = self.convert_to_dict(data) artists = [] per_channel_reps = [] for i, (k, v) in enumerate(data.items()): # Average representation across channels avg_channel_rep = v.mean(dim=1, keepdim=True) # Average representation across additional dimensions (probably space) per_channel_rep = v.mean(dim=tuple(np.arange(2, v.ndim))) while per_channel_rep.ndim < 3: per_channel_rep = per_channel_rep.unsqueeze(0) per_channel_reps.append(per_channel_rep) art = display.update_plot( axes[2 * i : 2 * (i + 1)], {"00": avg_channel_rep, "01": per_channel_rep}, batch_idx=batch_idx, ) artists.extend(art) if rescale_ylim: display._rescale_ylim(axes[1::2], torch.cat(per_channel_reps, -1)) return artists
[docs] def plot_representation( self, data: torch.Tensor | dict[str, torch.Tensor], ax: plt.Axes | None = None, figsize: tuple[float, float] | None = None, ylim: tuple[float, float] | Literal[False] | None = False, batch_idx: int = 0, title: str | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: """ Plot model representation. This creates two plots: one containing the representation averaged across all channels, and one containing the per-channel representation, i.e., the representation averaged across all dimensions *except* channels. Intended for neural networks, e.g., models whose output at each node has many channels and one or two additional dimensions. Parameters ---------- data The data to show on the plot. Should look like the output of :meth:`forward` or :meth:`convert_to_dict`, with the exact same structure (e.g., as returned by another instance of this class). ax Axes where we will plot the data. If a ``plt.Axes`` instance, will subdivide into 6 or 8 new axes (depending on self.n_scales). If ``None``, we create a new figure. figsize The size of the figure to create. Must be ``None`` if ax is not ``None``. If both figsize and ax are ``None``, then we set ``figsize=(7, 5)``. ylim If not ``None``, the y-limits to use for this plot. If ``None``, we adjust y-limits to be symmetrical about 0. If ``False``, do not change y-limits. batch_idx Which index to take from the batch dimension (the first one). title Title for the plot. Returns ------- fig Figure containing the plot. axes List of axes containing the plot. Number of axes will be two per node. Raises ------ ValueError If both ``figsize`` and ``ax`` are not ``None``. Examples -------- .. plot:: :context: reset >>> import plenoptic as po >>> import torchvision >>> weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1 >>> tv_model = torchvision.models.resnet50(weights=weights) >>> # This model's transform consists of resizing, cropping, and normalizing. >>> # We recommend only including the normalizing in the transform. >>> tv_transform = weights.transforms() >>> norm = torchvision.transforms.Normalize( ... tv_transform.mean, tv_transform.std ... ) >>> model = po.models.FeatureExtractorModel(tv_model, "layer2", norm) >>> # this model requires a 3d input, and expects it to have a certain input >>> # size. >>> img = po.process.center_crop( ... po.data.einstein(False), tv_transform.crop_size[0] ... ) >>> model.plot_representation(model(img)) (<Figure ...>, [<Axes...>, <Axes...>]) This function creates two axes per node, one showing the representation averaged across channels, one showing it per channel (averaging across any additional dimensions): .. plot:: :context: close-figs >>> model = po.models.FeatureExtractorModel( ... tv_model, ["layer2", "layer4"], norm ... ) >>> model.plot_representation(model(img)) (<Figure ...>, [<Axes...>, <Axes...>, <Axes...>, <Axes...>]) Plot the dictionary representation: .. plot:: :context: close-figs >>> model.plot_representation(model.convert_to_dict(model(img))) (<Figure ...>, [<Axes...>, <Axes...>, <Axes...>, <Axes...>]) Plot on an existing axes object: .. plot:: :context: close-figs >>> fig, axes = plt.subplots(1, 2) >>> model.plot_representation(model.convert_to_dict(model(img)), ax=axes[1]) (<Figure ...>, [<Axes...>, <Axes...>, <Axes...>, <Axes...>]) """ if ax is None and figsize is None: figsize = (7, 5) elif ax is not None and figsize is not None: raise ValueError("figsize can't be set if ax is not None") if isinstance(data, torch.Tensor): data = self.convert_to_dict(data) # Determine figure layout n_cols = len(data) axes = [] if ax is None: fig = plt.figure(figsize=figsize) gs = mpl.gridspec.GridSpec(1, n_cols, fig) else: ax = display._clean_up_axes( ax, False, ["top", "right", "bottom", "left"], ["x", "y"] ) gs = ax.get_subplotspec().subgridspec(1, n_cols) fig = ax.figure for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i]) # Average representation across channels avg_channel_rep = v.mean(dim=1, keepdim=True) # Average representation across additional dimensions (probably space) per_channel_rep = v.mean(dim=tuple(np.arange(2, v.ndim))) while per_channel_rep.ndim < 3: per_channel_rep = per_channel_rep.unsqueeze(0) if avg_channel_rep.ndim == 3: height_ratios = [1, 1] elif avg_channel_rep.ndim == 4: height_ratios = [2, 1] # this warning is not relevant here with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="data has keys, so we're ignoring title" ) ax = display.plot_representation( data={ f"{k} avg across channels": avg_channel_rep, f"{k} per channel (n={v.shape[1]})": per_channel_rep, }, ax=ax, batch_idx=batch_idx, axes_direction="vertical", gridspec_kwargs={"height_ratios": height_ratios}, ylim=ylim, ) axes.extend(ax) if title is not None: fig.suptitle(title) return fig, axes