Coverage for tvo/exp/_experiments.py: 97%

154 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5from abc import ABC, abstractmethod 

6from tvo.utils.data import TVODataLoader 

7from tvo.utils.model_protocols import Trainable 

8from tvo.utils.parallel import ( 

9 pprint, 

10 init_processes, 

11 gather_from_processes, 

12 get_h5_dataset_to_processes, 

13) 

14from tvo.exp._utils import make_var_states 

15from tvo.utils import get, H5Logger 

16from tvo.trainer import Trainer 

17from tvo.exp._EStepConfig import EStepConfig 

18from tvo.exp._ExpConfig import ExpConfig 

19from tvo.exp._EpochLog import EpochLog 

20from tvo.variational import TVOVariationalStates 

21import tvo 

22 

23import math 

24from typing import Dict, Any, Generator 

25import torch as to 

26import torch.distributed as dist 

27import time 

28from pathlib import Path 

29import os 

30from munch import Munch 

31 

32 

33class Experiment(ABC): 

34 """Abstract base class for all experiments.""" 

35 

36 @abstractmethod 

37 def run(self, epochs: int) -> Generator[EpochLog, None, None]: 

38 pass # pragma: no cover 

39 

40 

41class _TrainingAndOrValidation(Experiment): 

42 def __init__( 

43 self, 

44 conf: ExpConfig, 

45 estep_conf: EStepConfig, 

46 model: Trainable, 

47 train_dataset: to.Tensor = None, 

48 test_dataset: to.Tensor = None, 

49 ): 

50 """Helper class to avoid code repetition between Training and Testing. 

51 

52 It performs training and/or validation/testings depending on what input is provided. 

53 """ 

54 H = sum(model.shape[1:]) 

55 self.model = model 

56 assert isinstance(model, Trainable) 

57 self._conf = Munch(conf.as_dict()) 

58 self._conf.model = type(model).__name__ 

59 self._conf.device = tvo.get_device().type 

60 self._estep_conf = Munch(estep_conf.as_dict()) 

61 self.train_data = None 

62 self.train_states = None 

63 self._precision = model.precision 

64 if train_dataset is not None: 

65 self.train_data = self._make_dataloader(train_dataset, conf) 

66 # might differ between processes: last process might have smaller N and less states 

67 # (but TVODataLoader+ShufflingSampler make sure the number of batches is the same) 

68 N = train_dataset.shape[0] 

69 self.train_states = self._make_states(N, H, self._precision, estep_conf) 

70 

71 self.test_data = None 

72 self.test_states = None 

73 if test_dataset is not None: 

74 self.test_data = self._make_dataloader(test_dataset, conf) 

75 N = test_dataset.shape[0] 

76 self.test_states = self._make_states(N, H, self._precision, estep_conf) 

77 

78 will_reconstruct = ( 

79 self._conf.reco_epochs is not None or self._conf.warmup_reco_epochs is not None 

80 ) 

81 self.trainer = Trainer( 

82 self.model, 

83 self.train_data, 

84 self.train_states, 

85 self.test_data, 

86 self.test_states, 

87 rollback_if_F_decreases=self._conf.rollback_if_F_decreases, 

88 will_reconstruct=will_reconstruct, 

89 eval_F_at_epoch_end=self._conf.eval_F_at_epoch_end, 

90 data_transform=self._conf.data_transform, 

91 ) 

92 self.logger = H5Logger(self._conf.output, blacklist=self._conf.log_blacklist) 

93 

94 def _make_dataloader(self, dataset: to.Tensor, conf: ExpConfig) -> TVODataLoader: 

95 if dataset.dtype is not to.uint8: 

96 dataset = dataset.to(dtype=self._precision) 

97 dataset = dataset.to(device=tvo.get_device()) 

98 return TVODataLoader( 

99 dataset, batch_size=conf.batch_size, shuffle=conf.shuffle, drop_last=conf.drop_last 

100 ) 

101 

102 def _make_states( 

103 self, N: int, H: int, precision: to.dtype, estep_conf: EStepConfig 

104 ) -> TVOVariationalStates: 

105 states = make_var_states(estep_conf, N, H, precision) 

