Coverage for tvo/models/sssc.py: 73%
307 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6from typing import Dict, Optional, Tuple, Union, Any
7from math import pi as MATH_PI
8from tvo import get_device
9from tvo.utils.model_protocols import Sampler, Optimized, Reconstructor
10from tvo.utils.parallel import broadcast, all_reduce, pprint
11from tvo.variational.TVOVariationalStates import TVOVariationalStates
12from tvo.variational._utils import mean_posterior
15def _get_hash(x: to.Tensor) -> int:
16 return hash(x.detach().cpu().numpy().tobytes())
19class SSSC(Sampler, Optimized, Reconstructor):
20 def __init__(
21 self,
22 H: int,
23 D: int,
24 W_init: to.Tensor = None,
25 sigma2_init: to.Tensor = None,
26 mus_init: to.Tensor = None,
27 Psi_init: to.Tensor = None,
28 pies_init: to.Tensor = None,
29 reformulated_lpj: bool = True,
30 use_storage: bool = True,
31 reformulated_psi_update: bool = False,
32 precision: to.dtype = to.float32,
33 ):
34 """Spike-And-Slab Sparse Coding (SSSC) model.
36 :param H: Number of hidden units.
37 :param D: Number of observables.
38 :param W_init: Tensor with shape (H, D), initializes SSSC weights.
39 :param sigma2_init: Tensor initializing SSSC observable variance.
40 :param mus_init: Tensor with shape (H,), initializes SSSC latent means.
41 :param Psi_init: Tensor with shape (H, H), initializes SSSC latent variance.
42 :param pies_init: Tensor with shape (H,), initializes SSSC priors.
43 :param reformulated_lpj: Use looped instead of batchified E-step and mathematically
44 reformulated form of the log-pseudo-joint formula (exploiting
45 matrix determinant lemma and Woodbury matrix identity). Yields
46 more accurate solutions in large dimensions (i.e. large D and H).
47 :param use_storage: Whether to memorize state vector-dependent and datapoint independent-
48 terms computed in the E-step. Terms will be looked-up rather than re-
49 computed if a datapoint evaluates a state that has been evaluated for
50 another datapoint before. The storage will be cleared after each epoch.
51 :param reformulated_psi_update: Whether to update Psi using reformulated form of the
52 update equation.
53 :param precision: Floating point precision required. Must be one of torch.float32 or
54 torch.float64.
55 """
56 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}"
57 device = get_device()
58 self._precision = precision
59 self._shape = (D, H)
60 self._reformulated_lpj = reformulated_lpj
61 self._use_storage = use_storage
62 self._reformulated_psi_update = reformulated_psi_update
64 self._theta: Dict[str, to.Tensor] = {}
65 self._theta["W"] = self._init_W(W_init)
66 self._theta["sigma2"] = self._init_sigma2(sigma2_init)
67 self._theta["mus"] = self._init_mus(mus_init)
68 self._theta["Psi"] = self._init_Psi(Psi_init)
69 self._theta["pies"] = self._init_pies(pies_init)
71 self._log2pi = to.log(to.tensor([2.0 * MATH_PI], dtype=precision, device=device))
73 self._my_sum_y_szT = to.zeros((D, H), dtype=precision, device=device)
74 self._my_sum_xpt_sz_xpt_szT = to.zeros((H, H), dtype=precision, device=device)
75 self._my_sum_xpt_szszT = to.zeros((H, H), dtype=precision, device=device)
76 self._my_sum_xpt_s = to.zeros((H,), dtype=precision, device=device)
77 self._my_sum_xpt_sz = to.zeros((H,), dtype=precision, device=device)
78 self._my_sum_xpt_ssT = to.zeros((H, H), dtype=precision, device=device)
79 self._my_sum_xpt_ssz = (
80 to.zeros((H, H), dtype=precision, device=device) if reformulated_psi_update else None
81 )
82 self._my_sum_diag_yyT = to.zeros((D,), dtype=precision, device=device)
83 self._my_N = to.tensor([0], dtype=to.int, device=device)
84 self._eps_eyeH = to.eye(H, dtype=precision, device=device) * 1e-6
85 self._storage: Optional[Dict[int, to.Tensor]] = {} if use_storage else None
87 self._config = dict(
88 shape=self._shape,
89 reformulated_lpj=reformulated_lpj,
90 reformulated_psi_update=reformulated_psi_update,
91 use_storage=use_storage,
92 precision=precision,
93 device=device,
94 )
96 def _init_W(self, init: Optional[to.Tensor]):
97 D, H = self.shape
98 if init is not None:
99 assert init.shape == (D, H)
100 return init.to(dtype=self.precision, device=get_device())
101 else:
102 W_init = to.rand((D, H), dtype=self.precision, device=get_device())
103 broadcast(W_init)
104 return W_init
106 def _init_sigma2(self, init: Optional[to.Tensor]):
107 if init is not None:
108 assert init.shape == (1,)
109 return init.to(dtype=self.precision, device=get_device())
110 else:
111 return to.tensor([1.0], dtype=self.precision, device=get_device())
113 def _init_mus(self, init: Optional[to.Tensor]):
114 H = self.shape[1]
115 if init is not None:
116 assert init.shape == (H,)
117 return init.to(dtype=self.precision, device=get_device())
118 else:
119 mus_init = to.normal(
120 mean=to.zeros(H, dtype=self.precision, device=get_device()),
121 std=to.ones(H, dtype=self.precision, device=get_device()),
122 )
123 broadcast(mus_init)
124 return mus_init
126 def _init_Psi(self, init: Optional[to.Tensor]):
127 H = self.shape[1]
128 if init is not None:
129 assert init.shape == (H, H)
130 return init.to(dtype=self.precision, device=get_device())
131 else:
132 return to.eye(H, dtype=self.precision, device=get_device())
134 def _init_pies(self, init: Optional[to.Tensor]):
135 H = self.shape[1]
136 if init is not None:
137 assert init.shape == (H,)
138 return init.to(dtype=self.precision, device=get_device())
139 else:
140 return 0.1 + 0.5 * to.rand(H, dtype=self.precision, device=get_device())
142 def generate_data(
143 self, N: int = None, hidden_state: to.Tensor = None
144 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
145 precision, device = self.precision, get_device()
146 D, H = self.shape
148 if hidden_state is None:
149 assert N is not None
150 pies = self.theta["pies"]
151 hidden_state = to.rand((N, H), dtype=precision, device=device) < pies
152 must_return_hidden_state = True
153 else:
154 shape = hidden_state.shape
155 if N is None:
156 N = shape[0]
157 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
158 must_return_hidden_state = False
160 Z = to.distributions.multivariate_normal.MultivariateNormal(
161 loc=self.theta["mus"],
162 covariance_matrix=self.theta["Psi"],
163 )
165 Wbar = to.einsum("dh,nh->nd", (self.theta["W"], hidden_state * Z.sample((N,))))
167 Y = to.distributions.multivariate_normal.MultivariateNormal(
168 loc=Wbar,
169 covariance_matrix=self.theta["sigma2"]
170 * to.eye(D, dtype=self.theta["W"].dtype, device=get_device()),
171 )
173 return (Y.sample(), hidden_state) if must_return_hidden_state else Y.sample()
175 def _lpj_fn(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
176 """
177 Straightforward batchified implementation of log-pseudo joint for SSSC
178 """
179 precision = self.precision
180 W, sigma2, _pies, mus, Psi = (
181 self.theta["W"],
182 self.theta["sigma2"],
183 self.theta["pies"],
184 self.theta["mus"],
185 self.theta["Psi"],
186 )
187 pies = _pies.clamp(1e-2, 1.0 - 1e-2)
188 Kfloat = states.type_as(pies)
189 N, D, S, H = data.shape + states.shape[1:]
190 eyeD = to.eye(D, dtype=precision, device=get_device())
192 s1 = Kfloat @ to.log(pies / (1.0 - pies)) # (N, S)
193 Ws = Kfloat.unsqueeze(2) * W.unsqueeze(0).unsqueeze(1) # (N, S, D, H)
194 data_norm = data.unsqueeze(1) - Ws @ mus # (N, S, D)
195 data_norm[to.isnan(data_norm)] = 0.0
197 WsPsi = Ws @ Psi # (N, S, D, H)
198 WsPsiWsT = (
199 to.matmul(WsPsi, Ws.permute([0, 1, 3, 2]))
200 if precision == to.float32
201 else to.einsum("nsxh,nsyh->nsxy", WsPsi, Ws)
202 ) # (N, S, D, D)
203 C_s = WsPsiWsT + sigma2 * eyeD.unsqueeze(0).unsqueeze(1) # (N, S, D, D)
204 log_det_C_s = to.linalg.slogdet(C_s)[1]
205 try:
206 Inv_C_s = to.linalg.inv(C_s)
207 except Exception:
208 Inv_C_s = to.linalg.pinv(C_s)
210 return (
211 s1
212 - 0.5 * log_det_C_s
213 - 0.5 * (to.einsum("nsx,nsxd->nsd", data_norm, Inv_C_s) * data_norm).sum(dim=2)
214 )
216 def _common_e_m_step_terms(
217 self, state: to.Tensor, inds_d_not_isnan: to.Tensor
218 ) -> Tuple[to.Tensor, to.Tensor, to.Tensor, to.Tensor, to.Tensor, to.Tensor]:
219 W = self.theta["W"]
220 sigma2 = self.theta["sigma2"]
221 Psi = self.theta["Psi"]
222 mus = self.theta["mus"]
224 W_s = W[inds_d_not_isnan][:, state]
225 Psi_s = Psi[state, :][:, state]
226 mus_s = mus[state]
228 try:
229 Inv_Psi_s = to.linalg.inv(Psi_s)
230 except Exception:
231 Inv_Psi_s = to.linalg.pinv(Psi_s)
233 Inv_Lambda_s = W_s.t() @ W_s / sigma2 + Inv_Psi_s # (|state|, |state|)
234 try:
235 Lambda_s = to.linalg.inv(Inv_Lambda_s)
236 except Exception:
237 Lambda_s = to.linalg.pinv(Inv_Lambda_s)
239 Lambda_s_W_s_sigma2inv = Lambda_s @ W_s.t() / sigma2 # (|state|, D_nonnan)
241 return (
242 W_s,
243 mus_s,
244 Psi_s,
245 Inv_Lambda_s,
246 Lambda_s,
247 Lambda_s_W_s_sigma2inv,
248 )
250 def _check_if_storage_reliable(self, incomplete: bool, batch_size: int):
251 """Disable the storage logic by setting `self._use_storage=False` if data is incomplete
252 and if the specified batch_size is larger than one. The terms stored by the storage are
253 only data-independent if the data does not contain missing values; otherwise,
254 data-dependent indices of missing values are included in the computations of the
255 respective terms.
257 :param incomplete: Boolean indicating whether the data contains missing values
258 :param batch_size: Batch size
259 """
261 use_storage = self._use_storage if not (incomplete and batch_size > 1) else False
262 if self._use_storage != use_storage:
263 pprint("Disabled storage (inaccurate for incomplete data and batch_size > 1)")
264 self._use_storage = use_storage
266 def _reformulated_lpj_fn(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
267 """
268 Batchified implementation of log-pseudo joint for SSSC using matrix determinant lemma and
269 Woodbury matrix identity to compute determinant and inverse of matrix C_s
270 """
271 precision = self.precision
272 sigma2, _pies = (
273 self.theta["sigma2"],
274 self.theta["pies"],
275 )
276 pies = _pies.clamp(1e-2, 1.0 - 1e-2)
277 Kbool = states.to(dtype=to.bool)
278 Kfloat = states.to(dtype=precision)
279 batch_size, S = data.shape[0], Kbool.shape[1]
281 self._check_if_storage_reliable(incomplete=to.isnan(data).any(), batch_size=batch_size)
282 use_storage = self._use_storage
284 notnan = to.logical_not(to.isnan(data))
286 lpj = Kfloat @ to.log(pies / (1.0 - pies)) # initial allocation, (N, S)
287 for n in range(batch_size):
288 for s in range(S):
289 hsh = _get_hash(Kbool[n, s])
290 datapoint_notnan, D_notnan = data[n][notnan[n]], notnan[n].sum()
291 if use_storage and self._storage is not None and hsh in self._storage:
292 W_s, mus_s, log_det_C_s_wo_last_term, Inv_C_s = (
293 self._storage[hsh]["W_s"],
294 self._storage[hsh]["mus_s"],
295 self._storage[hsh]["log_det_C_s_wo_last_term"],
296 self._storage[hsh]["Inv_C_s"],
297 )
298 else:
299 (
300 W_s,
301 mus_s,
302 Psi_s,
303 Inv_Lambda_s,
304 Lambda_s,
305 Lambda_s_W_s_sigma2inv,
306 ) = self._common_e_m_step_terms(Kbool[n, s], notnan[n])
308 Inv_C_s = (
309 to.eye(D_notnan, dtype=self.precision, device=get_device()) / sigma2
310 - W_s @ Lambda_s_W_s_sigma2inv / sigma2
311 ) # (D_nonnan, D_nonnan)
312 log_det_C_s_wo_last_term = (
313 to.linalg.slogdet(Inv_Lambda_s)[1] + to.linalg.slogdet(Psi_s)[1]
314 ) # matrix determinant lemma, last term added in log_joint (1,)
316 if use_storage:
317 assert self._storage is not None
318 self._storage[hsh] = {
319 "W_s": W_s,
320 "mus_s": mus_s,
321 "Lambda_s": Lambda_s,
322 "Lambda_s_W_s_sigma2inv": Lambda_s_W_s_sigma2inv,
323 "log_det_C_s_wo_last_term": log_det_C_s_wo_last_term,
324 "Inv_C_s": Inv_C_s,
325 }
327 datapoint_norm = datapoint_notnan - W_s @ mus_s # (D_nonnan,)
329 lpj[n, s] -= (
330 0.5
331 * (
332 log_det_C_s_wo_last_term
333 + (datapoint_norm * (Inv_C_s @ datapoint_norm)).sum()
334 ).item()
335 )
337 return lpj
339 def log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
340 """Evaluate log-pseudo-joints for SSSC."""
341 lpj_fn = self._reformulated_lpj_fn if self._reformulated_lpj else self._lpj_fn
342 lpj = lpj_fn(data, states)
343 min_ = to.finfo(self.precision).min
344 lpj[to.isnan(lpj)] = min_
345 lpj[to.isinf(lpj)] = min_
346 return lpj
348 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj=None) -> to.Tensor:
349 """Evaluate log-joints for SSSC."""
350 assert states.dtype == to.uint8
351 notnan = to.logical_not(to.isnan(data))
352 if lpj is None:
353 lpj = self.log_pseudo_joint(data, states)
354 # TODO: could pre-evaluate the constant factor once per epoch
355 pies = self.theta["pies"].clamp(1e-2, 1.0 - 1e-2)
356 D = to.sum(notnan, dim=1) # (N,)
357 logjoints = (
358 lpj
359 + to.log(1.0 - pies).sum()
360 - D.unsqueeze(1) / 2.0 * (self._log2pi + to.log(self.theta["sigma2"]))
361 )
362 assert logjoints.shape == lpj.shape
363 assert not to.isnan(logjoints).any() and not to.isinf(logjoints).any()
364 return logjoints
366 def update_param_batch(
367 self,
368 idx: to.Tensor,
369 batch: to.Tensor,
370 states: TVOVariationalStates,
371 **kwargs: Dict[str, Any],
372 ) -> None:
373 precision = self.precision
374 lpj = states.lpj[idx]
375 Kbool = states.K[idx].to(dtype=to.bool)
376 Kfloat = states.K[idx].to(dtype=lpj.dtype)
377 batch_size, S, H = Kbool.shape
379 use_storage = self._use_storage and self._storage is not None and len(self._storage) > 0
381 # TODO: Add option to neglect reconstructed values
382 notnan = to.ones_like(batch, dtype=to.bool, device=batch.device)
384 batch_kappas = to.zeros((batch_size, S, H), dtype=precision, device=get_device())
385 batch_Lambdas_plus_kappas_kappasT = to.zeros(
386 (batch_size, S, H, H), dtype=precision, device=get_device()
387 )
388 for n in range(batch_size):
389 for s in range(S):
390 state = Kbool[n, s]
391 if state.sum() == 0:
392 continue
393 hsh = _get_hash(state)
395 datapoint = batch[n]
397 if use_storage:
398 assert self._storage is not None
399 assert hsh in self._storage
400 W_s, mus_s, Lambda_s, Lambda_s_W_s_sigma2inv = (
401 self._storage[hsh]["W_s"],
402 self._storage[hsh]["mus_s"],
403 self._storage[hsh]["Lambda_s"],
404 self._storage[hsh]["Lambda_s_W_s_sigma2inv"],
405 )
406 else:
407 (
408 W_s,
409 mus_s,
410 _,
411 _,
412 Lambda_s,
413 Lambda_s_W_s_sigma2inv,
414 ) = self._common_e_m_step_terms(state, notnan[n])
416 datapoint_norm = datapoint - W_s @ mus_s # (D_nonnan,)
418 batch_kappas[n, s][state] = (
419 mus_s + Lambda_s_W_s_sigma2inv @ datapoint_norm
420 ) # is (|state|,)
421 batch_Lambdas_plus_kappas_kappasT[n, s][to.outer(state, state)] = (
422 Lambda_s + to.outer(batch_kappas[n, s][state], batch_kappas[n, s][state])
423 ).flatten() # (|state|, |state|)
425 batch_xpt_s = mean_posterior(Kfloat, lpj) # is (batch_size,H)
426 batch_xpt_ssT = mean_posterior(
427 Kfloat.unsqueeze(3) * Kfloat.unsqueeze(2), lpj
428 ) # (batch_size, H, H)
429 batch_xpt_sz = mean_posterior(batch_kappas, lpj) # (batch_size, H)
430 batch_xpt_szszT = mean_posterior(
431 batch_Lambdas_plus_kappas_kappasT, lpj
432 ) # (batch_size, H, H)
433 batch_xpt_sszT = (
434 (batch_xpt_s.unsqueeze(2) * batch_xpt_sz.unsqueeze(1))
435 if self._reformulated_psi_update
436 else None
437 )
438 # is (batch_size, H, H)
440 self._my_sum_xpt_s.add_(to.sum(batch_xpt_s, dim=0)) # (H,)
441 self._my_sum_xpt_ssT.add_(to.sum(batch_xpt_ssT, dim=0)) # (H, H)
442 self._my_sum_xpt_sz.add_(to.sum(batch_xpt_sz, dim=0)) # (H,)
443 self._my_sum_xpt_sz_xpt_szT.add_(batch_xpt_sz.t() @ batch_xpt_sz) # (H, H)
444 self._my_sum_xpt_szszT.add_(to.sum(batch_xpt_szszT, dim=0)) # (H, H)
445 self._my_sum_diag_yyT.add_(to.sum(batch**2, dim=0)) # (D,)
446 self._my_sum_y_szT.add_(batch.t() @ batch_xpt_sz) # (D, H)
447 self._my_N.add_(batch_size) # (1,)
448 if self._reformulated_psi_update:
449 assert self._my_sum_xpt_ssz is not None and batch_xpt_sszT is not None
450 self._my_sum_xpt_ssz.add_(to.sum(batch_xpt_sszT, dim=0)) # (H, H)
452 return None
454 def _invert_my_sum_xpt_ssT(self) -> to.Tensor:
455 eps_eyeH = self._eps_eyeH
456 try:
457 Inv_my_sum_xpt_ssT = to.linalg.inv(self._my_sum_xpt_ssT)
458 except Exception:
459 try:
460 Inv_my_sum_xpt_ssT = to.linalg.inv(self._my_sum_xpt_ssT + eps_eyeH)
461 pprint("Psi update: Addd diag(eps) before computing inverse")
462 except Exception:
463 Inv_my_sum_xpt_ssT = to.linalg.pinv(self._my_sum_xpt_ssT + eps_eyeH)
464 pprint("Psi update: Added diag(eps) and computed pseudo-inverse")
465 return Inv_my_sum_xpt_ssT
467 def update_param_epoch(self) -> None:
468 theta = self.theta
469 precision = self.precision
470 device = get_device()
471 dtype_eps, eps = to.finfo(precision).eps, 1e-5
472 eps_eyeH = self._eps_eyeH
473 D, H = self.shape
475 W, sigma2, pies, Psi, mus = (
476 theta["W"],
477 theta["sigma2"],
478 theta["pies"],
479 theta["Psi"],
480 theta["mus"],
481 )
483 all_reduce(self._my_sum_y_szT) # (D, H)
484 all_reduce(self._my_sum_xpt_szszT) # (H, H)
485 all_reduce(self._my_sum_xpt_sz_xpt_szT) # (H, H)
486 all_reduce(self._my_sum_xpt_s) # (H,)
487 all_reduce(self._my_sum_xpt_sz) # (H,)
488 all_reduce(self._my_sum_xpt_ssT) # (H, H)
489 all_reduce(self._my_sum_diag_yyT) # (D,)
490 all_reduce(self._my_N) # (1,)
492 N = self._my_N.item()
494 Inv_my_sum_xpt_ssT = self._invert_my_sum_xpt_ssT()
496 try:
497 sum_xpt_szszT_inv = to.linalg.inv(self._my_sum_xpt_szszT)
498 W[:] = self._my_sum_y_szT @ sum_xpt_szszT_inv
499 except Exception:
500 try:
501 noise = eps * to.randn(H, dtype=precision, device=device)
502 noise = noise.unsqueeze(1) * noise.unsqueeze(0)
503 sum_xpt_szszT_inv = to.linalg.pinv(self._my_sum_xpt_szszT + noise)
504 W[:] = self._my_sum_y_szT @ sum_xpt_szszT_inv
505 pprint("W update: Used noisy pseudo-inverse")
506 except Exception:
507 W[:] = W + eps * to.randn_like(W)
508 pprint("W update: Failed to compute W^(new). Pertubed current W with AWGN.")
510 pies[:] = self._my_sum_xpt_s / N
511 mus[:] = self._my_sum_xpt_sz / (self._my_sum_xpt_s + dtype_eps)
512 if self._reformulated_psi_update:
513 assert self._my_sum_xpt_ssz is not None
514 all_reduce(self._my_sum_xpt_ssz) # (H, H)
515 _Psi = (
516 to.outer(mus, mus) * self._my_sum_xpt_ssT
517 + self._my_sum_xpt_szszT
518 - 2.0 * mus.unsqueeze(1) * self._my_sum_xpt_ssz
519 )
520 Psi[:] = _Psi * Inv_my_sum_xpt_ssT + eps_eyeH
521 self._my_sum_xpt_ssz[:] = 0.0
522 else:
523 Psi[:] = (
524 self._my_sum_xpt_szszT - self._my_sum_xpt_ssT * to.outer(mus, mus)
525 ) * Inv_my_sum_xpt_ssT + eps_eyeH
526 sigma2[:] = (
527 self._my_sum_diag_yyT.sum() - to.trace(self._my_sum_xpt_sz_xpt_szT @ (W.t() @ W))
528 ) / N / D + eps
530 self._my_sum_y_szT[:] = 0.0
531 self._my_sum_xpt_szszT[:] = 0.0
532 self._my_sum_xpt_s[:] = 0.0
533 self._my_sum_xpt_sz[:] = 0.0
534 self._my_sum_xpt_ssT[:] = 0.0
535 self._my_sum_diag_yyT[:] = 0.0
536 self._my_sum_xpt_sz_xpt_szT[:] = 0.0
537 self._my_N[:] = 0.0
538 if self._use_storage:
539 assert self._storage is not None
540 self._storage.clear()
542 def data_estimator(
543 self,
544 idx: to.Tensor,
545 batch: to.Tensor,
546 states: TVOVariationalStates,
547 ) -> to.Tensor:
548 """Estimator used for data reconstruction. Data reconstruction can only be supported
549 by a model if it implements this method. The estimator to be implemented is defined
550 as follows:""" r"""
551 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
552 """
553 # TODO Find solution to avoid redundant computations in `data_estimator` and
554 # `log_pseudo_joint`
555 precision = self.precision
556 lpj = states.lpj[idx]
557 K = states.K[idx]
558 batch_size, S, H = K.shape
559 Kbool = K.to(dtype=to.bool)
560 W = self.theta["W"]
561 use_storage = self._use_storage and self._storage is not None and len(self._storage) > 0
563 notnan = to.logical_not(to.isnan(batch))
565 batch_kappas = to.zeros((batch_size, S, H), dtype=precision, device=get_device())
566 for n in range(batch_size):
567 for s in range(S):
568 state = Kbool[n, s]
569 if state.sum() == 0:
570 continue
571 hsh = _get_hash(state)
573 datapoint_notnan = batch[n][notnan[n]]
575 if use_storage:
576 assert self._storage is not None
577 assert hsh in self._storage
578 W_s, mus_s, Lambda_s_W_s_sigma2inv = (
579 self._storage[hsh]["W_s"],
580 self._storage[hsh]["mus_s"],
581 self._storage[hsh]["Lambda_s_W_s_sigma2inv"],
582 )
583 else:
584 (
585 W_s,
586 mus_s,
587 _,
588 _,
589 _,
590 Lambda_s_W_s_sigma2inv,
591 ) = self._common_e_m_step_terms(state, notnan[n])
593 datapoint_norm = datapoint_notnan - W_s @ mus_s # (D_nonnan,)
595 batch_kappas[n, s][state] = (
596 mus_s + Lambda_s_W_s_sigma2inv @ datapoint_norm
597 ) # is (|state|,)
599 return to.sum(W.unsqueeze(0) * mean_posterior(batch_kappas, lpj).unsqueeze(1), dim=2)