Coverage for tvo/utils/H5Logger.py: 84%
69 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch.distributed as dist
6import torch as to
7import h5py
8from typing import Union, Iterable, Dict, Any
9from os import path, rename
12def _append_to_dict(d: Dict[str, to.Tensor], k: str, t: to.Tensor):
13 """Append tensor t to dict d at key k."""
14 if k not in d:
15 # the extra 0-sized dimension will be used for concatenation
16 d[k] = to.empty((0, *t.shape))
17 assert d[k].shape[1:] == t.shape, f"variable {k} changed shape between appends"
18 d[k] = to.cat((d[k].to(t), t.unsqueeze(0)))
21class H5Logger:
22 def __init__(self, output: str, blacklist: Iterable[str] = [], verbose: bool = False):
23 """Utility class to iteratively write to HD5 files.
25 :param output: Output filename or file path. Overwritten if it already exists.
26 :param blacklist: Variables in `blacklist` are ignored and never get logged.
27 :param verbose: Whether to print variable names after appending/setting
29 If tvo.get_run_policy() is 'mpi', operations on H5Logger are no-op for all processes
30 except for the process with rank 0.
31 """
32 self._rank = dist.get_rank() if dist.is_initialized() else 0
33 self._fname = output
34 self._data: Dict[str, to.Tensor] = {}
35 self._blacklist = blacklist
36 self._verbose = verbose
38 def append(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]):
39 """Append arguments to log. Arguments can be torch.Tensors or dictionaries thereof.
41 The output HDF5 file will contain one dataset for each of the tensors and one group
42 for each of the dictionaries.
43 """
44 if self._rank != 0:
45 return
47 data = self._data
48 for k, v in kwargs.items():
49 if k in self._blacklist:
50 continue
52 if isinstance(v, to.Tensor):
53 _append_to_dict(data, k, v)
54 elif isinstance(v, dict):
55 if k not in data:
56 data[k] = {}
57 for name, tensor in v.items():
58 _append_to_dict(data[k], name, tensor)
59 else: # pragma: no cover
60 msg = (
61 "Arguments must be torch.Tensors or dictionaries thereof "
62 f"but '{k}' is {type(v)}."
63 )
64 raise TypeError(msg)
65 if self._verbose:
66 print(f"Appended {k} to {self._fname}")
68 def set(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]):
69 """Set or reset arguments to desired value in log.
71 Arguments can be torch.Tensors or dictionaries thereof.
72 The output HDF5 file will contain one dataset for each of the tensors and one group
73 for each of the dictionaries.
74 """
75 if self._rank != 0:
76 return
78 for k, v in kwargs.items():
79 if k in self._blacklist:
80 continue
82 if not isinstance(v, to.Tensor) and not isinstance(v, dict): # pragma: no cover
83 msg = (
84 "Arguments must be torch.Tensors or dictionaries thereof "
85 f"but '{k}' is {type(v)}."
86 )
87 raise TypeError(msg)
89 self._data[k] = v
91 if self._verbose:
92 print(f"Set {k} to {self._fname}")
94 def write(self) -> None:
95 """Write logged data to output file.
97 If a file with this name already exists (e.g. because of a previous call to this method)
98 the old file is renamed to `<fname>.old`.
99 """
100 if self._rank != 0:
101 return
103 fname = self._fname
105 if path.exists(fname):
106 rename(fname, fname + ".old")
108 with h5py.File(fname, "w") as f:
109 for k, v in self._data.items():
110 H5Logger._write_one(f, k, v)
112 @staticmethod
113 def _write_one(f: h5py.Group, key: str, value: Any) -> None:
114 if isinstance(value, to.Tensor):
115 f.create_dataset(key, data=value.detach().cpu())
116 elif isinstance(value, dict):
117 g = f.create_group(key)
118 for k, v in value.items():
119 H5Logger._write_one(g, k, v)
120 else:
121 try:
122 f.create_dataset(key, data=value)
123 except TypeError:
124 f.create_dataset(key, data=str(value))
126 def append_and_write(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]):
127 """Jointly append and write arguments. See docs of `append` and `write`."""
128 self.append(**kwargs)
129 self.write()
131 def set_and_write(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]):
132 """Jointly set and write arguments. See docs of `set` and `write`."""
133 self.set(**kwargs)
134 self.write()