Coverage for tvo/models/sssc.py: 73%

307 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6from typing import Dict, Optional, Tuple, Union, Any 

7from math import pi as MATH_PI 

8from tvo import get_device 

9from tvo.utils.model_protocols import Sampler, Optimized, Reconstructor 

10from tvo.utils.parallel import broadcast, all_reduce, pprint 

11from tvo.variational.TVOVariationalStates import TVOVariationalStates 

12from tvo.variational._utils import mean_posterior 

13 

14 

15def _get_hash(x: to.Tensor) -> int: 

16 return hash(x.detach().cpu().numpy().tobytes()) 

17 

18 

19class SSSC(Sampler, Optimized, Reconstructor): 

20 def __init__( 

21 self, 

22 H: int, 

23 D: int, 

24 W_init: to.Tensor = None, 

25 sigma2_init: to.Tensor = None, 

26 mus_init: to.Tensor = None, 

27 Psi_init: to.Tensor = None, 

28 pies_init: to.Tensor = None, 

29 reformulated_lpj: bool = True, 

30 use_storage: bool = True, 

31 reformulated_psi_update: bool = False, 

32 precision: to.dtype = to.float32, 

33 ): 

34 """Spike-And-Slab Sparse Coding (SSSC) model. 

35 

36 :param H: Number of hidden units. 

37 :param D: Number of observables. 

38 :param W_init: Tensor with shape (H, D), initializes SSSC weights. 

39 :param sigma2_init: Tensor initializing SSSC observable variance. 

40 :param mus_init: Tensor with shape (H,), initializes SSSC latent means. 

41 :param Psi_init: Tensor with shape (H, H), initializes SSSC latent variance. 

42 :param pies_init: Tensor with shape (H,), initializes SSSC priors. 

43 :param reformulated_lpj: Use looped instead of batchified E-step and mathematically 

44 reformulated form of the log-pseudo-joint formula (exploiting 

45 matrix determinant lemma and Woodbury matrix identity). Yields 

46 more accurate solutions in large dimensions (i.e. large D and H). 

47 :param use_storage: Whether to memorize state vector-dependent and datapoint independent- 

48 terms computed in the E-step. Terms will be looked-up rather than re- 

49 computed if a datapoint evaluates a state that has been evaluated for 

50 another datapoint before. The storage will be cleared after each epoch. 

51 :param reformulated_psi_update: Whether to update Psi using reformulated form of the 

52 update equation. 

53 :param precision: Floating point precision required. Must be one of torch.float32 or 

54 torch.float64. 

55 """ 

56 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}" 

57 device = get_device() 

58 self._precision = precision 

59 self._shape = (D, H) 

60 self._reformulated_lpj = reformulated_lpj 

61 self._use_storage = use_storage 

62 self._reformulated_psi_update = reformulated_psi_update 

63 

64 self._theta: Dict[str, to.Tensor] = {} 

65 self._theta["W"] = self._init_W(W_init) 

66 self._theta["sigma2"] = self._init_sigma2(sigma2_init) 

67 self._theta["mus"] = self._init_mus(mus_init) 

68 self._theta["Psi"] = self._init_Psi(Psi_init) 

69 self._theta["pies"] = self._init_pies(pies_init) 

70 

71 self._log2pi = to.log(to.tensor([2.0 * MATH_PI], dtype=precision, device=device)) 

72 

73 self._my_sum_y_szT = to.zeros((D, H), dtype=precision, device=device) 

74 self._my_sum_xpt_sz_xpt_szT = to.zeros((H, H), dtype=precision, device=device) 

75 self._my_sum_xpt_szszT = to.zeros((H, H), dtype=precision, device=device) 

76 self._my_sum_xpt_s = to.zeros((H,), dtype=precision, device=device) 

77 self._my_sum_xpt_sz = to.zeros((H,), dtype=precision, device=device) 

78 self._my_sum_xpt_ssT = to.zeros((H, H), dtype=precision, device=device) 

79 self._my_sum_xpt_ssz = ( 

80 to.zeros((H, H), dtype=precision, device=device) if reformulated_psi_update else None 

81 ) 

