Coverage for tvo/trainer/Trainer.py: 98%

180 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import tvo 

6from tvo.utils.model_protocols import Trainable, Optimized, Reconstructor 

7from tvo.variational import TVOVariationalStates 

8from tvo.utils.data import TVODataLoader 

9from tvo.utils.parallel import all_reduce 

10from typing import Dict, Any, Sequence, Union, Callable 

11import torch as to 

12 

13 

14class Trainer: 

15 def __init__( 

16 self, 

17 model: Trainable, 

18 train_data: Union[TVODataLoader, to.Tensor] = None, 

19 train_states: TVOVariationalStates = None, 

20 test_data: Union[TVODataLoader, to.Tensor] = None, 

21 test_states: TVOVariationalStates = None, 

22 rollback_if_F_decreases: Sequence[str] = [], 

23 will_reconstruct: bool = False, 

24 eval_F_at_epoch_end: bool = False, 

25 data_transform: Callable[[to.Tensor], to.Tensor] = None, 

26 ): 

27 """Train and/or test a given model. 

28 

29 :param model: an object of a concrete type satisfying the Trainable protocol 

30 :param train_data: the contained dataset should have shape (N,D) 

31 :param train_states: TVOVariationalStates with shape (N,S,H) 

32 :param test_data: validation or test dataset. The contained dataset should have shape (M,D) 

33 :param test_states: TVOVariationalStates with shape (M,Z,H) 

34 :param rollback_if_F_decreases: see ExpConfig docs 

35 :param will_reconstruct: True if data will be reconstructed by the Trainer 

36 :param eval_F_at_epoch_end: By default, the trainer evaluates the model free energy batch 

37 by batch, accumulating the values over the course of the epoch. 

38 If this option is set to `True`, the free energy will be 

39 evaluated at the end of an epoch instead. 

40 :param data_transform: A transformation to be applied to datapoints before they are passed 

41 to the model for training/evaluation. 

42 

43 Both train_data and train_states must be provided, or neither. 

44 The same holds for test_data and test_states. 

45 At least one of these two pairs of arguments must be present. 

46 

47 Training steps on test_data only perform E-steps, i.e. model parameters are 

48 not updated but test_states are. Therefore test_data can also be used for validation. 

49 """ 

50 for data, states in ((train_data, train_states), (test_data, test_states)): 

51 assert (data is not None) == ( 

52 states is not None 

53 ), "Please provide both dataset and variational states, or neither" 

54 train_data = TVODataLoader(train_data) if isinstance(train_data, to.Tensor) else train_data 

55 test_data = TVODataLoader(test_data) if isinstance(test_data, to.Tensor) else test_data 

56 self.can_train = train_data is not None and train_states is not None 

57 self.can_test = test_data is not None and test_states is not None 

58 if not self.can_train and not self.can_test: # pragma: no cover 

59 raise RuntimeError("Please provide at least one pair of dataset and variational states") 

60 

61 _d, _s = (train_data, train_states) if self.can_train else (test_data, test_states) 

62 assert _d is not None and _s is not None 

63 if isinstance(model, Optimized): 

64 model.init_storage(_s.config["S"], _s.config["S_new"], _d.batch_size) 

65 

66 self.model = model 

67 self.train_data = train_data 

68 self.train_states = train_states 

69 self.test_data = test_data 

70 self.test_states = test_states 

71 self.will_reconstruct = will_reconstruct 

72 self.eval_F_at_epoch_end = eval_F_at_epoch_end 

73 if train_data is not None: 

74 self.N_train = to.tensor(len(train_data.dataset)) 

75 all_reduce(self.N_train) 

76 self.N_train = self.N_train.item() 

77 if self.will_reconstruct: 

78 self.train_reconstruction = train_data.dataset.tensors[1].clone() 

79 if test_data is not None: 

80 self.N_test = to.tensor(len(test_data.dataset)) 

81 all_reduce(self.N_test) 

82 self.N_test = self.N_test.item() 

83 if self.will_reconstruct: 

84 self.test_reconstruction = test_data.dataset.tensors[1].clone() 

85 self._to_rollback = rollback_if_F_decreases 

86 self.data_transform = data_transform if data_transform is not None else lambda x: x 

87 

88 @staticmethod 

89 def _do_e_step( 

90 data: TVODataLoader, 

91 states: TVOVariationalStates, 

92 model: Trainable, 

93 N: int, 

94 data_transform, 

95 reconstruction: to.Tensor = None, 

96 ): 

97 if reconstruction is not None and not isinstance(model, Reconstructor): 

98 raise NotImplementedError( 

99 f"reconstruction not implemented for model {type(model).__name__}" 

100 ) 

