Coverage for tvo/exp/_ExpConfig.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import os 

6 

7from typing import Iterable, Sequence, Dict, Any, Callable 

8import torch as to 

9 

10 

11class ExpConfig: 

12 def __init__( 

13 self, 

14 batch_size: int = 1, 

15 shuffle: bool = True, 

16 drop_last: bool = False, 

17 warmup_Esteps: int = 0, 

18 output: str = None, 

19 log_blacklist: Iterable[str] = [], 

20 log_only_latest_theta: bool = False, 

21 rollback_if_F_decreases: Sequence[str] = [], 

22 warmup_reco_epochs: Iterable[int] = None, 

23 reco_epochs: Iterable[int] = None, 

24 keep_best_states: bool = False, 

25 eval_F_at_epoch_end: bool = False, 

26 data_transform: Callable[[to.Tensor], to.Tensor] = None, 

27 ): 

28 """Configuration object for Experiment classes. 

29 

30 :param batch_size: Batch size for the data loaders. 

31 :param shuffle: Whether data should be reshuffled at every epoch. 

32 See also torch's `DataLoader docs`_. 

33 :param drop_last: set to True to drop the last incomplete batch, if the dataset size is not 

34 divisible by the batch size. See also torch's `DataLoader docs`_. 

35 :param warmup_Esteps: Number of warm-up E-steps to perform. 

36 :param output: Name or path of output HDF5 file. The default filename is "tvo_exp_<PID>.h5" 

37 where PID is the process ID. It is overwritten if it already exists. 

38 :param log_blacklist: By default, experiments log all available quantities. These are: 

39 

40 - "{train,valid,test}_F": one or more of training/validation/test 

41 free energy, depending on the experiment 

42 - "{train,valid,test}_subs": average variational state substitutions 

43 per datapoint (which ones are available depends on the experiment) 

44 - "{train,valid,test}_states": latest snapshot of variational states 

45 per datapoint 

46 - "{train,valid,test}_lpj": latest snapshot of log-pseudo-joints 

47 per datapoint 

48 - "theta": a group containing logs of whatever model.theta contains 

49 If one of these names appears in `log_blacklist`, the corresponing 

50 quantity will not be logged. 

51 :param log_only_latest_theta: Log only the most recent snapshot of the model parameters 

52 (use H5Logger.set instead of H5Logger.append) 

53 :param rollback_if_F_decreases: names of model parameters (corresponding to those in 

54 model.theta) that should be rolled back (i.e. not 

55 updated) if the free energy value before and after 

56 `model.update_param_epoch` decreases for a given epoch. 

57 This is only useful for models that perform the actual 

58 update of those parameters in `update_param_epoch` and not 

59 in `update_param_batch`. 

60 BSC and NoisyOR are such models. This feature is useful, 

61 for example, to prevent NoisyOR's M-step equation from 

62 oscillating away from the fixed point (i.e. the optimum). 

63 :param warmup_reco_epochs: List of warmup_Estep indices at which to compute data 

64 reconstructions. 

65 :param reco_epochs: List of epoch indices at which to compute data reconstructions. 

66 :param keep_best_states: If true, the experiment log will contain extra entries "best_*_F" 

67 and "best_*_states" (where * is one of "train", "valid", "test") 

68 corresponding to the best free energy value reached during training 

69 and the variational states at that epoch respectively. 

70 :param eval_F_at_epoch_end: By default, the framework evaluates the model free energy batch 

71 by batch during training, accumulating the values over the 

72 course of the epoch. If this option is set to `True`, the free 

73 energy will be evaluated at the end of each epoch instead. 

74 :param data_transform: A transformation to be applied to datapoints before they are passed 

75 to the model for training or evaluation. 

76 """ 

77 self.batch_size = batch_size 

78 self.shuffle = shuffle 

79 self.drop_last = drop_last 

80 self.warmup_Esteps = warmup_Esteps 

81 self.output = output if output is not None else f"tvo_exp_{os.getpid()}.h5" 

82 self.log_blacklist = log_blacklist 

83 self.log_only_latest_theta = log_only_latest_theta 

84 self.rollback_if_F_decreases = rollback_if_F_decreases 

85 self.warmup_reco_epochs = warmup_reco_epochs 

86 self.reco_epochs = reco_epochs 

87 self.keep_best_states = keep_best_states 

88 self.eval_F_at_epoch_end = eval_F_at_epoch_end 

89 self.data_transform = data_transform 

90 

91 def as_dict(self) -> Dict[str, Any]: 

92 return vars(self)