"""
Tools to deal with data from outside plenoptic.
For example, images synthesized using the code from another paper.
"""
import pathlib
from typing import Any
import imageio.v3 as iio
import matplotlib as mpl
import matplotlib.lines as lines
import numpy as np
import pyrtools as pt
import scipy.io as sio
from .data import fetch_data
__all__ = [
"plot_MAD_results",
]
def __dir__() -> list[str]:
return __all__
[docs]
def plot_MAD_results(
original_image: str,
noise_levels: list[int] | None = None,
results_dir: str | pathlib.Path | None = None,
ssim_images_dir: str | pathlib.Path | None = None,
zoom: int | float = 3,
vrange: str = "indep1",
**kwargs: Any,
) -> tuple[mpl.figure.Figure, dict[str, dict[str, float | np.ndarray]]]:
r"""
Plot original MAD results, provided by Zhou Wang.
Plot the results of original MAD Competition, as provided in .mat
files. The figure created shows the results for one reference image
and multiple noise levels. The reference image is plotted on the
first row, followed by a separate row for each noise level, which
will show the initial (noisy) image and the four synthesized images,
with their respective losses for the two metrics (MSE and SSIM).
We also return a dictionary that contains the losses, noise levels, and original
image name for each plotted noise level.
This code can probably be adapted to other uses, but requires that
all images are the same size and assumes they're all 64 x 64 pixels.
Parameters
----------
original_image
Which of the sample images to plot. Must be of the form ``f"samp{i}"``
where ``i`` is an integer between 1 and 10 (inclusive).
noise_levels
Which noise levels to plot. if ``None``, will plot all. If a list,
elements must be :math:`2^i` where :math:`i\in [1, 10]`.
results_dir
Path to the results directory containing the results.mat files. If
``None``, we download them.
ssim_images_dir
Path to the directory containing the .tif images used in SSIM paper. If
``None``, we download them.
zoom
Ratio of display pixels to image pixels, passed to
:func:`~plenoptic.plot.imshow`. If >1, must be an integer. If <1, must
be 1/d where d is a a divisor of the size of the largest image.
vrange
How to map image values to colormap. In addition to the values accepted by
:func:`~plenoptic.plot.imshow`, we also accept ``"row0/1/2/3"``, which
is the same as ``"auto0/1/2/3"``, except that we do it on a per-row basis (all
images with same noise level).
**kwargs
Passed to :func:`~plenoptic.plot.imshow`. Note that we call imshow
separately on each image and so any argument that relies on imshow having
access to all images will probably not work as expected.
Returns
-------
fig :
Figure containing the images.
results :
Dictionary containing the errors for each noise level. To
convert to a well-structured pandas DataFrame, run
``pandas.DataFrame(results).T``.
Raises
------
ValueError
If ``original_image`` takes an illegal value.
"""
if results_dir is None:
results_dir = fetch_data("MAD_results.tar.gz")
else:
results_dir = pathlib.Path(results_dir).expanduser()
if ssim_images_dir is None:
ssim_images_dir = fetch_data("ssim_images.tar.gz")
else:
ssim_images_dir = pathlib.Path(ssim_images_dir).expanduser()
allowed_vals = [f"samp{i}" for i in range(1, 11)]
if original_image not in allowed_vals:
err_msg = f"original_image must be one of {allowed_vals}"
raise ValueError(err_msg)
img_path = ssim_images_dir / f"{original_image}.tif"
orig_img = iio.imread(img_path)
blanks = np.ones((*orig_img.shape, 4))
if noise_levels is None:
noise_levels = [2**i for i in range(1, 11)]
results = {}
images = np.dstack([orig_img, blanks])
titles = ["Original image"] + 4 * [None]
super_titles = 5 * [None]
keys = [
"im_init",
"im_fixmse_maxssim",
"im_fixmse_minssim",
"im_fixssim_minmse",
"im_fixssim_maxmse",
]
for level in noise_levels:
mat = sio.loadmat(
str(results_dir / f"{original_image}_L{level}_results.mat"),
squeeze_me=True,
)
# remove these metadata keys
[mat.pop(k) for k in ["__header__", "__version__", "__globals__"]]
key_titles = [
f"Noise level: {level}",
f"Best SSIM: {mat['maxssim']:.05f}",
f"Worst SSIM: {mat['minssim']:.05f}",
f"Best MSE: {mat['minmse']:.05f}",
f"Worst MSE: {mat['maxmse']:.05f}",
]
key_super_titles = [
None,
f"Fix MSE: {mat['FIX_MSE']:.0f}",
None,
f"Fix SSIM: {mat['FIX_SSIM']:.05f}",
None,
]
for k, t, s in zip(keys, key_titles, key_super_titles):
images = np.dstack([images, mat.pop(k)])
titles.append(t)
super_titles.append(s)
# this then just contains the loss information
mat.update({"noise_level": level, "original_image": original_image})
results[f"L{level}"] = mat
images = images.transpose((2, 0, 1))
if vrange.startswith("row"):
vrange_list = []
for i in range(len(images) // 5):
vr, cmap = pt.tools.display.colormap_range(
images[5 * i : 5 * (i + 1)], vrange.replace("row", "auto")
)
vrange_list.extend(vr)
else:
vrange_list, cmap = pt.tools.display.colormap_range(images, vrange)
# this is a bit of hack to do the same thing imshow does, but with
# slightly more space dedicated to the title
fig = pt.tools.display.make_figure(
len(images) // 5,
5,
[zoom * i + 1 for i in images.shape[-2:]],
vert_pct=0.75,
)
for img, ax, t, vr, s in zip(images, fig.axes, titles, vrange_list, super_titles):
# these are the blanks
if (img == 1).all():
continue
pt.imshow(img, ax=ax, title=t, zoom=zoom, vrange=vr, cmap=cmap, **kwargs)
if s is not None:
font = {
k.replace("_", ""): v
for k, v in ax.title.get_font_properties().__dict__.items()
}
# these are the acceptable keys for the fontdict below
font = {
k: v
for k, v in font.items()
if k in ["family", "color", "weight", "size", "style"]
}
# for some reason, this (with passing the transform) is
# different (and looks better) than using ax.text. We also
# slightly adjust the placement of the text to account for
# different zoom levels (we also have 10 pixels between the
# rows and columns, which correspond to a different)
img_size = ax.bbox.size
fig.text(
1 + (5 / img_size[0]),
(1 / 0.75),
s,
fontdict=font,
transform=ax.transAxes,
ha="center",
va="top",
)
# linewidth of 1.5 looks good with bbox of 192, 192
linewidth = np.max([1.5 * np.mean(img_size / 192), 1])
line = lines.Line2D(
2 * [0 - ((5 + linewidth / 2) / img_size[0])],
[0, (1 / 0.75)],
transform=ax.transAxes,
figure=fig,
linewidth=linewidth,
)
fig.lines.append(line)
return fig, results