101 F = to.tensor(0.0) 

102 subs = to.tensor(0) 

103 if isinstance(model, Optimized): 

104 model.init_epoch() 

105 for idx, batch in data: 

106 batch = data_transform(batch) 

107 if isinstance(model, Optimized): 

108 model.init_batch() 

109 subs += states.update(idx, batch, model) 

110 F += model.free_energy(idx, batch, states) 

111 if reconstruction is not None: 

112 # full data estimation 

113 reconstruction[idx] = model.data_estimator(idx, batch, states) # type: ignore 

114 all_reduce(F) 

115 all_reduce(subs) 

116 return F.item() / N, subs.item() / N, reconstruction 

117 

118 def e_step(self, compute_reconstruction: bool = False) -> Dict[str, Any]: 

119 """Run one epoch of E-steps on training and/or test data, depending on what is available. 

120 

121 Only E-steps are executed. 

122 

123 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs' 

124 (keys might be missing depending on what is available) 

125 """ 

126 ret = {} 

127 model = self.model 

128 train_data, train_states = self.train_data, self.train_states 

129 test_data, test_states = self.test_data, self.test_states 

130 train_reconstruction = ( 

131 self.train_reconstruction 

132 if (compute_reconstruction and hasattr(self, "train_reconstruction")) 

133 else None 

134 ) 

135 test_reconstruction = ( 

136 self.test_reconstruction 

137 if (compute_reconstruction and hasattr(self, "test_reconstruction")) 

138 else None 

139 ) 

140 

141 # Training # 

142 if self.can_train: 

143 assert train_data is not None and train_states is not None # to make mypy happy 

144 ret["train_F"], ret["train_subs"], train_rec = self._do_e_step( 

145 train_data, 

146 train_states, 

147 model, 

148 self.N_train, 

149 self.data_transform, 

150 train_reconstruction, 

151 ) 

152 if train_rec is not None: 

153 ret["train_rec"] = train_rec 

154 

155 # Validation/Testing # 

156 if self.can_test: 

157 assert test_data is not None and test_states is not None # to make mypy happy 

158 ret["test_F"], ret["test_subs"], test_rec = self._do_e_step( 

159 test_data, test_states, model, self.N_test, self.data_transform, test_reconstruction 

160 ) 

161 if test_rec is not None: 

162 ret["test_rec"] = test_rec 

163 

164 return ret 

165 

166 def em_step(self, compute_reconstruction: bool = False) -> Dict[str, Any]: 

167 """Run one training and/or test epoch, depending on what data is available. 

168 

169 Both E-step and M-step are executed. Eventually reconstructions are computed intermediately. 

170 

171 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs' 

172 (keys might be missing depending on what is available). The free energy values 

173 are calculated per batch, so if the model updates its parameters in 

174 `update_param_epoch`, the free energies reported at epoch X are calculated 

175 using the weights of epoch X-1. 

176 """ 

177 # NOTE: 

178 # For models that update the parameters in update_param_epoch, the free energy reported at 

179 # each epoch is the one after the E-step and before the M-step (K sets of epoch X and 

180 # \Theta of epoch X-1 yield free energy of epoch X). 

181 # For models that update the parameters in update_param_batch, the free energy reported 

182 # at each epoch does not correspond to a fixed set of parameters: each batch had a 

183 # different set of parameters and the reported free energy is more of an average of the 

184 # free energies yielded by all the sets of parameters spanned during an epoch. 

185 

186 ret_dict = {} 

187 

188 # Training # 

189 if self.can_train: 

190 F, subs, reco = self._train_epoch(compute_reconstruction) 

191 all_reduce(F) 

192 ret_dict["train_F"] = F.item() / self.N_train 

193 all_reduce(subs) 

194 ret_dict["train_subs"] = subs.item() / self.N_train 

195 if reco is not None: 

196 ret_dict["train_rec"] = reco 

197 

198 # Validation/Testing # 

199 if self.can_test: 

200 test_data, test_states, test_reconstruction = ( 

201 self.test_data, 

202 self.test_states, 

203 self.test_reconstruction 

204 if (compute_reconstruction and hasattr(self, "test_reconstruction")) 

205 else None, 

206 ) 

207 model = self.model 

208 

209 assert test_data is not None and test_states is not None # to make mypy happy 

210 res = self._do_e_step( 

211 test_data, test_states, model, self.N_test, self.data_transform, test_reconstruction 

212 ) 

213 ret_dict["test_F"], ret_dict["test_subs"], test_rec = res 

214 if test_reconstruction is not None: 

215 ret_dict["test_rec"] = test_reconstruction 

216 

217 return ret_dict 

218 

219 def _train_epoch(self, compute_reconstruction: bool): 