106 return states 

107 

108 @property 

109 def config(self) -> Dict[str, Any]: 

110 return dict(self._conf) 

111 

112 @property 

113 def estep_config(self) -> Dict[str, Any]: 

114 return dict(self._estep_conf) 

115 

116 def run(self, epochs: int) -> Generator[EpochLog, None, None]: 

117 """Run training and/or testing. 

118 

119 :param epochs: Number of epochs to train for 

120 """ 

121 trainer = self.trainer 

122 logger = self.logger 

123 

124 self._log_confs(logger) 

125 

126 # warm-up E-steps 

127 if self._conf.warmup_Esteps > 0: 

128 pprint("Warm-up E-steps") 

129 for e in range(self._conf.warmup_Esteps): 

130 compute_reconstruction = ( 

131 self._conf.warmup_reco_epochs is not None and e in self._conf.warmup_reco_epochs 

132 ) 

133 d = trainer.e_step(compute_reconstruction) 

134 self._log_epoch(logger, d) 

135 

136 # log initial free energies (after warm-up E-steps if any) 

137 if self._conf.warmup_Esteps == 0: 

138 d = trainer.eval_free_energies() 

139 self._log_epoch(logger, d) 

140 yield EpochLog(epoch=0, results=d) 

141 

142 # EM steps 

143 for e in range(epochs): 

144 start_t = time.time() 

145 compute_reconstruction = ( 

146 self._conf.reco_epochs is not None and e in self._conf.reco_epochs 

147 ) 

148 d = trainer.em_step(compute_reconstruction) 

149 epoch_runtime = time.time() - start_t 

150 self._log_epoch(logger, d) 

151 yield EpochLog(e + 1, d, epoch_runtime) 

152 

153 # remove leftover ".old" logfiles produced by the logger 

154 rank = dist.get_rank() if dist.is_initialized() else 0 

155 leftover_logfile = self._conf.output + ".old" 

156 if rank == 0 and Path(leftover_logfile).is_file(): 

157 os.remove(leftover_logfile) 

158 

159 # put trainer into undefined state after the experiment is finished 

160 self.trainer = None # type: ignore 

161 

162 def _log_confs(self, logger: H5Logger): 

163 """Dump experiment+estep configuration to screen and save it to output file.""" 

164 titles = ["Experiment", "E-step"] 

165 confs = [self.config, self.estep_config] 

166 logger.set(exp_config=self.config) 

167 logger.set(estep_config=self.estep_config) 

168 

169 model_conf = self.model.config # could raise 

170 logger.set(model_config=model_conf) 

171 confs.append(model_conf) 

172 titles.append("Model") 

173 

174 for title, conf in zip(titles, confs): 

175 pprint(f"\n{title} configuration:") 

176 for k, v in conf.items(): 

177 pprint(f"\t{k:<20}: {v}") 

178 

179 def _log_epoch(self, logger: H5Logger, epoch_results: Dict[str, float]): 

180 """Log F, subs, model.theta, states.K and states.lpj to file, return printable log. 

181 

182 :param logger: the logger for this run 

183 :param epoch_results: dictionary returned by Trainer.e_step or Trainer.em_step 

184 """ 

185 for data_kind in "train", "test": 

186 if data_kind + "_F" not in epoch_results: 

187 continue 

188 

189 # log_kind is one of "train", "valid" or "test" 

190 # (while data_kind is one of "train" or "test") 

191 log_kind = "valid" if data_kind == "test" and self.train_data is not None else data_kind 

192 

193 # log F and subs to stdout and file 

194 F, subs = get(epoch_results, f"{data_kind}_F", f"{data_kind}_subs") 

195 assert not (math.isnan(F) or math.isinf(F)), f"{log_kind} free energy is invalid!" 

196 F_and_subs_dict = {f"{log_kind}_F": to.tensor(F), f"{log_kind}_subs": to.tensor(subs)} 

197 logger.append(**F_and_subs_dict) 

198 

199 # log latest states and lpj to file 

200 states = getattr(self, f"{data_kind}_states") 

201 if f"{log_kind}_states" not in self._conf.log_blacklist: 

