Coverage for tvo/models/pmm.py: 91%
107 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
6import torch as to
7import math
8from torch.distributions.one_hot_categorical import OneHotCategorical
10from torch import Tensor
11from typing import Union, Tuple
13import tvo
14from tvo.utils.parallel import pprint, all_reduce, broadcast
15from tvo.variational.TVOVariationalStates import TVOVariationalStates
16from tvo.variational._utils import mean_posterior
17from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor
18from tvo.utils.sanity import fix_theta
21class PMM(Optimized, Sampler, Reconstructor):
22 def __init__(
23 self,
24 H: int,
25 D: int,
26 W_init: Tensor = None,
27 pies_init: Tensor = None,
28 precision: to.dtype = to.float64,
29 ):
30 """Poisson Mixture Model (PMM).
32 :param H: Number of hidden units.
33 :param D: Number of observables.
34 :param W_init: Tensor with shape (D,H), initializes PMM weights.
35 :param pies_init: Tensor with shape (H,), initializes PMM priors.
36 :param precision: Floating point precision required. Must be one of torch.float32 or
37 torch.float64.
39 """
40 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}"
41 self._precision = precision
43 device = tvo.get_device()
45 if W_init is not None:
46 assert W_init.shape == (D, H)
47 W_init = W_init.to(dtype=precision, device=device)
48 else:
49 W_init = to.rand((D, H), dtype=precision, device=device)
50 broadcast(W_init)
52 if pies_init is not None:
53 assert pies_init.shape == (H,)
54 pies_init = pies_init.to(dtype=precision, device=device)
55 else:
56 pies_init = to.full((H,), 1.0 / H, dtype=precision, device=device)
58 self._theta = {"pies": pies_init, "W": W_init}
59 eps, inf = 1.0e-5, math.inf
60 self.policy = {
61 "W": [None, to.full_like(self._theta["W"], eps), to.full_like(self._theta["W"], inf)],
62 "pies": [
63 None,
64 to.full_like(self._theta["pies"], eps),
65 to.full_like(self._theta["pies"], 1.0 - eps),
66 ],
67 }
69 self.my_Wp = to.zeros((D, H), dtype=precision, device=device)
70 self.my_Wq = to.zeros((H), dtype=precision, device=device)
71 self.my_pies = to.zeros(H, dtype=precision, device=device)
72 self.my_N = to.tensor([0], dtype=to.int, device=device)
73 self._config = dict(H=H, D=D, precision=precision, device=device)
74 self._shape = self.theta["W"].shape
76 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore
77 """Evaluate log-pseudo-joints for GMM."""
78 Kfloat = states.to(
79 dtype=self.theta["W"].dtype
80 ) # N,C,C # TODO Find solution to avoid byte->float casting
82 Wbar = (
83 to.matmul(Kfloat, self.theta["W"].t()) + to.finfo(to.float32).tiny
84 ) # N,C,D # TODO Pre-allocate tensor and use `out` argument of to.matmul
86 lpj = (
87 to.sum(data[:, None, :] * to.log(Wbar), dim=2)
88 - to.sum(Wbar, dim=2)
89 + to.matmul(Kfloat, to.log(self.theta["pies"]))
90 )
91 return lpj.to(device=states.device)
93 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor:
94 """Evaluate log-joints for PMM."""
95 if lpj is None:
96 lpj = self.log_pseudo_joint(data, states)
97 return lpj - to.sum(to.lgamma(data + 1), dim=1)[:, None]
99 def update_param_batch(self, idx: Tensor, batch: Tensor, states: Tensor) -> None:
100 lpj = states.lpj[idx]
101 K = states.K[idx]
102 batch_size, S, _ = K.shape
104 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting
106 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H) mean_posterior(Kfloat, lpj)
107 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H)
109 self.my_pies.add_(to.sum(batch_s_pjc, dim=0))
110 self.my_Wp.add_(to.sum(batch_Wp, dim=0))
111 self.my_Wq.add_(to.sum(batch_s_pjc, dim=0))
112 self.my_N.add_(batch_size)
114 return None
116 def update_param_epoch(self) -> None:
117 theta = self.theta
118 policy = self.policy
120 all_reduce(self.my_Wp)
121 all_reduce(self.my_Wq)
122 all_reduce(self.my_pies)
123 all_reduce(self.my_N)
125 N = self.my_N.item()
127 # Calculate updated W
128 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"])
129 broadcast(Wold_noisy)
130 theta_new = {}
131 try:
132 theta_new["W"] = self.my_Wp / self.my_Wq[None, :]
133 except RuntimeError:
134 pprint("Inversion error. Will not update W but add some noise instead.")
135 theta_new["W"] = Wold_noisy
137 # Calculate updated pi
138 theta_new["pies"] = self.my_pies / N
140 policy["W"][0] = Wold_noisy
141 policy["pies"][0] = theta["pies"]
142 fix_theta(theta_new, policy)
143 for key in theta:
144 theta[key] = theta_new[key]
146 self.my_Wp[:] = 0.0
147 self.my_Wq[:] = 0.0
148 self.my_pies[:] = 0.0
149 self.my_N[:] = 0.0
151 @property
152 def shape(self) -> Tuple[int, ...]:
153 return self.theta["W"].shape
155 def generate_data(
156 self, N: int = None, hidden_state: to.Tensor = None
157 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
158 precision, device = self.precision, tvo.get_device()
159 D, H = self.shape
161 if hidden_state is None:
162 assert N is not None
163 pies = self.theta["pies"]
164 hidden_state = OneHotCategorical(probs=pies).sample([N]) == 1
165 must_return_hidden_state = True
166 else:
167 shape = hidden_state.shape
168 if N is None:
169 N = shape[0]
170 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
171 must_return_hidden_state = False
173 Wbar = to.zeros((N, D), dtype=precision, device=device)
175 for n in range(N):
176 for h in range(H):
177 if hidden_state[n, h]:
178 Wbar[n] += self.theta["W"][:, h]
180 # Add noise according to the model parameters
181 Y = to.poisson(Wbar)
183 return (Y, hidden_state) if must_return_hidden_state else Y
185 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor:
186 # Not yet implemented
188 """Estimator used for data reconstruction. Data reconstruction can only be supported
189 by a model if it implements this method. The estimator to be implemented is defined
190 as follows:""" r"""
191 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
192 """
193 K = states.K[idx]
194 # TODO Find solution to avoid byte->float casting of `K`
195 # TODO Pre-allocate tensor and use `out` argument of to.matmul
196 return mean_posterior(
197 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx]
198 )