Source code for plenoptic.io
"""Helper functions for saving/loading.""" # numpydoc ignore=ES01
import importlib
import torch
__all__ = [
"examine_saved_synthesis",
"LoadWarning",
]
def __dir__() -> list[str]:
return __all__
[docs]
class LoadWarning(UserWarning):
"""
Custom warning to raise if there's an issue with loading.
And we do not want it to result in an error.
Examples
--------
>>> import plenoptic as po
>>> import warnings
>>> model = po.models.Gaussian((31, 31))
>>> model.eval()
Gaussian()
>>> po.remove_grad(model)
>>> met = po.Metamer(po.data.einstein(), model)
>>> met.synthesize(2)
>>> met.save("load_warning_example.pt")
>>> # this loss function has a different name but the same behavior
>>> met = po.Metamer(po.data.einstein(), model, lambda *args: po.loss.mse(*args))
>>> with warnings.catch_warnings(record=True) as warned:
... met.load("load_warning_example.pt", raise_on_checks=False)
... print(len(warned), warned[0].category)
1 <class 'plenoptic.io.LoadWarning'>
"""
pass
def _parse_save_io_attr_name(
synth_object: dict, input_names: tuple[str]
) -> tuple[list[torch.Tensor], list[str]]:
"""
Parse names of save_io_attrs, allowing for more complex behavior.
The strings specified in ``input_names`` must either be the names of this
object's attributes or of the form ``x * name``, where ``x`` is a float and
``name`` is a string as above, in which case we multiply that attribute by
``x``.
Parameters
----------
synth_object
Dictionary containing tensors corresponding to ``input_names`` from.
input_names
The second element from the tuple ``save_io_attrs`` input to
:func:`save`.
Returns
-------
tensors
The tensors to pass to the corresponding ``save_io_attr``.
input_names_test
List of strings of attributes that we ensure we save.
"""
tensors = []
input_names_test = []
for t in input_names:
t = t.split("*")
if len(t) == 2:
name = t[1].strip()
scale = float(t[0].strip())
else:
name = t[0]
scale = 1
input_names_test.append(name)
tensors.append(scale * synth_object[name])
return tensors, input_names_test
[docs]
def examine_saved_synthesis(file_path: str, map_location: str | None = None):
"""
Examine saved synthesis object.
This is used for debugging, it will print out information about the versions used,
names of the callable attributes, shapes of tensor attributes, etc.
Parameters
----------
file_path
The path to load the synthesis object from.
map_location
Argument to pass to :func:`torch.load` as ``map_location``. If you
save stuff that was being run on a GPU and are loading onto a
CPU, you'll need this to make sure everything lines up
properly. This should be structured like the str you would
pass to :class:`torch.device`.
"""
load_dict = torch.load(file_path, map_location=map_location, weights_only=True)
metadata = load_dict.pop("save_metadata")
print("Metadata\n--------")
print(
f"plenoptic version : {metadata['plenoptic_version']} "
f"(installed: {importlib.metadata.version('plenoptic')})"
)
print(
f"torch version : {metadata['torch_version']} "
f"(installed: {importlib.metadata.version('torch')})"
)
print(f"Saved object type : {metadata['synthesis_object']}")
print("\nCallables attributes\n--------------------")
callables = [
(k, v)
for k, v in load_dict.items()
if isinstance(v, tuple) and (isinstance(v[0], str) or v[0] is None)
]
pad_len = max([len(k[1:] if k.startswith("_") else k) for k, v in callables]) + 1
for k, v in callables:
display_k = k[1:] if k.startswith("_") else k
load_dict.pop(k)
# then this is state_dict attribute
if len(v) == 2:
print(f"{display_k:<{pad_len}}: {v[0]}")
# then this is an io attribute
else:
tensors, _ = _parse_save_io_attr_name(load_dict, v[1])
print(
f"{display_k:<{pad_len}}: {v[0]}, "
f"{[t.shape for t in tensors]} -> {v[2].shape}"
)
print("\nTensor attributes\n-----------------")
tensors = [(k, v) for k, v in load_dict.items() if isinstance(v, torch.Tensor)]
pad_len = max([len(k[1:] if k.startswith("_") else k) for k, v in tensors]) + 1
for k, v in tensors:
display_k = k[1:] if k.startswith("_") else k
load_dict.pop(k)
print(f"{display_k:<{pad_len}}: {v.dtype}, shape {v.shape}")
print("\nOther attributes\n----------------")
pad_len = max([len(k[1:] if k.startswith("_") else k) for k in load_dict]) + 1
for k, v in load_dict.items():
display_k = k[1:] if k.startswith("_") else k
if hasattr(v, "__len__") and not isinstance(v, str):
print(f"{display_k:<{pad_len}}: {type(v)} with length {len(v)}")
else:
print(f"{display_k:<{pad_len}}: {v}")