220 model = self.model 

221 train_data, train_states, train_reconstruction = ( 

222 self.train_data, 

223 self.train_states, 

224 self.train_reconstruction 

225 if (compute_reconstruction and hasattr(self, "train_reconstruction")) 

226 else None, 

227 ) 

228 

229 assert train_data is not None and train_states is not None # to make mypy happy 

230 F = to.tensor(0.0, device=tvo.get_device()) 

231 subs = to.tensor(0) 

232 if isinstance(model, Optimized): 

233 model.init_epoch() 

234 for idx, batch in train_data: 

235 batch = self.data_transform(batch) 

236 if isinstance(model, Optimized): 

237 model.init_batch() 

238 with to.no_grad(): 

239 subs += train_states.update(idx, batch, model) 

240 if train_reconstruction is not None: 

241 assert isinstance(model, Reconstructor) 

242 train_reconstruction[idx] = model.data_estimator( 

243 idx, batch, train_states 

244 ) # full data estimation 

245 if to.isnan(batch).any(): 

246 missing_data_mask = to.isnan(batch) 

247 batch[missing_data_mask] = train_reconstruction[idx][missing_data_mask] 

248 train_reconstruction[idx] = batch 

249 batch_F = model.update_param_batch(idx, batch, train_states) 

250 if not self.eval_F_at_epoch_end: 

251 if batch_F is None: 

252 batch_F = model.free_energy(idx, batch, train_states) 

253 F += batch_F 

254 self._update_parameters_with_rollback() 

255 return F, subs, train_reconstruction 

256 

257 def eval_free_energies(self) -> Dict[str, Any]: 

258 """Return a dictionary with the same contents as e_step/em_step, without training the model. 

259 

260 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs' 

261 (keys might be missing depending on what is available) 

262 """ 

263 m = self.model 

264 train_data, train_states = self.train_data, self.train_states 

265 test_data, test_states = self.test_data, self.test_states 

266 lpj_fn = m.log_pseudo_joint if isinstance(m, Optimized) else m.log_joint 

267 ret = {} 

268 

269 if self.can_train: 

270 assert train_data is not None and train_states is not None # to make mypy happy 

271 F = to.tensor(0.0) 

272 if isinstance(m, Optimized): 

273 m.init_epoch() 

274 for idx, batch in train_data: 

275 batch = self.data_transform(batch) 

276 if isinstance(m, Optimized): 

277 m.init_batch() 

278 train_states.lpj[idx] = lpj_fn(batch, train_states.K[idx]) 

279 F += m.free_energy(idx, batch, train_states) 

280 all_reduce(F) 

281 ret["train_F"] = F.item() / self.N_train 

282 ret["train_subs"] = 0 

283 

284 if self.can_test: 

285 assert test_data is not None and test_states is not None # to make mypy happy 

286 F = to.tensor(0.0) 

287 if isinstance(m, Optimized): 

288 m.init_epoch() 

289 for idx, batch in test_data: 

290 batch = self.data_transform(batch) 

291 if isinstance(m, Optimized): 

292 m.init_batch() 

293 test_states.lpj[idx] = lpj_fn(batch, test_states.K[idx]) 

294 F += m.free_energy(idx, batch, test_states) 

295 all_reduce(F) 

296 ret["test_F"] = F.item() / self.N_test 

297 ret["test_subs"] = 0 

298 

299 return ret 

300 

301 def _update_parameters_with_rollback(self) -> None: 

302 """Update model parameters calling `update_param_epoch`, roll back if F decreases.""" 

303 

304 if len(self._to_rollback) == 0: 

305 # nothing to rollback, fall back to simple parameter update 

306 self.model.update_param_epoch() 

307 return 

308 

309 m = self.model 

310 lpj_fn = m.log_pseudo_joint if isinstance(m, Optimized) else m.log_joint 

311 

312 assert self.train_data is not None and self.train_states is not None # to make mypy happy 

313 all_data = self.train_data.dataset.tensors[1] 

314 states = self.train_states 

315 

316 old_params = {p: m.theta[p].clone() for p in self._to_rollback} 

317 old_F = m.free_energy(idx=to.arange(all_data.shape[0]), batch=all_data, states=states) 

318 all_reduce(old_F) 

319 old_lpj = states.lpj.clone() 

320 m.update_param_epoch() 

321 states.lpj[:] = lpj_fn(all_data, states.K) 

322 new_F = m.free_energy(idx=to.arange(all_data.shape[0]), batch=all_data, states=states) 

323 all_reduce(new_F) 

324 if new_F < old_F: 

325 for p in self._to_rollback: 

326 m.theta[p][:] = old_params[p] 

327 states.lpj[:] = old_lpj