Run this notebook yourself!

Download the executed notebook: Demo_Eigendistortion.ipynb!

Run it in your browser: Demo_Eigendistortion.ipynb!

Attention

The eigendistortion synthesis investigated in this notebook takes a long time to run, especially if you don’t have a GPU available. Therefore, we have cached the result of these syntheses online and only download them for investigation in this notebook.

Reproducing Berardino et al., 2017 (Eigendistortions)#

Author: Lyndon Duong, Jan 2021

In this demo, we will be reproducing eigendistortions first presented in Berardino et al 2017. We’ll be using a Front End model of the human visual system (called “On-Off” in the paper), as well as an early layer of VGG16. The Front End model is a simple convolutional neural network with a normalization nonlinearity, loosely based on biological retinal/geniculate circuitry.

Front-end model

This signal-flow diagram shows an input being decomposed into two channels, with each being luminance and contrast normalized, and ending with a ReLu.

What do eigendistortions tell us?#

Our perception is influenced by our internal representation (neural responses) of the external world. Eigendistortions are rank-ordered directions in image space, along which a model’s responses are more sensitive. Plenoptic’s Eigendistortion provides an easy way to synthesize eigendistortions for any PyTorch model.

import torch

import plenoptic as po

# this notebook uses torchvision, which is an optional dependency.
# if this fails, install torchvision in your plenoptic environment
# and restart the notebook kernel.
try:
    from torchvision import models
    from torchvision.models import feature_extraction
except ModuleNotFoundError:
    raise ModuleNotFoundError(
        "optional dependency torchvision not found!"
        " please install it in your plenoptic environment "
        "and restart the notebook kernel"
    )

# we do not actually run synthesis in this notebook, so the cpu is fine.
DEVICE = torch.device("cpu")

Input preprocessing#

Let’s load the parrot image used in the paper and display it:

# crop the image to be square:
image_tensor = po.data.parrot().to(DEVICE).to(torch.float64)
image_tensor = po.process.center_crop(image_tensor, min(image_tensor.shape[-2:]))

print("Torch image shape:", image_tensor.shape)

po.plot.imshow(image_tensor);
Torch image shape: torch.Size([1, 1, 254, 254])
../../_images/e5b259f070e8848c89787314174e70eb5f070d1f5a2419c8c59299fb317948b8.png

Since the Front-end OnOff model only has two channel outputs, we can easily visualize the feature maps. We’ll apply a circular mask to this model’s inputs to avoid edge artifacts in the synthesis.

mdl_f = po.models.OnOff(
    kernel_size=(31, 31), pretrained=True, apply_mask=True, cache_filt=True
)
po.remove_grad(mdl_f)
mdl_f = mdl_f.to(DEVICE).to(image_tensor.dtype)
mdl_f.eval()

response_f = mdl_f(image_tensor)
po.plot.imshow(
    response_f,
    title=["on channel response", "off channel response"],
);
../../_images/3b5f0ab34f82a8314da78b0f55d2f131fc29bef1e8eaa454b581846ad0d4732c.png

Synthesizing eigendistortions#

Front-end model: eigendistortion synthesis#

Now that we have our Front End model set up, we can synthesize eigendistortions! This is done easily just by calling synthesize after instantiating the Eigendistortion object. We’ll synthesize the top and bottom k, representing the most- and least-noticeable eigendistortions for this model.

The paper synthesizes the top and bottom k=1 eigendistortions, but we’ll set k>1 so the algorithm converges/stabilizes faster.

# synthesize the top and bottom k distortions
eigendist_f = po.Eigendistortion(image=image_tensor, model=mdl_f)
# this synthesis takes a long time to run, so we load in a cached version.
# see the following admonition for how to run this yourself
eigendist_f.load(
    po.data.fetch_data("berardino_onoff.pt"),
    tensor_equality_atol=1e-7,
    map_location=DEVICE,
)
Downloading file 'berardino_onoff.pt' from 'https://osf.io/download/uqfa8' to '/home/jenkins/agent/workspace/CCN_neurorse_plenoptic_main/.cache/plenoptic'.

