Coverage for tvo/exp/_EStepConfig.py: 69%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5from abc import ABC, abstractmethod 

6from typing import Dict, Any 

7 

8 

9class EStepConfig(ABC): 

10 def __init__(self, n_states: int): 

11 """Abstract base configuration object for experiments' E-steps. 

12 

13 :param n_states: Number of variational states per datapoint to keep in memory. 

14 """ 

15 self.n_states = n_states 

16 

17 @abstractmethod 

18 def as_dict(self) -> Dict[str, Any]: 

19 raise NotImplementedError # pragma: no cover 

20 

21 

22class EVOConfig(EStepConfig): 

23 def __init__( 

24 self, 

25 n_states: int, 

26 n_parents: int, 

27 n_generations: int, 

28 parent_selection: str = "fitness", 

29 crossover: bool = True, 

30 n_children: int = None, 

31 mutation: str = "uniform", 

32 bitflip_frequency: float = None, 

33 K_init_file: str = None, 

34 ): 

35 """Configuration object for EVO E-step. 

36 

37 :param n_states: Number of variational states per datapoint to keep in memory. 

38 :param n_parents: Number of parent states to select at each EVO generation. 

39 Must be <= n_states. 

40 :param parent_selection: Parent selection algorithm for EVO. Must be one of: 

41 

42 - 'fitness': fitness-proportional parent selection 

43 - 'uniform': random uniform parent selection 

44 :param crossover: Whether crossover should be applied or not. 

45 Must be False if n_children is specified. 

46 :param n_children: Number of children per parent to generate via mutation 

47 at each EVO generation. Required if crossover is False. 

48 :param mutation: Mutation algorithm for EVO. Must be one of: 

49 

50 - 'sparsity': bits are flipped so that states tend 

51 towards current model sparsity. 

52 - 'uniform': random uniform selection of bits to flip. 

53 :param bitflip_frequency: Probability of flipping a bit during the mutation step (e.g. 

54 2/H for an average of 2 bitflips per mutation). Required when 

55 using the 'sparsity' mutation algorithm. 

56 :param K_init_file: Full path to H5 file providing initial states 

57 """ 

58 assert ( 

59 not crossover or n_children is None 

60 ), "Exactly one of n_children and crossover may be provided." 

61 valid_selections = ("fitness", "uniform") 

62 assert parent_selection in valid_selections, f"Unknown parent selection {parent_selection}" 

63 valid_mutations = ("sparsity", "uniform") 

64 assert mutation in valid_mutations, f"Unknown mutation {mutation}" 

65 assert ( 

66 n_parents <= n_states 

67 ), f"n_parents ({n_parents}) must be lower than n_states ({n_states})" 

68 assert ( 

69 mutation != "sparsity" or bitflip_frequency is not None 

70 ), "bitflip_frequency is required for mutation algorithm 'sparsity'" 

71 

72 self.n_parents = n_parents 

73 self.n_children = n_children 

74 self.n_generations = n_generations 

75 self.parent_selection = parent_selection 

76 self.crossover = crossover 

77 self.mutation = mutation 

78 self.bitflip_frequency = bitflip_frequency 

79 self.K_init_file = K_init_file 

80 

81 super().__init__(n_states) 

82 

83 def as_dict(self) -> Dict[str, Any]: 

84 return vars(self) 

85 

86 

87class TVSConfig(EStepConfig): 

88 def __init__( 

89 self, 

90 n_states: int, 

91 n_prior_samples: int, 

92 n_marginal_samples: int, 

93 K_init_file: str = None, 

94 ): 

95 """Configuration object for TVS E-step. 

96 

97 :param n_states: Number of variational states per datapoint to keep in memory. 

98 :param n_prior_samples: Number of new variational states to be sampled from prior. 

99 :param n_marginal_samples: Number of new variational states to be sampled from\ 

100 approximated marginal p(s_h=1|vec{y}^{(n)}, Theta). 

101 :param K_init_file: Full path to H5 file providing initial states 

102 """ 

103 assert n_states > 0, f"n_states must be positive integer ({n_states})" 

104 assert n_prior_samples > 0, f"n_prior_samples must be positive integer ({n_prior_samples})" 

105 assert ( 

106 n_marginal_samples > 0 

107 ), f"n_marginal_samples must be positive integer ({n_marginal_samples})" 

108 

109 self.n_prior_samples = n_prior_samples 

110 self.n_marginal_samples = n_marginal_samples 

111 self.K_init_file = K_init_file 

112 

113 super().__init__(n_states) 

114 

115 def as_dict(self) -> Dict[str, Any]: 

116 return vars(self) 

117 

118 

119class FullEMConfig(EStepConfig): 

120 def __init__(self, n_latents: int): 

121 """Full EM configuration.""" 

122 super().__init__(2**n_latents) 

123 

124 def as_dict(self) -> Dict[str, Any]: 

125 return vars(self) 

126 

127 

128class FullEMSingleCauseConfig(EStepConfig): 

129 def __init__(self, n_latents: int): 

130 """Full EM configuration.""" 

131 super().__init__(n_latents) 

132 

133 def as_dict(self) -> Dict[str, Any]: 

134 return vars(self) 

135 

136 

137class RandomSamplingConfig(EStepConfig): 

138 def __init__( 

139 self, n_states: int, n_samples: int, sparsity: float = 0.5, K_init_file: str = None 

140 ): 

141 """Configuration object for random sampling. 

142 

143 :param n_states: Number of variational states per datapoint to keep in memory. 

144 :param n_samples: Number of new variational states to randomly draw. 

145 :param sparsity: average fraction of active units in sampled states. 

146 :param K_init_file: Full path to H5 file providing initial states 

147 """ 

148 assert n_states > 0, f"n_states must be positive integer ({n_states})" 

149 assert n_samples > 0, f"n_samples must be positive integer ({n_samples})" 

150 assert sparsity > 0 and sparsity < 1, f"sparsity must be in [0, 1] ({sparsity})" 

151 

152 self.n_samples = n_samples 

153 self.sparsity = sparsity 

154 self.K_init_file = K_init_file 

155 

156 super().__init__(n_states) 

157 

158 def as_dict(self) -> Dict[str, Any]: 

159 return vars(self)