Coverage for tvo/variational/fullem.py: 93%
43 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
7from torch import Tensor
9import tvo
10from tvo.utils.model_protocols import Trainable, Optimized
11from tvo.variational.TVOVariationalStates import TVOVariationalStates
14def state_matrix(H: int, device: to.device = None):
15 """Get full combinatorics of H-dimensional binary vecor.
17 :param H: vector length
18 :device: torch.device of output Tensor. Defaults to tvo.get_device().
19 :returns: tensor containing full combinatorics, shape (2**H,H)
20 """
21 if device is None:
22 device = tvo.get_device()
24 all_states = to.empty((2**H, H), dtype=to.uint8, device=device)
25 for state in range(2**H):
26 bit_sequence = tuple(int(bit) for bit in f"{state:0{H}b}")
27 all_states[state] = to.tensor(bit_sequence, dtype=to.uint8, device=device)
28 return all_states
31class FullEM(TVOVariationalStates):
32 def __init__(self, N: int, H: int, precision: to.dtype, K_init=None):
33 """Full EM class.
35 :param N: Number of datapoints
36 :param H: Number of latent variables
37 :param precision: The floating point precision of the lpj values.
38 Must be one of to.float32 or to.float64
39 :param K_init: Optional initialization of states
40 """
41 conf = dict(N=N, S=None, S_new=None, H=H, precision=precision)
42 required_keys = ("N", "H", "precision")
43 for c in required_keys:
44 assert c in conf and conf[c] is not None
45 self.config = conf
46 self.lpj = to.empty((N, 2**H), dtype=precision, device=tvo.get_device())
47 self.precision = precision
48 self.K = state_matrix(H)[None, :, :].expand(N, -1, -1)
50 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int:
51 lpj_fn = model.log_pseudo_joint if isinstance(model, Optimized) else model.log_joint
53 K = self.K
54 lpj = self.lpj
56 lpj[idx] = lpj_fn(batch, K[idx])
58 return 0
61class FullEMSingleCauseModels(FullEM):
62 def __init__(self, N: int, H: int, precision: to.dtype):
63 """Full EM class for single causes models.
65 :param N: Number of datapoints
66 :param C: Number of latent variables
67 :param precision: The floating point precision of the lpj values.
68 Must be one of to.float32 or to.float64
69 """
70 conf = dict(N=N, S=None, S_new=None, H=H, precision=precision)
71 required_keys = ("N", "H", "precision")
72 for c in required_keys:
73 assert c in conf and conf[c] is not None
74 self.config = conf
75 self.lpj = to.empty((N, H), dtype=precision, device=tvo.get_device())
76 self.precision = precision
77 self.K = to.eye(H, dtype=precision, device=tvo.get_device())[None, :, :].expand(N, -1, -1)
79 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int:
80 assert to.any(self.K.sum(axis=1) == 1), "Multiple causes detected."
81 super(FullEMSingleCauseModels, self).update(idx, batch, model)
82 return 0