Coverage for tvo/exp/_EStepConfig.py: 69%
59 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5from abc import ABC, abstractmethod
6from typing import Dict, Any
9class EStepConfig(ABC):
10 def __init__(self, n_states: int):
11 """Abstract base configuration object for experiments' E-steps.
13 :param n_states: Number of variational states per datapoint to keep in memory.
14 """
15 self.n_states = n_states
17 @abstractmethod
18 def as_dict(self) -> Dict[str, Any]:
19 raise NotImplementedError # pragma: no cover
22class EVOConfig(EStepConfig):
23 def __init__(
24 self,
25 n_states: int,
26 n_parents: int,
27 n_generations: int,
28 parent_selection: str = "fitness",
29 crossover: bool = True,
30 n_children: int = None,
31 mutation: str = "uniform",
32 bitflip_frequency: float = None,
33 K_init_file: str = None,
34 ):
35 """Configuration object for EVO E-step.
37 :param n_states: Number of variational states per datapoint to keep in memory.
38 :param n_parents: Number of parent states to select at each EVO generation.
39 Must be <= n_states.
40 :param parent_selection: Parent selection algorithm for EVO. Must be one of:
42 - 'fitness': fitness-proportional parent selection
43 - 'uniform': random uniform parent selection
44 :param crossover: Whether crossover should be applied or not.
45 Must be False if n_children is specified.
46 :param n_children: Number of children per parent to generate via mutation
47 at each EVO generation. Required if crossover is False.
48 :param mutation: Mutation algorithm for EVO. Must be one of:
50 - 'sparsity': bits are flipped so that states tend
51 towards current model sparsity.
52 - 'uniform': random uniform selection of bits to flip.
53 :param bitflip_frequency: Probability of flipping a bit during the mutation step (e.g.
54 2/H for an average of 2 bitflips per mutation). Required when
55 using the 'sparsity' mutation algorithm.
56 :param K_init_file: Full path to H5 file providing initial states
57 """
58 assert (
59 not crossover or n_children is None
60 ), "Exactly one of n_children and crossover may be provided."
61 valid_selections = ("fitness", "uniform")
62 assert parent_selection in valid_selections, f"Unknown parent selection {parent_selection}"
63 valid_mutations = ("sparsity", "uniform")
64 assert mutation in valid_mutations, f"Unknown mutation {mutation}"
65 assert (
66 n_parents <= n_states
67 ), f"n_parents ({n_parents}) must be lower than n_states ({n_states})"
68 assert (
69 mutation != "sparsity" or bitflip_frequency is not None
70 ), "bitflip_frequency is required for mutation algorithm 'sparsity'"
72 self.n_parents = n_parents
73 self.n_children = n_children
74 self.n_generations = n_generations
75 self.parent_selection = parent_selection
76 self.crossover = crossover
77 self.mutation = mutation
78 self.bitflip_frequency = bitflip_frequency
79 self.K_init_file = K_init_file
81 super().__init__(n_states)
83 def as_dict(self) -> Dict[str, Any]:
84 return vars(self)
87class TVSConfig(EStepConfig):
88 def __init__(
89 self,
90 n_states: int,
91 n_prior_samples: int,
92 n_marginal_samples: int,
93 K_init_file: str = None,
94 ):
95 """Configuration object for TVS E-step.
97 :param n_states: Number of variational states per datapoint to keep in memory.
98 :param n_prior_samples: Number of new variational states to be sampled from prior.
99 :param n_marginal_samples: Number of new variational states to be sampled from\
100 approximated marginal p(s_h=1|vec{y}^{(n)}, Theta).
101 :param K_init_file: Full path to H5 file providing initial states
102 """
103 assert n_states > 0, f"n_states must be positive integer ({n_states})"
104 assert n_prior_samples > 0, f"n_prior_samples must be positive integer ({n_prior_samples})"
105 assert (
106 n_marginal_samples > 0
107 ), f"n_marginal_samples must be positive integer ({n_marginal_samples})"
109 self.n_prior_samples = n_prior_samples
110 self.n_marginal_samples = n_marginal_samples
111 self.K_init_file = K_init_file
113 super().__init__(n_states)
115 def as_dict(self) -> Dict[str, Any]:
116 return vars(self)
119class FullEMConfig(EStepConfig):
120 def __init__(self, n_latents: int):
121 """Full EM configuration."""
122 super().__init__(2**n_latents)
124 def as_dict(self) -> Dict[str, Any]:
125 return vars(self)
128class FullEMSingleCauseConfig(EStepConfig):
129 def __init__(self, n_latents: int):
130 """Full EM configuration."""
131 super().__init__(n_latents)
133 def as_dict(self) -> Dict[str, Any]:
134 return vars(self)
137class RandomSamplingConfig(EStepConfig):
138 def __init__(
139 self, n_states: int, n_samples: int, sparsity: float = 0.5, K_init_file: str = None
140 ):
141 """Configuration object for random sampling.
143 :param n_states: Number of variational states per datapoint to keep in memory.
144 :param n_samples: Number of new variational states to randomly draw.
145 :param sparsity: average fraction of active units in sampled states.
146 :param K_init_file: Full path to H5 file providing initial states
147 """
148 assert n_states > 0, f"n_states must be positive integer ({n_states})"
149 assert n_samples > 0, f"n_samples must be positive integer ({n_samples})"
150 assert sparsity > 0 and sparsity < 1, f"sparsity must be in [0, 1] ({sparsity})"
152 self.n_samples = n_samples
153 self.sparsity = sparsity
154 self.K_init_file = K_init_file
156 super().__init__(n_states)
158 def as_dict(self) -> Dict[str, Any]:
159 return vars(self)