Coverage for tvo/exp/_utils.py: 71%
21 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6from typing import Union
8from tvo.variational import (
9 FullEM,
10 EVOVariationalStates,
11 FullEMSingleCauseModels,
12 TVSVariationalStates,
13 RandomSampledVarStates,
14)
15from tvo.exp._EStepConfig import (
16 FullEMConfig,
17 EVOConfig,
18 EStepConfig,
19 FullEMSingleCauseConfig,
20 TVSConfig,
21 RandomSamplingConfig,
22)
25def make_var_states(
26 conf: EStepConfig, N: int, H: int, precision: to.dtype
27) -> Union[
28 EVOVariationalStates,
29 FullEM,
30 FullEMSingleCauseModels,
31 TVSVariationalStates,
32 RandomSampledVarStates,
33]:
34 if isinstance(conf, FullEMConfig):
35 assert conf.n_states == 2**H, "FullEMConfig and model have different H"
36 return FullEM(N, H, precision)
37 elif isinstance(conf, FullEMSingleCauseConfig):
38 assert conf.n_states == H, "FullEMSingleCauseConfig and model have different H"
39 return FullEMSingleCauseModels(N, H, precision)
40 elif isinstance(conf, EVOConfig):
41 return _make_EVO_var_states(conf, N, H, precision)
42 elif isinstance(conf, TVSConfig):
43 return TVSVariationalStates(
44 N,
45 H,
46 conf.n_states,
47 precision,
48 conf.n_prior_samples,
49 conf.n_marginal_samples,
50 conf.K_init_file,
51 )
52 elif isinstance(conf, RandomSamplingConfig):
53 return RandomSampledVarStates(
54 N, H, conf.n_states, precision, conf.n_samples, conf.sparsity, conf.K_init_file
55 )
56 else: # pragma: no cover
57 raise NotImplementedError()
60def _make_EVO_var_states(conf: EVOConfig, N: int, H: int, precision: to.dtype):
61 selection = {"fitness": "batch_fitparents", "uniform": "randparents"}[conf.parent_selection]
62 mutation = {"sparsity": "sparseflip", "uniform": "randflip"}[conf.mutation]
63 return EVOVariationalStates(
64 N=N,
65 H=H,
66 S=conf.n_states,
67 precision=precision,
68 parent_selection=selection,
69 mutation=mutation,
70 n_parents=conf.n_parents,
71 n_generations=conf.n_generations,
72 n_children=conf.n_children if not conf.crossover else None,
73 crossover=conf.crossover,
74 bitflip_frequency=conf.bitflip_frequency,
75 K_init_file=conf.K_init_file,
76 )