Coverage for tvo/exp/_experiments.py: 97%
154 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5from abc import ABC, abstractmethod
6from tvo.utils.data import TVODataLoader
7from tvo.utils.model_protocols import Trainable
8from tvo.utils.parallel import (
9 pprint,
10 init_processes,
11 gather_from_processes,
12 get_h5_dataset_to_processes,
13)
14from tvo.exp._utils import make_var_states
15from tvo.utils import get, H5Logger
16from tvo.trainer import Trainer
17from tvo.exp._EStepConfig import EStepConfig
18from tvo.exp._ExpConfig import ExpConfig
19from tvo.exp._EpochLog import EpochLog
20from tvo.variational import TVOVariationalStates
21import tvo
23import math
24from typing import Dict, Any, Generator
25import torch as to
26import torch.distributed as dist
27import time
28from pathlib import Path
29import os
30from munch import Munch
33class Experiment(ABC):
34 """Abstract base class for all experiments."""
36 @abstractmethod
37 def run(self, epochs: int) -> Generator[EpochLog, None, None]:
38 pass # pragma: no cover
41class _TrainingAndOrValidation(Experiment):
42 def __init__(
43 self,
44 conf: ExpConfig,
45 estep_conf: EStepConfig,
46 model: Trainable,
47 train_dataset: to.Tensor = None,
48 test_dataset: to.Tensor = None,
49 ):
50 """Helper class to avoid code repetition between Training and Testing.
52 It performs training and/or validation/testings depending on what input is provided.
53 """
54 H = sum(model.shape[1:])
55 self.model = model
56 assert isinstance(model, Trainable)
57 self._conf = Munch(conf.as_dict())
58 self._conf.model = type(model).__name__
59 self._conf.device = tvo.get_device().type
60 self._estep_conf = Munch(estep_conf.as_dict())
61 self.train_data = None
62 self.train_states = None
63 self._precision = model.precision
64 if train_dataset is not None:
65 self.train_data = self._make_dataloader(train_dataset, conf)
66 # might differ between processes: last process might have smaller N and less states
67 # (but TVODataLoader+ShufflingSampler make sure the number of batches is the same)
68 N = train_dataset.shape[0]
69 self.train_states = self._make_states(N, H, self._precision, estep_conf)
71 self.test_data = None
72 self.test_states = None
73 if test_dataset is not None:
74 self.test_data = self._make_dataloader(test_dataset, conf)
75 N = test_dataset.shape[0]
76 self.test_states = self._make_states(N, H, self._precision, estep_conf)
78 will_reconstruct = (
79 self._conf.reco_epochs is not None or self._conf.warmup_reco_epochs is not None
80 )
81 self.trainer = Trainer(
82 self.model,
83 self.train_data,
84 self.train_states,
85 self.test_data,
86 self.test_states,
87 rollback_if_F_decreases=self._conf.rollback_if_F_decreases,
88 will_reconstruct=will_reconstruct,
89 eval_F_at_epoch_end=self._conf.eval_F_at_epoch_end,
90 data_transform=self._conf.data_transform,
91 )
92 self.logger = H5Logger(self._conf.output, blacklist=self._conf.log_blacklist)
94 def _make_dataloader(self, dataset: to.Tensor, conf: ExpConfig) -> TVODataLoader:
95 if dataset.dtype is not to.uint8:
96 dataset = dataset.to(dtype=self._precision)
97 dataset = dataset.to(device=tvo.get_device())
98 return TVODataLoader(
99 dataset, batch_size=conf.batch_size, shuffle=conf.shuffle, drop_last=conf.drop_last
100 )
102 def _make_states(
103 self, N: int, H: int, precision: to.dtype, estep_conf: EStepConfig
104 ) -> TVOVariationalStates:
105 states = make_var_states(estep_conf, N, H, precision)
106 return states
108 @property
109 def config(self) -> Dict[str, Any]:
110 return dict(self._conf)
112 @property
113 def estep_config(self) -> Dict[str, Any]:
114 return dict(self._estep_conf)
116 def run(self, epochs: int) -> Generator[EpochLog, None, None]:
117 """Run training and/or testing.
119 :param epochs: Number of epochs to train for
120 """
121 trainer = self.trainer
122 logger = self.logger
124 self._log_confs(logger)
126 # warm-up E-steps
127 if self._conf.warmup_Esteps > 0:
128 pprint("Warm-up E-steps")
129 for e in range(self._conf.warmup_Esteps):
130 compute_reconstruction = (
131 self._conf.warmup_reco_epochs is not None and e in self._conf.warmup_reco_epochs
132 )
133 d = trainer.e_step(compute_reconstruction)
134 self._log_epoch(logger, d)
136 # log initial free energies (after warm-up E-steps if any)
137 if self._conf.warmup_Esteps == 0:
138 d = trainer.eval_free_energies()
139 self._log_epoch(logger, d)
140 yield EpochLog(epoch=0, results=d)
142 # EM steps
143 for e in range(epochs):
144 start_t = time.time()
145 compute_reconstruction = (
146 self._conf.reco_epochs is not None and e in self._conf.reco_epochs
147 )
148 d = trainer.em_step(compute_reconstruction)
149 epoch_runtime = time.time() - start_t
150 self._log_epoch(logger, d)
151 yield EpochLog(e + 1, d, epoch_runtime)
153 # remove leftover ".old" logfiles produced by the logger
154 rank = dist.get_rank() if dist.is_initialized() else 0
155 leftover_logfile = self._conf.output + ".old"
156 if rank == 0 and Path(leftover_logfile).is_file():
157 os.remove(leftover_logfile)
159 # put trainer into undefined state after the experiment is finished
160 self.trainer = None # type: ignore
162 def _log_confs(self, logger: H5Logger):
163 """Dump experiment+estep configuration to screen and save it to output file."""
164 titles = ["Experiment", "E-step"]
165 confs = [self.config, self.estep_config]
166 logger.set(exp_config=self.config)
167 logger.set(estep_config=self.estep_config)
169 model_conf = self.model.config # could raise
170 logger.set(model_config=model_conf)
171 confs.append(model_conf)
172 titles.append("Model")
174 for title, conf in zip(titles, confs):
175 pprint(f"\n{title} configuration:")
176 for k, v in conf.items():
177 pprint(f"\t{k:<20}: {v}")
179 def _log_epoch(self, logger: H5Logger, epoch_results: Dict[str, float]):
180 """Log F, subs, model.theta, states.K and states.lpj to file, return printable log.
182 :param logger: the logger for this run
183 :param epoch_results: dictionary returned by Trainer.e_step or Trainer.em_step
184 """
185 for data_kind in "train", "test":
186 if data_kind + "_F" not in epoch_results:
187 continue
189 # log_kind is one of "train", "valid" or "test"
190 # (while data_kind is one of "train" or "test")
191 log_kind = "valid" if data_kind == "test" and self.train_data is not None else data_kind
193 # log F and subs to stdout and file
194 F, subs = get(epoch_results, f"{data_kind}_F", f"{data_kind}_subs")
195 assert not (math.isnan(F) or math.isinf(F)), f"{log_kind} free energy is invalid!"
196 F_and_subs_dict = {f"{log_kind}_F": to.tensor(F), f"{log_kind}_subs": to.tensor(subs)}
197 logger.append(**F_and_subs_dict)
199 # log latest states and lpj to file
200 states = getattr(self, f"{data_kind}_states")
201 if f"{log_kind}_states" not in self._conf.log_blacklist:
202 K = gather_from_processes(states.K)
203 logger.set(**{f"{log_kind}_states": K})
204 else:
205 K = None
206 if f"{log_kind}_lpj" not in self._conf.log_blacklist:
207 logger.set(**{f"{log_kind}_lpj": gather_from_processes(states.lpj)})
209 if self._conf.keep_best_states:
210 best_F_name = f"best_{log_kind}_F"
211 best_F = getattr(self, f"_{best_F_name}", None)
212 if best_F is None or F > best_F:
213 rank = dist.get_rank() if dist.is_initialized() else 0
214 if K is None:
215 K = gather_from_processes(states.K)
216 if rank == 0:
217 assert isinstance(K, to.Tensor) # to make mypy happy
218 best_states_dict = {
219 best_F_name: to.tensor(F),
220 f"best_{log_kind}_states": K.cpu().clone(),
221 }
222 logger.set(**best_states_dict)
223 setattr(self, f"_{best_F_name}", F)
225 # log data reconstructions
226 reco_dict = {}
227 if (
228 f"{log_kind}_reconstruction" not in self._conf.log_blacklist
229 and f"{data_kind}_rec" in epoch_results
230 ):
231 reco_dict[f"{log_kind}_reconstruction"] = gather_from_processes(
232 epoch_results[f"{data_kind}_rec"]
233 )
234 logger.set(**reco_dict)
236 log_theta_fn = logger.set if self._conf.log_only_latest_theta else logger.append
237 log_theta_fn(theta=self.model.theta)
238 logger.write()
241class Training(_TrainingAndOrValidation):
242 def __init__(
243 self,
244 conf: ExpConfig,
245 estep_conf: EStepConfig,
246 model: Trainable,
247 train_data_file: str,
248 val_data_file: str = None,
249 ):
250 """Train model on given dataset for the given number of epochs.
252 :param conf: Experiment configuration.
253 :param estep_conf: Instance of a class inheriting from EStepConfig.
254 :param model: model to train
255 :param train_data_file: Path to an HDF5 file containing the training dataset.
256 Datasets with name "train_data" and "data" will be
257 searched in the file, in this order.
258 :param val_data_file: Path to an HDF5 file containing the training dataset.
259 Datasets with name "val_data" and "data" will be searched in the file,
260 in this order.
262 On the validation dataset, Training only performs E-steps without updating
263 the model parameters.
265 .. _DataLoader docs: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
266 """
267 if tvo.get_run_policy() == "mpi":
268 init_processes()
269 train_dataset = get_h5_dataset_to_processes(train_data_file, ("train_data", "data"))
270 val_dataset = None
271 if val_data_file is not None:
272 val_dataset = get_h5_dataset_to_processes(val_data_file, ("val_data", "data"))
274 setattr(conf, "train_dataset", train_data_file)
275 setattr(conf, "val_dataset", val_data_file)
276 super().__init__(conf, estep_conf, model, train_dataset, val_dataset)
279class Testing(_TrainingAndOrValidation):
280 def __init__(self, conf: ExpConfig, estep_conf: EStepConfig, model: Trainable, data_file: str):
281 """Test given model on given dataset for the given number of epochs.
283 :param conf: Experiment configuration.
284 :param estep_conf: Instance of a class inheriting from EStepConfig.
285 :param model: model to test
286 :param data_file: Path to an HDF5 file containing the training dataset. Datasets with name
287 "test_data" and "data" will be searched in the file, in this order.
289 Only E-steps are run. Model parameters are not updated.
290 """
291 if tvo.get_run_policy() == "mpi":
292 init_processes()
293 dataset = get_h5_dataset_to_processes(data_file, ("test_data", "data"))
295 setattr(conf, "test_dataset", data_file)
296 super().__init__(conf, estep_conf, model, None, dataset)