Coverage for tvo/variational/fullem.py: 93%

43 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6 

7from torch import Tensor 

8 

9import tvo 

10from tvo.utils.model_protocols import Trainable, Optimized 

11from tvo.variational.TVOVariationalStates import TVOVariationalStates 

12 

13 

14def state_matrix(H: int, device: to.device = None): 

15 """Get full combinatorics of H-dimensional binary vecor. 

16 

17 :param H: vector length 

18 :device: torch.device of output Tensor. Defaults to tvo.get_device(). 

19 :returns: tensor containing full combinatorics, shape (2**H,H) 

20 """ 

21 if device is None: 

22 device = tvo.get_device() 

23 

24 all_states = to.empty((2**H, H), dtype=to.uint8, device=device) 

25 for state in range(2**H): 

26 bit_sequence = tuple(int(bit) for bit in f"{state:0{H}b}") 

27 all_states[state] = to.tensor(bit_sequence, dtype=to.uint8, device=device) 

28 return all_states 

29 

30 

31class FullEM(TVOVariationalStates): 

32 def __init__(self, N: int, H: int, precision: to.dtype, K_init=None): 

33 """Full EM class. 

34 

35 :param N: Number of datapoints 

36 :param H: Number of latent variables 

37 :param precision: The floating point precision of the lpj values. 

38 Must be one of to.float32 or to.float64 

39 :param K_init: Optional initialization of states 

40 """ 

41 conf = dict(N=N, S=None, S_new=None, H=H, precision=precision) 

42 required_keys = ("N", "H", "precision") 

43 for c in required_keys: 

44 assert c in conf and conf[c] is not None 

45 self.config = conf 

46 self.lpj = to.empty((N, 2**H), dtype=precision, device=tvo.get_device()) 

47 self.precision = precision 

48 self.K = state_matrix(H)[None, :, :].expand(N, -1, -1) 

49 

50 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int: 

51 lpj_fn = model.log_pseudo_joint if isinstance(model, Optimized) else model.log_joint 

52 

53 K = self.K 

54 lpj = self.lpj 

55 

56 lpj[idx] = lpj_fn(batch, K[idx]) 

57 

58 return 0 

59 

60 

61class FullEMSingleCauseModels(FullEM): 

62 def __init__(self, N: int, H: int, precision: to.dtype): 

63 """Full EM class for single causes models. 

64 

65 :param N: Number of datapoints 

66 :param C: Number of latent variables 

67 :param precision: The floating point precision of the lpj values. 

68 Must be one of to.float32 or to.float64 

69 """ 

70 conf = dict(N=N, S=None, S_new=None, H=H, precision=precision) 

71 required_keys = ("N", "H", "precision") 

72 for c in required_keys: 

73 assert c in conf and conf[c] is not None 

74 self.config = conf 

75 self.lpj = to.empty((N, H), dtype=precision, device=tvo.get_device()) 

76 self.precision = precision 

77 self.K = to.eye(H, dtype=precision, device=tvo.get_device())[None, :, :].expand(N, -1, -1) 

78 

79 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int: 

80 assert to.any(self.K.sum(axis=1) == 1), "Multiple causes detected." 

81 super(FullEMSingleCauseModels, self).update(idx, batch, model) 

82 return 0