Coverage for tvo/models/gmm.py: 91%
121 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
6import torch as to
7import math
8from torch.distributions.one_hot_categorical import OneHotCategorical
10from torch import Tensor
11from typing import Union, Tuple
13import tvo
14from tvo.utils.parallel import pprint, all_reduce, broadcast
15from tvo.variational.TVOVariationalStates import TVOVariationalStates
16from tvo.variational._utils import mean_posterior
17from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor
18from tvo.utils.sanity import fix_theta
21class GMM(Optimized, Sampler, Reconstructor):
22 def __init__(
23 self,
24 H: int,
25 D: int,
26 W_init: Tensor = None,
27 sigma2_init: Tensor = None,
28 pies_init: Tensor = None,
29 precision: to.dtype = to.float64,
30 ):
31 """Gaussian Mixture model (GMM).
33 :param H: Number of hidden units.
34 :param D: Number of observables.
35 :param W_init: Tensor with shape (D,H), initializes GM weights.
36 :param pies_init: Tensor with shape (H,), initializes GM priors.
37 :param precision: Floating point precision required. Must be one of torch.float32 or
38 torch.float64.
40 """
41 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}"
42 self._precision = precision
44 device = tvo.get_device()
46 if W_init is not None:
47 assert W_init.shape == (D, H)
48 W_init = W_init.to(dtype=precision, device=device)
49 else:
50 W_init = to.rand((D, H), dtype=precision, device=device)
51 broadcast(W_init)
53 if pies_init is not None:
54 assert pies_init.shape == (H,)
55 pies_init = pies_init.to(dtype=precision, device=device)
56 else:
57 pies_init = to.full((H,), 1.0 / H, dtype=precision, device=device)
59 if sigma2_init is not None:
60 assert sigma2_init.shape == (1,)
61 sigma2_init = sigma2_init.to(dtype=precision, device=device)
62 else:
63 sigma2_init = to.tensor([1.0], dtype=precision, device=device)
65 self._theta = {"pies": pies_init, "W": W_init, "sigma2": sigma2_init}
66 eps, inf = 1.0e-5, math.inf
67 self.policy = {
68 "W": [None, to.full_like(self._theta["W"], -inf), to.full_like(self._theta["W"], inf)],
69 "pies": [
70 None,
71 to.full_like(self._theta["pies"], eps),
72 to.full_like(self._theta["pies"], 1.0 - eps),
73 ],
74 "sigma2": [
75 None,
76 to.full_like(self._theta["sigma2"], eps),
77 to.full_like(self._theta["sigma2"], inf),
78 ],
79 }
81 self.my_Wp = to.zeros((D, H), dtype=precision, device=device)
82 self.my_Wq = to.zeros((H), dtype=precision, device=device)
83 self.my_pies = to.zeros(H, dtype=precision, device=device)
84 self.my_sigma2 = to.zeros(1, dtype=precision, device=device)
85 self.my_N = to.tensor([0], dtype=to.int, device=device)
86 self._config = dict(H=H, D=D, precision=precision, device=device)
87 self._shape = self.theta["W"].shape
89 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore
90 """Evaluate log-pseudo-joints for GMM."""
91 Kfloat = states.to(
92 dtype=self.theta["W"].dtype
93 ) # N,C,C # TODO Find solution to avoid byte->float casting
94 Wbar = to.matmul(
95 Kfloat, self.theta["W"].t()
96 ) # N,C,D # TODO Pre-allocate tensor and use `out` argument of to.matmul
97 lpj = to.mul(
98 to.sum(to.pow(Wbar - data[:, None, :], 2), dim=2), -1 / 2 / self.theta["sigma2"]
99 ) + to.matmul(Kfloat, to.log(self.theta["pies"]))
100 return lpj.to(device=states.device)
102 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor:
103 """Evaluate log-joints for GMM."""
104 if lpj is None:
105 lpj = self.log_pseudo_joint(data, states)
106 D = self.shape[0]
107 return lpj - D / 2 * to.log(2 * math.pi * self.theta["sigma2"])
109 def update_param_batch(self, idx: Tensor, batch: Tensor, states: Tensor) -> None:
110 lpj = states.lpj[idx]
111 K = states.K[idx]
112 batch_size, S, _ = K.shape
114 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting
115 Wbar = to.matmul(
116 Kfloat, self.theta["W"].t()
117 ) # N,S,D # TODO Find solution to re-use evaluations from E-step
119 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H) mean_posterior(Kfloat, lpj)
120 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H)
121 batch_sigma2 = mean_posterior(to.sum((batch[:, None, :] - Wbar) ** 2, dim=2), lpj)
123 self.my_pies.add_(to.sum(batch_s_pjc, dim=0))
124 self.my_Wp.add_(to.sum(batch_Wp, dim=0))
125 self.my_Wq.add_(to.sum(batch_s_pjc, dim=0))
126 self.my_sigma2.add_(to.sum(batch_sigma2))
127 self.my_N.add_(batch_size)
129 return None
131 def update_param_epoch(self) -> None:
132 theta = self.theta
133 policy = self.policy
135 all_reduce(self.my_Wp)
136 all_reduce(self.my_Wq)
137 all_reduce(self.my_pies)
138 all_reduce(self.my_sigma2)
139 all_reduce(self.my_N)
141 N = self.my_N.item()
142 D = self.shape[0]
144 # Calculate updated W
145 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"])
146 broadcast(Wold_noisy)
147 theta_new = {}
148 try:
149 theta_new["W"] = self.my_Wp / self.my_Wq[None, :]
150 except RuntimeError:
151 pprint("Inversion error. Will not update W but add some noise instead.")
152 theta_new["W"] = Wold_noisy
154 # Calculate updated pi
155 theta_new["pies"] = self.my_pies / N
157 # Calculate updated sigma^2
158 theta_new["sigma2"] = self.my_sigma2 / N / D
160 policy["W"][0] = Wold_noisy
161 policy["pies"][0] = theta["pies"]
162 policy["sigma2"][0] = theta["sigma2"]
163 fix_theta(theta_new, policy)
164 for key in theta:
165 theta[key] = theta_new[key]
167 self.my_Wp[:] = 0.0
168 self.my_Wq[:] = 0.0
169 self.my_pies[:] = 0.0
170 self.my_sigma2[:] = 0.0
171 self.my_N[:] = 0.0
173 @property
174 def shape(self) -> Tuple[int, ...]:
175 return self.theta["W"].shape
177 def generate_data(
178 self, N: int = None, hidden_state: to.Tensor = None
179 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
180 precision, device = self.precision, tvo.get_device()
181 D, H = self.shape
183 if hidden_state is None:
184 assert N is not None
185 pies = self.theta["pies"]
186 hidden_state = OneHotCategorical(probs=pies).sample([N]) == 1
187 must_return_hidden_state = True
188 else:
189 shape = hidden_state.shape
190 if N is None:
191 N = shape[0]
192 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
193 must_return_hidden_state = False
195 Wbar = to.zeros((N, D), dtype=precision, device=device)
197 for n in range(N):
198 for h in range(H):
199 if hidden_state[n, h]:
200 Wbar[n] += self.theta["W"][:, h]
202 # Add noise according to the model parameters
203 Y = Wbar + to.sqrt(self.theta["sigma2"]) * to.randn((N, D), dtype=precision, device=device)
205 return (Y, hidden_state) if must_return_hidden_state else Y
207 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor:
208 # Not yet implemented
210 """Estimator used for data reconstruction. Data reconstruction can only be supported
211 by a model if it implements this method. The estimator to be implemented is defined
212 as follows:""" r"""
213 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
214 """
215 # Not
216 K = states.K[idx]
217 # TODO Find solution to avoid byte->float casting of `K`
218 # TODO Pre-allocate tensor and use `out` argument of to.matmul
219 return mean_posterior(
220 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx]
221 )