Coverage for tvo/utils/H5Logger.py: 84%

69 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch.distributed as dist 

6import torch as to 

7import h5py 

8from typing import Union, Iterable, Dict, Any 

9from os import path, rename 

10 

11 

12def _append_to_dict(d: Dict[str, to.Tensor], k: str, t: to.Tensor): 

13 """Append tensor t to dict d at key k.""" 

14 if k not in d: 

15 # the extra 0-sized dimension will be used for concatenation 

16 d[k] = to.empty((0, *t.shape)) 

17 assert d[k].shape[1:] == t.shape, f"variable {k} changed shape between appends" 

18 d[k] = to.cat((d[k].to(t), t.unsqueeze(0))) 

19 

20 

21class H5Logger: 

22 def __init__(self, output: str, blacklist: Iterable[str] = [], verbose: bool = False): 

23 """Utility class to iteratively write to HD5 files. 

24 

25 :param output: Output filename or file path. Overwritten if it already exists. 

26 :param blacklist: Variables in `blacklist` are ignored and never get logged. 

27 :param verbose: Whether to print variable names after appending/setting 

28 

29 If tvo.get_run_policy() is 'mpi', operations on H5Logger are no-op for all processes 

30 except for the process with rank 0. 

31 """ 

32 self._rank = dist.get_rank() if dist.is_initialized() else 0 

33 self._fname = output 

34 self._data: Dict[str, to.Tensor] = {} 

35 self._blacklist = blacklist 

36 self._verbose = verbose 

37 

38 def append(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]): 

39 """Append arguments to log. Arguments can be torch.Tensors or dictionaries thereof. 

40 

41 The output HDF5 file will contain one dataset for each of the tensors and one group 

42 for each of the dictionaries. 

43 """ 

44 if self._rank != 0: 

45 return 

46 

47 data = self._data 

48 for k, v in kwargs.items(): 

49 if k in self._blacklist: 

50 continue 

51 

52 if isinstance(v, to.Tensor): 

53 _append_to_dict(data, k, v) 

54 elif isinstance(v, dict): 

55 if k not in data: 

56 data[k] = {} 

57 for name, tensor in v.items(): 

58 _append_to_dict(data[k], name, tensor) 

59 else: # pragma: no cover 

60 msg = ( 

61 "Arguments must be torch.Tensors or dictionaries thereof " 

62 f"but '{k}' is {type(v)}." 

63 ) 

64 raise TypeError(msg) 

65 if self._verbose: 

66 print(f"Appended {k} to {self._fname}") 

67 

68 def set(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]): 

69 """Set or reset arguments to desired value in log. 

70 

71 Arguments can be torch.Tensors or dictionaries thereof. 

72 The output HDF5 file will contain one dataset for each of the tensors and one group 

73 for each of the dictionaries. 

74 """ 

75 if self._rank != 0: 

76 return 

77 

78 for k, v in kwargs.items(): 

79 if k in self._blacklist: 

80 continue 

81 

82 if not isinstance(v, to.Tensor) and not isinstance(v, dict): # pragma: no cover 

83 msg = ( 

84 "Arguments must be torch.Tensors or dictionaries thereof " 

85 f"but '{k}' is {type(v)}." 

86 ) 

87 raise TypeError(msg) 

88 

89 self._data[k] = v 

90 

91 if self._verbose: 

92 print(f"Set {k} to {self._fname}") 

93 

94 def write(self) -> None: 

95 """Write logged data to output file. 

96 

97 If a file with this name already exists (e.g. because of a previous call to this method) 

98 the old file is renamed to `<fname>.old`. 

99 """ 

100 if self._rank != 0: 

101 return 

102 

103 fname = self._fname 

104 

105 if path.exists(fname): 

106 rename(fname, fname + ".old") 

107 

108 with h5py.File(fname, "w") as f: 

109 for k, v in self._data.items(): 

110 H5Logger._write_one(f, k, v) 

111 

112 @staticmethod 

113 def _write_one(f: h5py.Group, key: str, value: Any) -> None: 

114 if isinstance(value, to.Tensor): 

115 f.create_dataset(key, data=value.detach().cpu()) 

116 elif isinstance(value, dict): 

117 g = f.create_group(key) 

118 for k, v in value.items(): 

119 H5Logger._write_one(g, k, v) 

120 else: 

121 try: 

122 f.create_dataset(key, data=value) 

123 except TypeError: 

124 f.create_dataset(key, data=str(value)) 

125 

126 def append_and_write(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]): 

127 """Jointly append and write arguments. See docs of `append` and `write`.""" 

128 self.append(**kwargs) 

129 self.write() 

130 

131 def set_and_write(self, **kwargs: Union[to.Tensor, Dict[str, to.Tensor]]): 

132 """Jointly set and write arguments. See docs of `set` and `write`.""" 

133 self.set(**kwargs) 

134 self.write()