82 self._my_sum_diag_yyT = to.zeros((D,), dtype=precision, device=device) 

83 self._my_N = to.tensor([0], dtype=to.int, device=device) 

84 self._eps_eyeH = to.eye(H, dtype=precision, device=device) * 1e-6 

85 self._storage: Optional[Dict[int, to.Tensor]] = {} if use_storage else None 

86 

87 self._config = dict( 

88 shape=self._shape, 

89 reformulated_lpj=reformulated_lpj, 

90 reformulated_psi_update=reformulated_psi_update, 

91 use_storage=use_storage, 

92 precision=precision, 

93 device=device, 

94 ) 

95 

96 def _init_W(self, init: Optional[to.Tensor]): 

97 D, H = self.shape 

98 if init is not None: 

99 assert init.shape == (D, H) 

100 return init.to(dtype=self.precision, device=get_device()) 

101 else: 

102 W_init = to.rand((D, H), dtype=self.precision, device=get_device()) 

103 broadcast(W_init) 

104 return W_init 

105 

106 def _init_sigma2(self, init: Optional[to.Tensor]): 

107 if init is not None: 

108 assert init.shape == (1,) 

109 return init.to(dtype=self.precision, device=get_device()) 

110 else: 

111 return to.tensor([1.0], dtype=self.precision, device=get_device()) 

112 

113 def _init_mus(self, init: Optional[to.Tensor]): 

114 H = self.shape[1] 

115 if init is not None: 

116 assert init.shape == (H,) 

117 return init.to(dtype=self.precision, device=get_device()) 

118 else: 

119 mus_init = to.normal( 

120 mean=to.zeros(H, dtype=self.precision, device=get_device()), 

121 std=to.ones(H, dtype=self.precision, device=get_device()), 

122 ) 

123 broadcast(mus_init) 

124 return mus_init 

125 

126 def _init_Psi(self, init: Optional[to.Tensor]): 

127 H = self.shape[1] 

128 if init is not None: 

129 assert init.shape == (H, H) 

130 return init.to(dtype=self.precision, device=get_device()) 

131 else: 

132 return to.eye(H, dtype=self.precision, device=get_device()) 

133 

134 def _init_pies(self, init: Optional[to.Tensor]): 

135 H = self.shape[1] 

136 if init is not None: 

137 assert init.shape == (H,) 

138 return init.to(dtype=self.precision, device=get_device()) 

139 else: 

140 return 0.1 + 0.5 * to.rand(H, dtype=self.precision, device=get_device()) 

141 