Front-end model: eigendistortion display#

Once synthesized, we can plot the distortion on the image using eigendistortion_imshow_all. Feel free to adjust the constant distortion_scale that scales the amount of each distortion on the image.

po.plot.eigendistortion_imshow_all(
    eigendist_f,
    distortion_scale=3,
    suptitle="OnOff",
);
../../_images/96f74546a8a792a0e3963b491fa2ddbf0b997faa1e47b44a375fa4c639f5ab83.png

VGG16: eigendistortion synthesis#

Following the lead of Berardino et al. (2017), let’s compare the Front End model’s eigendistortion to those of an early layer of VGG16! VGG16 takes as input color images, so we’ll need to repeat the grayscale parrot along the RGB color dimension.

# Create a class that takes the nth layer output of a given model
class TorchVision(torch.nn.Module):
    def __init__(self, model, return_node: str):
        super().__init__()
        self.extractor = feature_extraction.create_feature_extractor(
            model, return_nodes=[return_node]
        )
        self.model = model
        self.return_node = return_node

    def forward(self, x):
        return self.extractor(x)[self.return_node]

VGG16 was trained on pre-processed ImageNet images with approximately zero mean and unit stdev, so we can preprocess our Parrot image the same way.

# VGG16
def normalize(img_tensor):
    """standardize the image for vgg16"""
    return (img_tensor - img_tensor.mean()) / img_tensor.std()


# store these for later so we can un-normalize the image for display purposes
orig_mean = image_tensor.mean().detach()
orig_std = image_tensor.std().detach()
image_tensor = normalize(image_tensor)

image_tensor3 = image_tensor.repeat(1, 3, 1, 1)

# "layer 3" according to Berardino et al (2017)
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1, progress=False)
mdl_v = TorchVision(vgg, "features.11").to(DEVICE).to(image_tensor.dtype)
po.remove_grad(mdl_v)
mdl_v.eval()

eigendist_v = po.Eigendistortion(image=image_tensor3, model=mdl_v)
# this synthesis takes a long time to run, so we load in a cached version.
# see the following admonition for how to run this yourself
eigendist_v.load(
    po.data.fetch_data("berardino_vgg16.pt"),
    tensor_equality_atol=1e-7,
    map_location=DEVICE,
)
/home/jenkins/agent/workspace/CCN_neurorse_plenoptic_main/lib/python3.12/site-packages/plenoptic/validate.py:115: UserWarning: input_tensor range is (-1.973452970558101, 2.107675021080057); plenoptic's methods have largely been tested with the range [0, 1]. Synthesis should still work, but if you have any problems, please open an issue.
  warnings.warn(

VGG16: eigendistortion display#

We can now display the most- and least-noticeable eigendistortions as before, then compare their quality to those of the Front-end model.

Since the distortions here were synthesized using a pre-processed (normalized) imagea, we can easily pass a function to unprocess the image.

# create an image processing function to unnormalize the image
def unnormalize(x):
    return x * orig_std.to(x.device) + orig_mean.to(x.device)


po.plot.eigendistortion_imshow_all(
    eigendist_v,
    distortion_scale=[15, 100],
    suptitle="VGG16",
    process_image=unnormalize,
);
/home/jenkins/agent/workspace/CCN_neurorse_plenoptic_main/lib/python3.12/site-packages/plenoptic/plot/eigendistortion.py:279: UserWarning: Adding 0.5 to distortions to plot as RGB image, else matplotlib clipping will result in a strange looking image...
  warnings.warn(
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.05490196123717117..1.0000000000000222].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.4372402923781634..1.0000000000001918].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.06562489934552418..1.2138206822764188].
../../_images/84125a04b84c1bf67158905888aa44414e14993b1655047917d06b8f2f2aa934.png

Final thoughts#

To rigorously test which of these model’s representations are more human-like, we’ll have to conduct a perceptual experiment. For now, we’ll just leave it to you to eyeball and decide which distortions are more or less noticeable!