Coverage for tvo/exp/_utils.py: 71%

21 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6from typing import Union 

7 

8from tvo.variational import ( 

9 FullEM, 

10 EVOVariationalStates, 

11 FullEMSingleCauseModels, 

12 TVSVariationalStates, 

13 RandomSampledVarStates, 

14) 

15from tvo.exp._EStepConfig import ( 

16 FullEMConfig, 

17 EVOConfig, 

18 EStepConfig, 

19 FullEMSingleCauseConfig, 

20 TVSConfig, 

21 RandomSamplingConfig, 

22) 

23 

24 

25def make_var_states( 

26 conf: EStepConfig, N: int, H: int, precision: to.dtype 

27) -> Union[ 

28 EVOVariationalStates, 

29 FullEM, 

30 FullEMSingleCauseModels, 

31 TVSVariationalStates, 

32 RandomSampledVarStates, 

33]: 

34 if isinstance(conf, FullEMConfig): 

35 assert conf.n_states == 2**H, "FullEMConfig and model have different H" 

36 return FullEM(N, H, precision) 

37 elif isinstance(conf, FullEMSingleCauseConfig): 

38 assert conf.n_states == H, "FullEMSingleCauseConfig and model have different H" 

39 return FullEMSingleCauseModels(N, H, precision) 

40 elif isinstance(conf, EVOConfig): 

41 return _make_EVO_var_states(conf, N, H, precision) 

42 elif isinstance(conf, TVSConfig): 

43 return TVSVariationalStates( 

44 N, 

45 H, 

46 conf.n_states, 

47 precision, 

48 conf.n_prior_samples, 

49 conf.n_marginal_samples, 

50 conf.K_init_file, 

51 ) 

52 elif isinstance(conf, RandomSamplingConfig): 

53 return RandomSampledVarStates( 

54 N, H, conf.n_states, precision, conf.n_samples, conf.sparsity, conf.K_init_file 

55 ) 

56 else: # pragma: no cover 

57 raise NotImplementedError() 

58 

59 

60def _make_EVO_var_states(conf: EVOConfig, N: int, H: int, precision: to.dtype): 

61 selection = {"fitness": "batch_fitparents", "uniform": "randparents"}[conf.parent_selection] 

62 mutation = {"sparsity": "sparseflip", "uniform": "randflip"}[conf.mutation] 

63 return EVOVariationalStates( 

64 N=N, 

65 H=H, 

66 S=conf.n_states, 

67 precision=precision, 

68 parent_selection=selection, 

69 mutation=mutation, 

70 n_parents=conf.n_parents, 

71 n_generations=conf.n_generations, 

72 n_children=conf.n_children if not conf.crossover else None, 

73 crossover=conf.crossover, 

74 bitflip_frequency=conf.bitflip_frequency, 

75 K_init_file=conf.K_init_file, 

76 )