142 def generate_data( 

143 self, N: int = None, hidden_state: to.Tensor = None 

144 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

145 precision, device = self.precision, get_device() 

146 D, H = self.shape 

147 

148 if hidden_state is None: 

149 assert N is not None 

150 pies = self.theta["pies"] 

151 hidden_state = to.rand((N, H), dtype=precision, device=device) < pies 

152 must_return_hidden_state = True 

153 else: 

154 shape = hidden_state.shape 

155 if N is None: 

156 N = shape[0] 

157 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

158 must_return_hidden_state = False 

159 

160 Z = to.distributions.multivariate_normal.MultivariateNormal( 

161 loc=self.theta["mus"], 

162 covariance_matrix=self.theta["Psi"], 

163 ) 

164 

165 Wbar = to.einsum("dh,nh->nd", (self.theta["W"], hidden_state * Z.sample((N,)))) 

166 

167 Y = to.distributions.multivariate_normal.MultivariateNormal( 

168 loc=Wbar, 

169 covariance_matrix=self.theta["sigma2"] 

170 * to.eye(D, dtype=self.theta["W"].dtype, device=get_device()), 

171 ) 

172 

173 return (Y.sample(), hidden_state) if must_return_hidden_state else Y.sample() 

174 

175 def _lpj_fn(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

176 """ 

177 Straightforward batchified implementation of log-pseudo joint for SSSC 

178 """ 

179 precision = self.precision 

180 W, sigma2, _pies, mus, Psi = ( 

181 self.theta["W"], 

182 self.theta["sigma2"], 

183 self.theta["pies"], 

184 self.theta["mus"], 

185 self.theta["Psi"], 

186 ) 

187 pies = _pies.clamp(1e-2, 1.0 - 1e-2) 

188 Kfloat = states.type_as(pies) 

189 N, D, S, H = data.shape + states.shape[1:] 

190 eyeD = to.eye(D, dtype=precision, device=get_device()) 

191 

192 s1 = Kfloat @ to.log(pies / (1.0 - pies)) # (N, S) 

193 Ws = Kfloat.unsqueeze(2) * W.unsqueeze(0).unsqueeze(1) # (N, S, D, H) 

194 data_norm = data.unsqueeze(1) - Ws @ mus # (N, S, D) 

195 data_norm[to.isnan(data_norm)] = 0.0 

196 

197 WsPsi = Ws @ Psi # (N, S, D, H) 

198 WsPsiWsT = ( 

199 to.matmul(WsPsi, Ws.permute([0, 1, 3, 2])) 

200 if precision == to.float32 

201 else to.einsum("nsxh,nsyh->nsxy", WsPsi, Ws) 

202 ) # (N, S, D, D) 

203 C_s = WsPsiWsT + sigma2 * eyeD.unsqueeze(0).unsqueeze(1) # (N, S, D, D) 

204 log_det_C_s = to.linalg.slogdet(C_s)[1] 

205 try: 

206 Inv_C_s = to.linalg.inv(C_s) 

207 except Exception: 

208 Inv_C_s = to.linalg.pinv(C_s) 

209 

210 return ( 

211 s1 

212 - 0.5 * log_det_C_s 

213 - 0.5 * (to.einsum("nsx,nsxd->nsd", data_norm, Inv_C_s) * data_norm).sum(dim=2) 

214 ) 

215 

216 def _common_e_m_step_terms( 

217 self, state: to.Tensor, inds_d_not_isnan: to.Tensor 

218 ) -> Tuple[to.Tensor, to.Tensor, to.Tensor, to.Tensor, to.Tensor, to.Tensor]: 

219 W = self.theta["W"] 

220 sigma2 = self.theta["sigma2"] 

221 Psi = self.theta["Psi"] 

222 mus = self.theta["mus"] 

223 

224 W_s = W[inds_d_not_isnan][:, state] 

225 Psi_s = Psi[state, :][:, state] 

226 mus_s = mus[state] 

227 

228 try: 

229 Inv_Psi_s = to.linalg.inv(Psi_s) 

230 except Exception: 

231 Inv_Psi_s = to.linalg.pinv(Psi_s) 

232 

233 Inv_Lambda_s = W_s.t() @ W_s / sigma2 + Inv_Psi_s # (|state|, |state|) 

234 try: 

235 Lambda_s = to.linalg.inv(Inv_Lambda_s) 

236 except Exception: 

237 Lambda_s = to.linalg.pinv(Inv_Lambda_s) 

238 

239 Lambda_s_W_s_sigma2inv = Lambda_s @ W_s.t() / sigma2 # (|state|, D_nonnan) 

240 

241 return ( 

242 W_s, 

243 mus_s, 

244 Psi_s, 

245 Inv_Lambda_s, 

246 Lambda_s, 

247 Lambda_s_W_s_sigma2inv, 

248 ) 

249 

250 def _check_if_storage_reliable(self, incomplete: bool, batch_size: int): 

251 """Disable the storage logic by setting `self._use_storage=False` if data is incomplete 

252 and if the specified batch_size is larger than one. The terms stored by the storage are 

253 only data-independent if the data does not contain missing values; otherwise, 

254 data-dependent indices of missing values are included in the computations of the 

255 respective terms. 

256 

257 :param incomplete: Boolean indicating whether the data contains missing values 

258 :param batch_size: Batch size 

259 """ 

260 

261 use_storage = self._use_storage if not (incomplete and batch_size > 1) else False 

262 if self._use_storage != use_storage: 

263 pprint("Disabled storage (inaccurate for incomplete data and batch_size > 1)") 

264 self._use_storage = use_storage 

265 

266 def _reformulated_lpj_fn(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

267 """ 

268 Batchified implementation of log-pseudo joint for SSSC using matrix determinant lemma and 

269 Woodbury matrix identity to compute determinant and inverse of matrix C_s 

270 """ 

271 precision = self.precision 

272 sigma2, _pies = ( 

273 self.theta["sigma2"], 

274 self.theta["pies"], 

275 ) 

276 pies = _pies.clamp(1e-2, 1.0 - 1e-2) 

277 Kbool = states.to(dtype=to.bool) 

278 Kfloat = states.to(dtype=precision) 

279 batch_size, S = data.shape[0], Kbool.shape[1] 

280 

281 self._check_if_storage_reliable(incomplete=to.isnan(data).any(), batch_size=batch_size) 

282 use_storage = self._use_storage 

283 

284 notnan = to.logical_not(to.isnan(data)) 

285 

286 lpj = Kfloat @ to.log(pies / (1.0 - pies)) # initial allocation, (N, S) 

287 for n in range(batch_size): 

288 for s in range(S): 

289 hsh = _get_hash(Kbool[n, s]) 

290 datapoint_notnan, D_notnan = data[n][notnan[n]], notnan[n].sum() 

291 if use_storage and self._storage is not None and hsh in self._storage: 

292 W_s, mus_s, log_det_C_s_wo_last_term, Inv_C_s = ( 

293 self._storage[hsh]["W_s"], 

294 self._storage[hsh]["mus_s"], 

295 self._storage[hsh]["log_det_C_s_wo_last_term"], 

296 self._storage[hsh]["Inv_C_s"], 

297 ) 

298 else: 

299 ( 

300 W_s, 

301 mus_s, 

302 Psi_s, 

303 Inv_Lambda_s, 

304 Lambda_s, 

305 Lambda_s_W_s_sigma2inv, 

306 ) = self._common_e_m_step_terms(Kbool[n, s], notnan[n]) 

307 

308 Inv_C_s = ( 

309 to.eye(D_notnan, dtype=self.precision, device=get_device()) / sigma2 

310 - W_s @ Lambda_s_W_s_sigma2inv / sigma2 

311 ) # (D_nonnan, D_nonnan) 

312 log_det_C_s_wo_last_term = ( 

313 to.linalg.slogdet(Inv_Lambda_s)[1] + to.linalg.slogdet(Psi_s)[1] 

314 ) # matrix determinant lemma, last term added in log_joint (1,) 

315 

316 if use_storage: 

317 assert self._storage is not None 

318 self._storage[hsh] = { 

319 "W_s": W_s, 

320 "mus_s": mus_s, 

321 "Lambda_s": Lambda_s, 

322 "Lambda_s_W_s_sigma2inv": Lambda_s_W_s_sigma2inv, 

323 "log_det_C_s_wo_last_term": log_det_C_s_wo_last_term, 

324 "Inv_C_s": Inv_C_s, 

325 } 

326 

327 datapoint_norm = datapoint_notnan - W_s @ mus_s # (D_nonnan,) 

328 

329 lpj[n, s] -= ( 

330 0.5 

331 * ( 

332 log_det_C_s_wo_last_term 

333 + (datapoint_norm * (Inv_C_s @ datapoint_norm)).sum() 

334 ).item() 

335 ) 

336 

337 return lpj 

338 

339 def log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

340 """Evaluate log-pseudo-joints for SSSC.""" 

341 lpj_fn = self._reformulated_lpj_fn if self._reformulated_lpj else self._lpj_fn 

342 lpj = lpj_fn(data, states) 

343 min_ = to.finfo(self.precision).min 

344 lpj[to.isnan(lpj)] = min_ 

345 lpj[to.isinf(lpj)] = min_ 

346 return lpj 

347 

348 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj=None) -> to.Tensor: 

349 """Evaluate log-joints for SSSC.""" 

350 assert states.dtype == to.uint8 

351 notnan = to.logical_not(to.isnan(data)) 

352 if lpj is None: 

353 lpj = self.log_pseudo_joint(data, states) 

354 # TODO: could pre-evaluate the constant factor once per epoch 

355 pies = self.theta["pies"].clamp(1e-2, 1.0 - 1e-2) 

356 D = to.sum(notnan, dim=1) # (N,) 

357 logjoints = ( 

358 lpj 

359 + to.log(1.0 - pies).sum() 

360 - D.unsqueeze(1) / 2.0 * (self._log2pi + to.log(self.theta["sigma2"])) 

361 ) 

362 assert logjoints.shape == lpj.shape 

363 assert not to.isnan(logjoints).any() and not to.isinf(logjoints).any() 

364 return logjoints 

365 

366 def update_param_batch( 

367 self, 

368 idx: to.Tensor, 

369 batch: to.Tensor, 

370 states: TVOVariationalStates, 

371 **kwargs: Dict[str, Any], 

372 ) -> None: 

373 precision = self.precision 

374 lpj = states.lpj[idx] 

375 Kbool = states.K[idx].to(dtype=to.bool) 

376 Kfloat = states.K[idx].to(dtype=lpj.dtype) 

377 batch_size, S, H = Kbool.shape 

378 

379 use_storage = self._use_storage and self._storage is not None and len(self._storage) > 0 

380 

381 # TODO: Add option to neglect reconstructed values 

382 notnan = to.ones_like(batch, dtype=to.bool, device=batch.device) 

383 

384 batch_kappas = to.zeros((batch_size, S, H), dtype=precision, device=get_device()) 

385 batch_Lambdas_plus_kappas_kappasT = to.zeros( 

386 (batch_size, S, H, H), dtype=precision, device=get_device() 

387 ) 

388 for n in range(batch_size): 

389 for s in range(S): 

390 state = Kbool[n, s] 

391 if state.sum() == 0: 

392 continue 

393 hsh = _get_hash(state) 

394 

395 datapoint = batch[n] 

396 

397 if use_storage: 

398 assert self._storage is not None 

399 assert hsh in self._storage 

400 W_s, mus_s, Lambda_s, Lambda_s_W_s_sigma2inv = ( 

401 self._storage[hsh]["W_s"], 

402 self._storage[hsh]["mus_s"], 

403 self._storage[hsh]["Lambda_s"], 

404 self._storage[hsh]["Lambda_s_W_s_sigma2inv"], 

405 ) 

406 else: 

407 ( 

408 W_s, 

409 mus_s, 

410 _, 

411 _, 

412 Lambda_s, 

413 Lambda_s_W_s_sigma2inv, 

414 ) = self._common_e_m_step_terms(state, notnan[n]) 

415 

416 datapoint_norm = datapoint - W_s @ mus_s # (D_nonnan,) 

417 

418 batch_kappas[n, s][state] = ( 

419 mus_s + Lambda_s_W_s_sigma2inv @ datapoint_norm 

420 ) # is (|state|,) 

421 batch_Lambdas_plus_kappas_kappasT[n, s][to.outer(state, state)] = ( 

422 Lambda_s + to.outer(batch_kappas[n, s][state], batch_kappas[n, s][state]) 

423 ).flatten() # (|state|, |state|) 

424 

425 batch_xpt_s = mean_posterior(Kfloat, lpj) # is (batch_size,H) 

426 batch_xpt_ssT = mean_posterior( 

427 Kfloat.unsqueeze(3) * Kfloat.unsqueeze(2), lpj 

428 ) # (batch_size, H, H) 

429 batch_xpt_sz = mean_posterior(batch_kappas, lpj) # (batch_size, H) 

430 batch_xpt_szszT = mean_posterior( 

431 batch_Lambdas_plus_kappas_kappasT, lpj 

432 ) # (batch_size, H, H) 

433 batch_xpt_sszT = ( 

434 (batch_xpt_s.unsqueeze(2) * batch_xpt_sz.unsqueeze(1)) 

435 if self._reformulated_psi_update 

436 else None 

437 ) 

438 # is (batch_size, H, H) 

439 

440 self._my_sum_xpt_s.add_(to.sum(batch_xpt_s, dim=0)) # (H,) 

441 self._my_sum_xpt_ssT.add_(to.sum(batch_xpt_ssT, dim=0)) # (H, H) 

442 self._my_sum_xpt_sz.add_(to.sum(batch_xpt_sz, dim=0)) # (H,) 

443 self._my_sum_xpt_sz_xpt_szT.add_(batch_xpt_sz.t() @ batch_xpt_sz) # (H, H) 

444 self._my_sum_xpt_szszT.add_(to.sum(batch_xpt_szszT, dim=0)) # (H, H) 

445 self._my_sum_diag_yyT.add_(to.sum(batch**2, dim=0)) # (D,) 

446 self._my_sum_y_szT.add_(batch.t() @ batch_xpt_sz) # (D, H) 

447 self._my_N.add_(batch_size) # (1,) 

448 if self._reformulated_psi_update: 

449 assert self._my_sum_xpt_ssz is not None and batch_xpt_sszT is not None 

450 self._my_sum_xpt_ssz.add_(to.sum(batch_xpt_sszT, dim=0)) # (H, H) 

451 

452 return None 

453 

454 def _invert_my_sum_xpt_ssT(self) -> to.Tensor: 

455 eps_eyeH = self._eps_eyeH 

456 try: 

457 Inv_my_sum_xpt_ssT = to.linalg.inv(self._my_sum_xpt_ssT) 

458 except Exception: 

459 try: 

460 Inv_my_sum_xpt_ssT = to.linalg.inv(self._my_sum_xpt_ssT + eps_eyeH) 

461 pprint("Psi update: Addd diag(eps) before computing inverse") 

462 except Exception: 

463 Inv_my_sum_xpt_ssT = to.linalg.pinv(self._my_sum_xpt_ssT + eps_eyeH) 

464 pprint("Psi update: Added diag(eps) and computed pseudo-inverse") 

465 return Inv_my_sum_xpt_ssT 

466 

467 def update_param_epoch(self) -> None: 

468 theta = self.theta 

469 precision = self.precision 

470 device = get_device() 

471 dtype_eps, eps = to.finfo(precision).eps, 1e-5 

472 eps_eyeH = self._eps_eyeH 

473 D, H = self.shape 

474 

475 W, sigma2, pies, Psi, mus = ( 

476 theta["W"], 

477 theta["sigma2"], 

478 theta["pies"], 

479 theta["Psi"], 

480 theta["mus"], 

481 ) 

482 

483 all_reduce(self._my_sum_y_szT) # (D, H) 

484 all_reduce(self._my_sum_xpt_szszT) # (H, H) 

485 all_reduce(self._my_sum_xpt_sz_xpt_szT) # (H, H) 

486 all_reduce(self._my_sum_xpt_s) # (H,) 

487 all_reduce(self._my_sum_xpt_sz) # (H,) 

488 all_reduce(self._my_sum_xpt_ssT) # (H, H) 

489 all_reduce(self._my_sum_diag_yyT) # (D,) 

490 all_reduce(self._my_N) # (1,) 

491 

492 N = self._my_N.item() 

493 

494 Inv_my_sum_xpt_ssT = self._invert_my_sum_xpt_ssT() 

495 

496 try: 

497 sum_xpt_szszT_inv = to.linalg.inv(self._my_sum_xpt_szszT) 

498 W[:] = self._my_sum_y_szT @ sum_xpt_szszT_inv 

499 except Exception: 

500 try: 

501 noise = eps * to.randn(H, dtype=precision, device=device) 

502 noise = noise.unsqueeze(1) * noise.unsqueeze(0) 

503 sum_xpt_szszT_inv = to.linalg.pinv(self._my_sum_xpt_szszT + noise) 

504 W[:] = self._my_sum_y_szT @ sum_xpt_szszT_inv 

505 pprint("W update: Used noisy pseudo-inverse") 

506 except Exception: 

507 W[:] = W + eps * to.randn_like(W) 

508 pprint("W update: Failed to compute W^(new). Pertubed current W with AWGN.") 

509 

510 pies[:] = self._my_sum_xpt_s / N 

511 mus[:] = self._my_sum_xpt_sz / (self._my_sum_xpt_s + dtype_eps) 

512 if self._reformulated_psi_update: 

513 assert self._my_sum_xpt_ssz is not None 

514 all_reduce(self._my_sum_xpt_ssz) # (H, H) 

515 _Psi = ( 

516 to.outer(mus, mus) * self._my_sum_xpt_ssT 

517 + self._my_sum_xpt_szszT 

518 - 2.0 * mus.unsqueeze(1) * self._my_sum_xpt_ssz 

519 ) 

520 Psi[:] = _Psi * Inv_my_sum_xpt_ssT + eps_eyeH 

521 self._my_sum_xpt_ssz[:] = 0.0 

522 else: 

523 Psi[:] = ( 

524 self._my_sum_xpt_szszT - self._my_sum_xpt_ssT * to.outer(mus, mus) 

525 ) * Inv_my_sum_xpt_ssT + eps_eyeH 

526 sigma2[:] = ( 

527 self._my_sum_diag_yyT.sum() - to.trace(self._my_sum_xpt_sz_xpt_szT @ (W.t() @ W)) 

528 ) / N / D + eps 

529 

530 self._my_sum_y_szT[:] = 0.0 

531 self._my_sum_xpt_szszT[:] = 0.0 

532 self._my_sum_xpt_s[:] = 0.0 

533 self._my_sum_xpt_sz[:] = 0.0 

534 self._my_sum_xpt_ssT[:] = 0.0 

535 self._my_sum_diag_yyT[:] = 0.0 

536 self._my_sum_xpt_sz_xpt_szT[:] = 0.0 

537 self._my_N[:] = 0.0 

538 if self._use_storage: 

539 assert self._storage is not None 

540 self._storage.clear() 

541 

542 def data_estimator( 

543 self, 

544 idx: to.Tensor, 

545 batch: to.Tensor, 

546 states: TVOVariationalStates, 

547 ) -> to.Tensor: 

548 """Estimator used for data reconstruction. Data reconstruction can only be supported 

549 by a model if it implements this method. The estimator to be implemented is defined 

550 as follows:""" r""" 

551 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

552 """ 

553 # TODO Find solution to avoid redundant computations in `data_estimator` and 

554 # `log_pseudo_joint` 

555 precision = self.precision 

556 lpj = states.lpj[idx] 

557 K = states.K[idx] 

558 batch_size, S, H = K.shape 

559 Kbool = K.to(dtype=to.bool) 

560 W = self.theta["W"] 

561 use_storage = self._use_storage and self._storage is not None and len(self._storage) > 0 

562 

563 notnan = to.logical_not(to.isnan(batch)) 

564 

565 batch_kappas = to.zeros((batch_size, S, H), dtype=precision, device=get_device()) 

566 for n in range(batch_size): 

567 for s in range(S): 

568 state = Kbool[n, s] 

569 if state.sum() == 0: 

570 continue 

571 hsh = _get_hash(state) 

572 

573 datapoint_notnan = batch[n][notnan[n]] 

574 

575 if use_storage: 

576 assert self._storage is not None 

577 assert hsh in self._storage 

578 W_s, mus_s, Lambda_s_W_s_sigma2inv = ( 

579 self._storage[hsh]["W_s"], 

580 self._storage[hsh]["mus_s"], 

581 self._storage[hsh]["Lambda_s_W_s_sigma2inv"], 

582 ) 

583 else: 

584 ( 

585 W_s, 

586 mus_s, 

587 _, 

588 _, 

589 _, 

590 Lambda_s_W_s_sigma2inv, 

591 ) = self._common_e_m_step_terms(state, notnan[n]) 

592 

593 datapoint_norm = datapoint_notnan - W_s @ mus_s # (D_nonnan,) 

594 

595 batch_kappas[n, s][state] = ( 

596 mus_s + Lambda_s_W_s_sigma2inv @ datapoint_norm 

597 ) # is (|state|,) 

598 

599 return to.sum(W.unsqueeze(0) * mean_posterior(batch_kappas, lpj).unsqueeze(1), dim=2)