202 K = gather_from_processes(states.K) 

203 logger.set(**{f"{log_kind}_states": K}) 

204 else: 

205 K = None 

206 if f"{log_kind}_lpj" not in self._conf.log_blacklist: 

207 logger.set(**{f"{log_kind}_lpj": gather_from_processes(states.lpj)}) 

208 

209 if self._conf.keep_best_states: 

210 best_F_name = f"best_{log_kind}_F" 

211 best_F = getattr(self, f"_{best_F_name}", None) 

212 if best_F is None or F > best_F: 

213 rank = dist.get_rank() if dist.is_initialized() else 0 

214 if K is None: 

215 K = gather_from_processes(states.K) 

216 if rank == 0: 

217 assert isinstance(K, to.Tensor) # to make mypy happy 

218 best_states_dict = { 

219 best_F_name: to.tensor(F), 

220 f"best_{log_kind}_states": K.cpu().clone(), 

221 } 

222 logger.set(**best_states_dict) 

223 setattr(self, f"_{best_F_name}", F) 

224 

225 # log data reconstructions 

226 reco_dict = {} 

227 if ( 

228 f"{log_kind}_reconstruction" not in self._conf.log_blacklist 

229 and f"{data_kind}_rec" in epoch_results 

230 ): 

231 reco_dict[f"{log_kind}_reconstruction"] = gather_from_processes( 

232 epoch_results[f"{data_kind}_rec"] 

233 ) 

234 logger.set(**reco_dict) 

235 

236 log_theta_fn = logger.set if self._conf.log_only_latest_theta else logger.append 

237 log_theta_fn(theta=self.model.theta) 

238 logger.write() 

239 

240 

241class Training(_TrainingAndOrValidation): 

242 def __init__( 

243 self, 

244 conf: ExpConfig, 

245 estep_conf: EStepConfig, 

246 model: Trainable, 

247 train_data_file: str, 

248 val_data_file: str = None, 

249 ): 

250 """Train model on given dataset for the given number of epochs. 

251 

252 :param conf: Experiment configuration. 

253 :param estep_conf: Instance of a class inheriting from EStepConfig. 

254 :param model: model to train 

255 :param train_data_file: Path to an HDF5 file containing the training dataset. 

256 Datasets with name "train_data" and "data" will be 

257 searched in the file, in this order. 

258 :param val_data_file: Path to an HDF5 file containing the training dataset. 

259 Datasets with name "val_data" and "data" will be searched in the file, 

260 in this order. 

261 

262 On the validation dataset, Training only performs E-steps without updating 

263 the model parameters. 

264 

265 .. _DataLoader docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 

266 """ 

267 if tvo.get_run_policy() == "mpi": 

268 init_processes() 

269 train_dataset = get_h5_dataset_to_processes(train_data_file, ("train_data", "data")) 

270 val_dataset = None 

271 if val_data_file is not None: 

272 val_dataset = get_h5_dataset_to_processes(val_data_file, ("val_data", "data")) 

273 

274 setattr(conf, "train_dataset", train_data_file) 

275 setattr(conf, "val_dataset", val_data_file) 

276 super().__init__(conf, estep_conf, model, train_dataset, val_dataset) 

277 

278 

279class Testing(_TrainingAndOrValidation): 

280 def __init__(self, conf: ExpConfig, estep_conf: EStepConfig, model: Trainable, data_file: str): 

281 """Test given model on given dataset for the given number of epochs. 

282 

283 :param conf: Experiment configuration. 

284 :param estep_conf: Instance of a class inheriting from EStepConfig. 

285 :param model: model to test 

286 :param data_file: Path to an HDF5 file containing the training dataset. Datasets with name 

287 "test_data" and "data" will be searched in the file, in this order. 

288 

289 Only E-steps are run. Model parameters are not updated. 

290 """ 

291 if tvo.get_run_policy() == "mpi": 

292 init_processes() 

293 dataset = get_h5_dataset_to_processes(data_file, ("test_data", "data")) 

294 

295 setattr(conf, "test_dataset", data_file) 

296 super().__init__(conf, estep_conf, model, None, dataset)