Coverage for tvo/trainer/Trainer.py: 98%
180 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import tvo
6from tvo.utils.model_protocols import Trainable, Optimized, Reconstructor
7from tvo.variational import TVOVariationalStates
8from tvo.utils.data import TVODataLoader
9from tvo.utils.parallel import all_reduce
10from typing import Dict, Any, Sequence, Union, Callable
11import torch as to
14class Trainer:
15 def __init__(
16 self,
17 model: Trainable,
18 train_data: Union[TVODataLoader, to.Tensor] = None,
19 train_states: TVOVariationalStates = None,
20 test_data: Union[TVODataLoader, to.Tensor] = None,
21 test_states: TVOVariationalStates = None,
22 rollback_if_F_decreases: Sequence[str] = [],
23 will_reconstruct: bool = False,
24 eval_F_at_epoch_end: bool = False,
25 data_transform: Callable[[to.Tensor], to.Tensor] = None,
26 ):
27 """Train and/or test a given model.
29 :param model: an object of a concrete type satisfying the Trainable protocol
30 :param train_data: the contained dataset should have shape (N,D)
31 :param train_states: TVOVariationalStates with shape (N,S,H)
32 :param test_data: validation or test dataset. The contained dataset should have shape (M,D)
33 :param test_states: TVOVariationalStates with shape (M,Z,H)
34 :param rollback_if_F_decreases: see ExpConfig docs
35 :param will_reconstruct: True if data will be reconstructed by the Trainer
36 :param eval_F_at_epoch_end: By default, the trainer evaluates the model free energy batch
37 by batch, accumulating the values over the course of the epoch.
38 If this option is set to `True`, the free energy will be
39 evaluated at the end of an epoch instead.
40 :param data_transform: A transformation to be applied to datapoints before they are passed
41 to the model for training/evaluation.
43 Both train_data and train_states must be provided, or neither.
44 The same holds for test_data and test_states.
45 At least one of these two pairs of arguments must be present.
47 Training steps on test_data only perform E-steps, i.e. model parameters are
48 not updated but test_states are. Therefore test_data can also be used for validation.
49 """
50 for data, states in ((train_data, train_states), (test_data, test_states)):
51 assert (data is not None) == (
52 states is not None
53 ), "Please provide both dataset and variational states, or neither"
54 train_data = TVODataLoader(train_data) if isinstance(train_data, to.Tensor) else train_data
55 test_data = TVODataLoader(test_data) if isinstance(test_data, to.Tensor) else test_data
56 self.can_train = train_data is not None and train_states is not None
57 self.can_test = test_data is not None and test_states is not None
58 if not self.can_train and not self.can_test: # pragma: no cover
59 raise RuntimeError("Please provide at least one pair of dataset and variational states")
61 _d, _s = (train_data, train_states) if self.can_train else (test_data, test_states)
62 assert _d is not None and _s is not None
63 if isinstance(model, Optimized):
64 model.init_storage(_s.config["S"], _s.config["S_new"], _d.batch_size)
66 self.model = model
67 self.train_data = train_data
68 self.train_states = train_states
69 self.test_data = test_data
70 self.test_states = test_states
71 self.will_reconstruct = will_reconstruct
72 self.eval_F_at_epoch_end = eval_F_at_epoch_end
73 if train_data is not None:
74 self.N_train = to.tensor(len(train_data.dataset))
75 all_reduce(self.N_train)
76 self.N_train = self.N_train.item()
77 if self.will_reconstruct:
78 self.train_reconstruction = train_data.dataset.tensors[1].clone()
79 if test_data is not None:
80 self.N_test = to.tensor(len(test_data.dataset))
81 all_reduce(self.N_test)
82 self.N_test = self.N_test.item()
83 if self.will_reconstruct:
84 self.test_reconstruction = test_data.dataset.tensors[1].clone()
85 self._to_rollback = rollback_if_F_decreases
86 self.data_transform = data_transform if data_transform is not None else lambda x: x
88 @staticmethod
89 def _do_e_step(
90 data: TVODataLoader,
91 states: TVOVariationalStates,
92 model: Trainable,
93 N: int,
94 data_transform,
95 reconstruction: to.Tensor = None,
96 ):
97 if reconstruction is not None and not isinstance(model, Reconstructor):
98 raise NotImplementedError(
99 f"reconstruction not implemented for model {type(model).__name__}"
100 )
101 F = to.tensor(0.0)
102 subs = to.tensor(0)
103 if isinstance(model, Optimized):
104 model.init_epoch()
105 for idx, batch in data:
106 batch = data_transform(batch)
107 if isinstance(model, Optimized):
108 model.init_batch()
109 subs += states.update(idx, batch, model)
110 F += model.free_energy(idx, batch, states)
111 if reconstruction is not None:
112 # full data estimation
113 reconstruction[idx] = model.data_estimator(idx, batch, states) # type: ignore
114 all_reduce(F)
115 all_reduce(subs)
116 return F.item() / N, subs.item() / N, reconstruction
118 def e_step(self, compute_reconstruction: bool = False) -> Dict[str, Any]:
119 """Run one epoch of E-steps on training and/or test data, depending on what is available.
121 Only E-steps are executed.
123 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs'
124 (keys might be missing depending on what is available)
125 """
126 ret = {}
127 model = self.model
128 train_data, train_states = self.train_data, self.train_states
129 test_data, test_states = self.test_data, self.test_states
130 train_reconstruction = (
131 self.train_reconstruction
132 if (compute_reconstruction and hasattr(self, "train_reconstruction"))
133 else None
134 )
135 test_reconstruction = (
136 self.test_reconstruction
137 if (compute_reconstruction and hasattr(self, "test_reconstruction"))
138 else None
139 )
141 # Training #
142 if self.can_train:
143 assert train_data is not None and train_states is not None # to make mypy happy
144 ret["train_F"], ret["train_subs"], train_rec = self._do_e_step(
145 train_data,
146 train_states,
147 model,
148 self.N_train,
149 self.data_transform,
150 train_reconstruction,
151 )
152 if train_rec is not None:
153 ret["train_rec"] = train_rec
155 # Validation/Testing #
156 if self.can_test:
157 assert test_data is not None and test_states is not None # to make mypy happy
158 ret["test_F"], ret["test_subs"], test_rec = self._do_e_step(
159 test_data, test_states, model, self.N_test, self.data_transform, test_reconstruction
160 )
161 if test_rec is not None:
162 ret["test_rec"] = test_rec
164 return ret
166 def em_step(self, compute_reconstruction: bool = False) -> Dict[str, Any]:
167 """Run one training and/or test epoch, depending on what data is available.
169 Both E-step and M-step are executed. Eventually reconstructions are computed intermediately.
171 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs'
172 (keys might be missing depending on what is available). The free energy values
173 are calculated per batch, so if the model updates its parameters in
174 `update_param_epoch`, the free energies reported at epoch X are calculated
175 using the weights of epoch X-1.
176 """
177 # NOTE:
178 # For models that update the parameters in update_param_epoch, the free energy reported at
179 # each epoch is the one after the E-step and before the M-step (K sets of epoch X and
180 # \Theta of epoch X-1 yield free energy of epoch X).
181 # For models that update the parameters in update_param_batch, the free energy reported
182 # at each epoch does not correspond to a fixed set of parameters: each batch had a
183 # different set of parameters and the reported free energy is more of an average of the
184 # free energies yielded by all the sets of parameters spanned during an epoch.
186 ret_dict = {}
188 # Training #
189 if self.can_train:
190 F, subs, reco = self._train_epoch(compute_reconstruction)
191 all_reduce(F)
192 ret_dict["train_F"] = F.item() / self.N_train
193 all_reduce(subs)
194 ret_dict["train_subs"] = subs.item() / self.N_train
195 if reco is not None:
196 ret_dict["train_rec"] = reco
198 # Validation/Testing #
199 if self.can_test:
200 test_data, test_states, test_reconstruction = (
201 self.test_data,
202 self.test_states,
203 self.test_reconstruction
204 if (compute_reconstruction and hasattr(self, "test_reconstruction"))
205 else None,
206 )
207 model = self.model
209 assert test_data is not None and test_states is not None # to make mypy happy
210 res = self._do_e_step(
211 test_data, test_states, model, self.N_test, self.data_transform, test_reconstruction
212 )
213 ret_dict["test_F"], ret_dict["test_subs"], test_rec = res
214 if test_reconstruction is not None:
215 ret_dict["test_rec"] = test_reconstruction
217 return ret_dict
219 def _train_epoch(self, compute_reconstruction: bool):
220 model = self.model
221 train_data, train_states, train_reconstruction = (
222 self.train_data,
223 self.train_states,
224 self.train_reconstruction
225 if (compute_reconstruction and hasattr(self, "train_reconstruction"))
226 else None,
227 )
229 assert train_data is not None and train_states is not None # to make mypy happy
230 F = to.tensor(0.0, device=tvo.get_device())
231 subs = to.tensor(0)
232 if isinstance(model, Optimized):
233 model.init_epoch()
234 for idx, batch in train_data:
235 batch = self.data_transform(batch)
236 if isinstance(model, Optimized):
237 model.init_batch()
238 with to.no_grad():
239 subs += train_states.update(idx, batch, model)
240 if train_reconstruction is not None:
241 assert isinstance(model, Reconstructor)
242 train_reconstruction[idx] = model.data_estimator(
243 idx, batch, train_states
244 ) # full data estimation
245 if to.isnan(batch).any():
246 missing_data_mask = to.isnan(batch)
247 batch[missing_data_mask] = train_reconstruction[idx][missing_data_mask]
248 train_reconstruction[idx] = batch
249 batch_F = model.update_param_batch(idx, batch, train_states)
250 if not self.eval_F_at_epoch_end:
251 if batch_F is None:
252 batch_F = model.free_energy(idx, batch, train_states)
253 F += batch_F
254 self._update_parameters_with_rollback()
255 return F, subs, train_reconstruction
257 def eval_free_energies(self) -> Dict[str, Any]:
258 """Return a dictionary with the same contents as e_step/em_step, without training the model.
260 :returns: a dictionary containing 'train_F', 'train_subs', 'test_F', 'test_subs'
261 (keys might be missing depending on what is available)
262 """
263 m = self.model
264 train_data, train_states = self.train_data, self.train_states
265 test_data, test_states = self.test_data, self.test_states
266 lpj_fn = m.log_pseudo_joint if isinstance(m, Optimized) else m.log_joint
267 ret = {}
269 if self.can_train:
270 assert train_data is not None and train_states is not None # to make mypy happy
271 F = to.tensor(0.0)
272 if isinstance(m, Optimized):
273 m.init_epoch()
274 for idx, batch in train_data:
275 batch = self.data_transform(batch)
276 if isinstance(m, Optimized):
277 m.init_batch()
278 train_states.lpj[idx] = lpj_fn(batch, train_states.K[idx])
279 F += m.free_energy(idx, batch, train_states)
280 all_reduce(F)
281 ret["train_F"] = F.item() / self.N_train
282 ret["train_subs"] = 0
284 if self.can_test:
285 assert test_data is not None and test_states is not None # to make mypy happy
286 F = to.tensor(0.0)
287 if isinstance(m, Optimized):
288 m.init_epoch()
289 for idx, batch in test_data:
290 batch = self.data_transform(batch)
291 if isinstance(m, Optimized):
292 m.init_batch()
293 test_states.lpj[idx] = lpj_fn(batch, test_states.K[idx])
294 F += m.free_energy(idx, batch, test_states)
295 all_reduce(F)
296 ret["test_F"] = F.item() / self.N_test
297 ret["test_subs"] = 0
299 return ret
301 def _update_parameters_with_rollback(self) -> None:
302 """Update model parameters calling `update_param_epoch`, roll back if F decreases."""
304 if len(self._to_rollback) == 0:
305 # nothing to rollback, fall back to simple parameter update
306 self.model.update_param_epoch()
307 return
309 m = self.model
310 lpj_fn = m.log_pseudo_joint if isinstance(m, Optimized) else m.log_joint
312 assert self.train_data is not None and self.train_states is not None # to make mypy happy
313 all_data = self.train_data.dataset.tensors[1]
314 states = self.train_states
316 old_params = {p: m.theta[p].clone() for p in self._to_rollback}
317 old_F = m.free_energy(idx=to.arange(all_data.shape[0]), batch=all_data, states=states)
318 all_reduce(old_F)
319 old_lpj = states.lpj.clone()
320 m.update_param_epoch()
321 states.lpj[:] = lpj_fn(all_data, states.K)
322 new_F = m.free_energy(idx=to.arange(all_data.shape[0]), batch=all_data, states=states)
323 all_reduce(new_F)
324 if new_F < old_F:
325 for p in self._to_rollback:
326 m.theta[p][:] = old_params[p]
327 states.lpj[:] = old_lpj