Coverage for tvo/exp/_ExpConfig.py: 100%
20 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import os
7from typing import Iterable, Sequence, Dict, Any, Callable
8import torch as to
11class ExpConfig:
12 def __init__(
13 self,
14 batch_size: int = 1,
15 shuffle: bool = True,
16 drop_last: bool = False,
17 warmup_Esteps: int = 0,
18 output: str = None,
19 log_blacklist: Iterable[str] = [],
20 log_only_latest_theta: bool = False,
21 rollback_if_F_decreases: Sequence[str] = [],
22 warmup_reco_epochs: Iterable[int] = None,
23 reco_epochs: Iterable[int] = None,
24 keep_best_states: bool = False,
25 eval_F_at_epoch_end: bool = False,
26 data_transform: Callable[[to.Tensor], to.Tensor] = None,
27 ):
28 """Configuration object for Experiment classes.
30 :param batch_size: Batch size for the data loaders.
31 :param shuffle: Whether data should be reshuffled at every epoch.
32 See also torch's `DataLoader docs`_.
33 :param drop_last: set to True to drop the last incomplete batch, if the dataset size is not
34 divisible by the batch size. See also torch's `DataLoader docs`_.
35 :param warmup_Esteps: Number of warm-up E-steps to perform.
36 :param output: Name or path of output HDF5 file. The default filename is "tvo_exp_<PID>.h5"
37 where PID is the process ID. It is overwritten if it already exists.
38 :param log_blacklist: By default, experiments log all available quantities. These are:
40 - "{train,valid,test}_F": one or more of training/validation/test
41 free energy, depending on the experiment
42 - "{train,valid,test}_subs": average variational state substitutions
43 per datapoint (which ones are available depends on the experiment)
44 - "{train,valid,test}_states": latest snapshot of variational states
45 per datapoint
46 - "{train,valid,test}_lpj": latest snapshot of log-pseudo-joints
47 per datapoint
48 - "theta": a group containing logs of whatever model.theta contains
49 If one of these names appears in `log_blacklist`, the corresponing
50 quantity will not be logged.
51 :param log_only_latest_theta: Log only the most recent snapshot of the model parameters
52 (use H5Logger.set instead of H5Logger.append)
53 :param rollback_if_F_decreases: names of model parameters (corresponding to those in
54 model.theta) that should be rolled back (i.e. not
55 updated) if the free energy value before and after
56 `model.update_param_epoch` decreases for a given epoch.
57 This is only useful for models that perform the actual
58 update of those parameters in `update_param_epoch` and not
59 in `update_param_batch`.
60 BSC and NoisyOR are such models. This feature is useful,
61 for example, to prevent NoisyOR's M-step equation from
62 oscillating away from the fixed point (i.e. the optimum).
63 :param warmup_reco_epochs: List of warmup_Estep indices at which to compute data
64 reconstructions.
65 :param reco_epochs: List of epoch indices at which to compute data reconstructions.
66 :param keep_best_states: If true, the experiment log will contain extra entries "best_*_F"
67 and "best_*_states" (where * is one of "train", "valid", "test")
68 corresponding to the best free energy value reached during training
69 and the variational states at that epoch respectively.
70 :param eval_F_at_epoch_end: By default, the framework evaluates the model free energy batch
71 by batch during training, accumulating the values over the
72 course of the epoch. If this option is set to `True`, the free
73 energy will be evaluated at the end of each epoch instead.
74 :param data_transform: A transformation to be applied to datapoints before they are passed
75 to the model for training or evaluation.
76 """
77 self.batch_size = batch_size
78 self.shuffle = shuffle
79 self.drop_last = drop_last
80 self.warmup_Esteps = warmup_Esteps
81 self.output = output if output is not None else f"tvo_exp_{os.getpid()}.h5"
82 self.log_blacklist = log_blacklist
83 self.log_only_latest_theta = log_only_latest_theta
84 self.rollback_if_F_decreases = rollback_if_F_decreases
85 self.warmup_reco_epochs = warmup_reco_epochs
86 self.reco_epochs = reco_epochs
87 self.keep_best_states = keep_best_states
88 self.eval_F_at_epoch_end = eval_F_at_epoch_end
89 self.data_transform = data_transform
91 def as_dict(self) -> Dict[str, Any]:
92 return vars(self)