Coverage for tvo/models/tvae.py: 95%

352 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5from tvo.utils.model_protocols import Trainable, Sampler, Reconstructor 

6from tvo.variational.TVOVariationalStates import TVOVariationalStates 

7from tvo.variational._utils import mean_posterior 

8from tvo.utils.parallel import all_reduce, broadcast, mpi_average_grads 

9from tvo.utils import get, CyclicLR 

10import torch.optim as opt 

11import tvo 

12import torch as to 

13from typing import Tuple, List, Dict, Iterable, Optional, Sequence, Union, Callable 

14from math import pi as MATH_PI 

15from abc import abstractmethod 

16 

17 

18def _get_net_shape(net_shape: Sequence[int] = None, W_init: Sequence[to.Tensor] = None): 

19 if net_shape is not None: 

20 return tuple(reversed(net_shape)) 

21 else: 

22 assert ( 

23 W_init is not None 

24 ), "Must pass one of `net_shape` and `W_init` to __init__ of the\ 

25 TVAE model" 

26 return tuple(w.shape[0] for w in W_init) + (W_init[-1].shape[1],) 

27 

28 

29def _init_W( 

30 net_shape: Sequence[int], precision: to.dtype, init: Optional[Sequence[to.Tensor]] 

31) -> List[to.Tensor]: 

32 """Return weights initialized with Xavier or to specified init values. 

33 

34 This method also makes sure that device and precision are the ones required by 

35 the model. 

36 """ 

37 if init is None: 

38 n_layers = len(net_shape) - 1 

39 W_shapes = ((net_shape[ln], net_shape[ln + 1]) for ln in range(n_layers)) 

40 W = list(map(to.nn.init.xavier_normal_, (to.empty(s) for s in W_shapes))) 

41 else: 

42 assert ( 

43 len(init) == len(net_shape) - 1 

44 ), f"Shape is {net_shape} but {len(init)} weights passed" 

45 Wshapes = [w.shape for w in init] 

46 expected_Wshapes = [(net_shape[ln], net_shape[ln + 1]) for ln in range(len(init))] 

47 err_msg = f"Input W shapes: {Wshapes}\nExpected W shapes {expected_Wshapes}" 

48 assert all(ws == exp_s for ws, exp_s in zip(Wshapes, expected_Wshapes)), err_msg 

49 W = list(w.clone() for w in init) 

50 for w in W: 

51 broadcast(w) 

52 return [w.to(device=tvo.get_device(), dtype=precision).requires_grad_(True) for w in W] 

53 

54 

55def _init_b( 

56 net_shape: Sequence[int], precision: to.dtype, init: Optional[Iterable[to.Tensor]] 

57) -> List[to.Tensor]: 

58 """Return biases initialized to zeros or to specified init values. 

59 

60 This method also makes sure that device and precision are the ones required by the model. 

61 """ 

62 if init is None: 

63 B = [to.zeros(s) for s in net_shape[1:]] 

64 else: 

65 assert all(b.shape == (net_shape[ln + 1],) for ln, b in enumerate(init)) 

66 B = [b.clone() for b in init] 

67 return [b.to(device=tvo.get_device(), dtype=precision).requires_grad_(True) for b in B] 

68 

69 

70def _init_pi( 

71 precision: to.dtype, init: Optional[to.Tensor], H0: int, requires_grad: bool 

72) -> to.Tensor: 

73 if init is None: 

74 pi = to.full((H0,), 1 / H0) 

75 else: 

76 assert init.shape == (H0,) 

77 pi = init.clone() 

78 return pi.to(device=tvo.get_device(), dtype=precision).requires_grad_(requires_grad) 

79 

80 

81def _init_sigma2(precision: to.dtype, init: Optional[float], requires_grad: bool) -> to.Tensor: 

82 sigma2 = to.tensor([0.01] if init is None else [init]) 

83 return sigma2.to(device=tvo.get_device(), dtype=precision).requires_grad_(requires_grad) 

84 

85 

86class _TVAE(Trainable, Sampler, Reconstructor): 

87 _theta: Dict[str, to.Tensor] 

88 _precision: to.dtype 

89 _net_shape: Sequence[int] 

90 _scheduler: opt.lr_scheduler._LRScheduler 

91 _optimizer: opt.Optimizer 

92 _activation: Optional[Callable] = None 

93 _external_model: Optional[to.nn.Module] = None 

94 

95 @abstractmethod 

96 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj: to.Tensor = None) -> to.Tensor: 

97 ... 

98 

99 def _log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

100 with to.no_grad(): 

101 lpj, _ = self._lpj_and_mlpout(data, states) 

102 return lpj 

103 

104 @abstractmethod 

105 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]: 

106 ... 

107 

108 def free_energy(self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates) -> float: 

109 with to.no_grad(): 

110 return super().free_energy(idx, batch, states) 

111 

112 def _free_energy_from_logjoints(self, logjoints: to.Tensor) -> to.Tensor: 

113 Fn = to.logsumexp(logjoints, dim=1) 

114 assert Fn.shape == (logjoints.shape[0],) 

115 assert not to.isnan(Fn).any() and not to.isinf(Fn).any() 

116 return Fn.sum() 

117 

118 def update_param_batch( 

119 self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates 

120 ) -> float: 

121 if to.isnan(batch).any(): 

122 raise RuntimeError("There are NaNs in this batch") 

123 F, mlp_out = self._optimize_nn_params(idx, batch, states) 

124 with to.no_grad(): 

125 self._accumulate_param_updates(idx, batch, states, mlp_out) 

126 return F 

127 

128 @property 

129 def shape(self) -> Tuple[int, ...]: 

130 """Shape of TVAE model as a bayes net: (D, H0) 

131 

132 Neural network shape is ignored. 

133 """ 

134 return tuple((self._net_shape[-1], self._net_shape[0])) 

135 

136 @property 

137 def net_shape(self) -> Tuple[int, ...]: 

138 """Full TVAE network shape (D, Hn, Hn-1, ..., H0).""" 

139 return tuple(reversed(self._net_shape)) 

140 

141 def _optimize_nn_params( 

142 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates 

143 ) -> Tuple[float, to.Tensor]: 

144 """ 

145 Gradient-based optimized parameters are changed in-place. All other arguments are left 

146 untouched. 

147 

148 :returns: F and mlp_output _before_ the weight update 

149 """ 

150 assert self._optimizer is not None # to make mypy happy 

151 

152 lpj, mlp_out = self._lpj_and_mlpout(data, states.K[idx]) 

153 F = self._free_energy_from_logjoints(self.log_joint(data, states.K[idx], lpj)) 

154 loss = -F / data.shape[0] 

155 loss.backward() 

156 

157 mpi_average_grads(self.theta) 

158 self._optimizer.step() 

159 self._scheduler.step() 

160 self._optimizer.zero_grad() 

161 

162 return F.item(), mlp_out 

163 

164 def _accumulate_param_updates( 

165 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor 

166 ) -> None: 

167 pass 

168 

169 def data_estimator( 

170 self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates 

171 ) -> to.Tensor: # type: ignore 

172 r""" 

173 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

174 """ 

175 

176 lpj = states.lpj[idx] 

177 K = states.K[idx] 

178 

179 with to.no_grad(): 

180 means = self.forward(K) # N,S,D 

181 

182 return mean_posterior(means, lpj) # N, D 

183 

184 @abstractmethod 

185 def forward(self, x: to.Tensor) -> to.Tensor: 

186 """Forward application of TVAE's MLP to the specified input.""" 

187 ... 

188 

189 

190class GaussianTVAE(_TVAE): 

191 def __init__( 

192 self, 

193 shape: Sequence[int] = None, 

194 precision: to.dtype = to.float64, 

195 min_lr: float = 0.001, 

196 max_lr: float = 0.01, 

197 cycliclr_step_size_up=400, 

198 pi_init: to.Tensor = None, 

199 W_init: Sequence[to.Tensor] = None, 

200 b_init: Sequence[to.Tensor] = None, 

201 sigma2_init: float = None, 

202 analytical_sigma_updates: bool = True, 

203 analytical_pi_updates: bool = True, 

204 clamp_sigma_updates: bool = False, 

205 activation: Callable = None, 

206 external_model: Optional[to.nn.Module] = None, 

207 optimizer: Optional[opt.Optimizer] = None, 

208 ): 

209 """Create a TVAE model with Gaussian observables. 

210 

211 :param shape: Network shape, from observable to most hidden: (D,...,H1,H0). One of shape, 

212 (W_init, b_init), external_model must be specified exclusively. 

213 :param precision: One of to.float32 or to.float64, indicates the floating point precision 

214 of model parameters. 

215 :param min_lr: See docs of tvo.utils.CyclicLR 

216 :param max_lr: See docs of tvo.utils.CyclicLR 

217 :param cycliclr_step_size_up: See docs of tvo.utils.CyclicLR 

218 :param pi_init: Optional tensor with initial prior values 

219 :param W_init: Optional list of tensors with initial weight values. Weight matrices 

220 must be ordered from most hidden to observable layer. One of shape, 

221 (W_init, b_init), external_model must be specified exclusively. 

222 :param b_init: Optional list of tensors with initial bias. One of shape, 

223 (W_init, b_init), external_model must be specified exclusively. 

224 :param sigma2_init: Optional initial value for model variance. 

225 :param analytical_sigma_updates: Whether sigmas should be updated via the analytical 

226 max-likelihood solution rather than gradient descent. 

227 :param analytical_pi_updates: Whether priors should be updated via the analytical 

228 max-likelihood solution rather than gradient descent. 

229 :param clamp_sigma_updates: Whether to limit the rate at which sigma can be updated. 

230 :param activation: Decoder activation function used if external_model is not specified. 

231 Defaults to ReLU if not specified and external_model not used. 

232 :param external_model: Optional decoder neural network. One of shape, (W_init, b_init), 

233 external_model must be specified exclusively. 

234 :param optimizer: Gradient optimizer (defaults to Adam if not specified) 

235 """ 

236 self._theta: Dict[str, to.Tensor] = {} 

237 self._clamp_sigma = clamp_sigma_updates 

238 self._precision = precision 

239 self._activation = activation 

240 self._external_model = external_model 

241 assert ( 

242 (shape is not None and W_init is None and b_init is None and external_model is None) 

243 or ( 

244 shape is None 

245 and W_init is not None 

246 and b_init is not None 

247 and external_model is None 

248 ) 

249 or (shape is None and W_init is None and b_init is None and external_model is not None) 

250 ), "Must exclusively specify one one `shape`, (`W_init`, `b_init`), `external_model`" 

251 

252 if external_model is not None: 

253 assert hasattr( 

254 external_model, "H0" 

255 ), "for externally defined models, H0 has to be provided manually" 

256 assert hasattr( 

257 external_model, "shape" 

258 ), "for externally defined models, shape has to be provided manually" 

259 assert activation is None, "Must specify activation as part of external_model" 

260 H0 = external_model.H0 

261 self._net_shape = external_model.shape 

262 self.W = self.b = None 

263 gd_parameters = list(external_model.parameters()) 

264 else: 

265 self._net_shape = _get_net_shape(shape, W_init) 

266 H0 = self._net_shape[0] 

267 self.W = _init_W(self._net_shape, precision, W_init) 

268 self.b = _init_b(self._net_shape, precision, b_init) 

269 self._theta.update({f"W_{i}": W for i, W in enumerate(self.W)}) 

270 self._theta.update({f"b_{i}": b for i, b in enumerate(self.b)}) 

271 gd_parameters = self.W + self.b 

272 self._activation = to.nn.ReLU() if activation is None else activation 

273 assert callable(self._activation) 

274 

275 self._theta["pies"] = _init_pi( 

276 precision, pi_init, H0, requires_grad=not analytical_pi_updates 

277 ) 

278 self._theta["sigma2"] = _init_sigma2( 

279 precision, sigma2_init, requires_grad=not analytical_sigma_updates 

280 ) 

281 

282 self._min_lr, self._max_lr, self._step_size_up = min_lr, max_lr, cycliclr_step_size_up 

283 

284 if analytical_sigma_updates: 

285 self._new_sigma2 = to.zeros(1, dtype=precision, device=tvo.get_device()) 

286 self._analytical_sigma_updates = True 

287 else: 

288 gd_parameters.append(self._theta["sigma2"]) 

289 self._analytical_sigma_updates = False 

290 

291 if analytical_pi_updates: 

292 self._new_pi = to.zeros(H0, dtype=precision, device=tvo.get_device()) 

293 self._analytical_pi_updates = True 

294 else: 

295 gd_parameters.append(self._theta["pies"]) 

296 self._analytical_pi_updates = False 

297 

298 if optimizer is None: 

299 self._optimizer = opt.Adam(gd_parameters, lr=min_lr) 

300 else: 

301 self._optimizer = optimizer 

302 

303 self._scheduler = CyclicLR( 

304 self._optimizer, 

305 base_lr=min_lr, 

306 max_lr=max_lr, 

307 step_size_up=cycliclr_step_size_up, 

308 cycle_momentum=False, 

309 ) 

310 # number of datapoints processed in a training epoch 

311 self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device()) 

312 self._config = dict( 

313 net_shape=self._net_shape, 

314 activation=self._activation, 

315 external_model=self._external_model, 

316 precision=self.precision, 

317 min_lr=self._min_lr, 

318 max_lr=self._max_lr, 

319 step_size_up=self._step_size_up, 

320 analytical_sigma_updates=self._analytical_sigma_updates, 

321 analytical_pi_updates=self._analytical_pi_updates, 

322 clamp_sigma_updates=self._clamp_sigma, 

323 device=tvo.get_device(), 

324 ) 

325 

326 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]: 

327 N = data.shape[0] 

328 N_, S, H = states.shape 

329 assert N == N_, "Shape mismatch between data and states" 

330 pi, sigma2 = get(self.theta, "pies", "sigma2") 

331 states = states.to(dtype=self.precision) 

332 

333 mlp_out = self.forward(states) # (N, S, D) 

334 

335 # nansum used to automatically ignore missing data 

336 s1 = to.nansum((data.unsqueeze(1) - mlp_out).pow_(2), dim=2).div_(2 * sigma2) # (N, S) 

337 s2 = states @ to.log(pi / (1.0 - pi)) # (N, S) 

338 lpj = s2 - s1 

339 assert lpj.shape == (N, S) 

340 assert not to.isnan(lpj).any() and not to.isinf(lpj).any() 

341 return lpj, mlp_out 

342 

343 def log_joint(self, data, states, lpj=None): 

344 pi, sigma2 = get(self.theta, "pies", "sigma2") 

345 D = data.shape[1] - to.isnan(data).sum(dim=1) # (N,): ignores missing data 

346 D = D.unsqueeze(1) # (N, 1) 

347 if lpj is None: 

348 lpj = self._log_pseudo_joint(data, states) 

349 # TODO: could pre-evaluate the constant factor once per epoch 

350 logjoints = lpj - D / 2 * to.log(2 * MATH_PI * sigma2) + to.log(1 - pi).sum() 

351 return logjoints 

352 

353 def update_param_epoch(self) -> None: 

354 pi = self.theta["pies"] 

355 sigma2 = self.theta["sigma2"] 

356 

357 if tvo.get_run_policy() == "mpi": 

358 with to.no_grad(): 

359 for p in self.theta.values(): 

360 if p.requires_grad: 

361 broadcast(p) 

362 

363 if pi.requires_grad and sigma2.requires_grad: 

364 return # nothing to do 

365 else: 

366 D = self._net_shape[-1] 

367 all_reduce(self._train_datapoints) 

368 N = self._train_datapoints.item() 

369 

370 if not pi.requires_grad: 

371 all_reduce(self._new_pi) 

372 pi[:] = self._new_pi / N 

373 # avoids infinites in lpj evaluation 

374 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi) 

375 self._new_pi.zero_() 

376 

377 # FIXME in case of missing data there is a correction that should be applied here 

378 if not sigma2.requires_grad: 

379 all_reduce(self._new_sigma2) 

380 # disallow arbitrary growth of sigma. at each iteration, it can grow by at most 1% 

381 new_sigma_min = (sigma2 - sigma2.abs() * 0.01).item() 

382 new_sigma_max = (sigma2 + sigma2.abs() * 0.01).item() 

383 sigma2[:] = self._new_sigma2 / (N * D) 

384 if self._clamp_sigma: 

385 to.clamp(sigma2, new_sigma_min, new_sigma_max, out=sigma2) 

386 self._new_sigma2.zero_() 

387 

388 self._train_datapoints[:] = 0 

389 

390 def generate_data( 

391 self, N: int = None, hidden_state: to.Tensor = None 

392 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

393 H = self.shape[-1] 

394 if hidden_state is None: 

395 pies = self.theta["pies"] 

396 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies 

397 must_return_hidden_state = True 

398 else: 

399 if N is not None: 

400 shape = hidden_state.shape 

401 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

402 must_return_hidden_state = False 

403 

404 with to.no_grad(): 

405 mlp_out = self.forward(hidden_state) 

406 Y = to.distributions.Normal(loc=mlp_out, scale=to.sqrt(self.theta["sigma2"])).sample() 

407 

408 return (Y, hidden_state) if must_return_hidden_state else Y 

409 

410 def _optimize_nn_params( 

411 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates 

412 ) -> Tuple[float, to.Tensor]: 

413 F, mlp_out = super()._optimize_nn_params(idx, data, states) 

414 

415 with to.no_grad(): 

416 sigma2 = self.theta["sigma2"] 

417 if sigma2.requires_grad: 

418 to.clamp(sigma2, 1e-5, out=sigma2) 

419 pi = self.theta["pies"] 

420 if pi.requires_grad: 

421 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi) 

422 

423 return F, mlp_out 

424 

425 def _accumulate_param_updates( 

426 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor 

427 ) -> None: 

428 """Evaluate partial updates to pi and sigma2.""" 

429 K_batch = states.K[idx].type_as(states.lpj) 

430 

431 if not self.theta["pies"].requires_grad: 

432 # \pi_h = \frac{1}{N} \sum_n < K_nsh >_{q^n} 

433 self._new_pi.add_(mean_posterior(K_batch, states.lpj[idx]).sum(dim=0)) 

434 

435 if not self.theta["sigma2"].requires_grad: 

436 # \sigma2 = \frac{1}{DN} \sum_{n,d} < (y^n_d - \vec{a}^L_d)^2 >_{q^n} 

437 # TODO would it be better (faster or more numerically stable) to 

438 # sum over D _before_ taking the mean_posterior? 

439 y_minus_a_sqr = (data.unsqueeze(1) - mlp_out).pow_(2) # (N, S, D) 

440 assert y_minus_a_sqr.shape == (idx.numel(), K_batch.shape[1], data.shape[1]) 

441 self._new_sigma2.add_(mean_posterior(y_minus_a_sqr, states.lpj[idx]).sum((0, 1))) 

442 

443 self._train_datapoints.add_(data.shape[0]) 

444 

445 def forward(self, x: to.Tensor) -> to.Tensor: 

446 """Forward application of TVAE's MLP to the specified input.""" 

447 assert x.shape[-1] == self._net_shape[0], "Incompatible shape in forward input" 

448 

449 output = x.to(dtype=self.precision) 

450 if self._external_model is not None: 

451 output = self._external_model.forward(output).to(dtype=self.precision) 

452 else: 

453 assert isinstance(self.W, Sequence) and isinstance( 

454 self.b, Sequence 

455 ) # to make mypy happy 

456 assert callable(self._activation) # to make mypy happy 

457 

458 # middle layers (relu) 

459 for W, b in zip(self.W[:-1], self.b[:-1]): 

460 # output = self._activation(output @ W + b) 

461 output = self._activation(output @ W + b) 

462 

463 # output layer (linear) 

464 output = output @ self.W[-1] + self.b[-1] 

465 

466 return output 

467 

468 

469class BernoulliTVAE(_TVAE): 

470 def __init__( 

471 self, 

472 shape: Sequence[int] = None, 

473 precision: to.dtype = to.float64, 

474 min_lr: float = 0.001, 

475 max_lr: float = 0.01, 

476 cycliclr_step_size_up=400, 

477 pi_init: to.Tensor = None, 

478 W_init: Sequence[to.Tensor] = None, 

479 b_init: Sequence[to.Tensor] = None, 

480 analytical_pi_updates: bool = True, 

481 activation: Callable = None, 

482 external_model: Optional[to.nn.Module] = None, 

483 optimizer: Optional[opt.Optimizer] = None, 

484 ): 

485 """Create a TVAE model with Bernoulli observables. 

486 

487 :param shape: Network shape, from observable to most hidden: (D,...,H1,H0). 

488 Can be None if W_init is not None. 

489 :param precision: One of to.float32 or to.float64, indicates the floating point precision 

490 of model parameters. 

491 :param min_lr: See docs of tvo.utils.CyclicLR 

492 :param max_lr: See docs of tvo.utils.CyclicLR 

493 :param cycliclr_step_size_up: See docs of tvo.utils.CyclicLR 

494 :param pi_init: Optional tensor with initial prior values 

495 :param W_init: Optional list of tensors with initial weight values. Weight matrices 

496 must be ordered from most hidden to observable layer. If this parameter 

497 is not None, the shape parameter can be omitted. 

498 :param b_init: Optional list of tensors with initial. 

499 :param analytical_pi_updates: Whether priors should be updated via the analytical 

500 max-likelihood solution rather than gradient descent. 

501 :param activation: Decoder activation function used if external_model is not specified. 

502 Defaults to ReLU if not specified and external_model not used. 

503 :param external_model: Optional decoder neural network. One of shape, (W_init, b_init), 

504 external_model must be specified exclusively. 

505 :param optimizer: Gradient optimizer (defaults to Adam if not specified) 

506 """ 

507 self._theta: Dict[str, to.Tensor] = {} 

508 self._precision = precision 

509 self._activation = activation 

510 self._external_model = external_model 

511 assert ( 

512 (shape is not None and W_init is None and b_init is None and external_model is None) 

513 or ( 

514 shape is None 

515 and W_init is not None 

516 and b_init is not None 

517 and external_model is None 

518 ) 

519 or (shape is None and W_init is None and b_init is None and external_model is not None) 

520 ), "Must exclusively specify one one `shape`, (`W_init`, `b_init`), `external_model`" 

521 

522 if external_model is not None: 

523 assert hasattr( 

524 external_model, "H0" 

525 ), "for externally defined models, H0 has to be provided manually" 

526 assert hasattr( 

527 external_model, "shape" 

528 ), "for externally defined models, shape has to be provided manually" 

529 assert activation is None, "Must specify activation as part of external_model" 

530 H0 = external_model.H0 

531 self._net_shape = external_model.shape 

532 self.W = self.b = None 

533 gd_parameters = list(external_model.parameters()) 

534 else: 

535 self._net_shape = _get_net_shape(shape, W_init) 

536 H0 = self._net_shape[0] 

537 self.W = _init_W(self._net_shape, precision, W_init) 

538 self.b = _init_b(self._net_shape, precision, b_init) 

539 self._theta.update({f"W_{i}": W for i, W in enumerate(self.W)}) 

540 self._theta.update({f"b_{i}": b for i, b in enumerate(self.b)}) 

541 gd_parameters = self.W + self.b 

542 self._activation = to.nn.ReLU() if activation is None else activation 

543 assert callable(self._activation) 

544 

545 self._theta["pies"] = _init_pi( 

546 precision, pi_init, H0, requires_grad=not analytical_pi_updates 

547 ) 

548 

549 self._min_lr, self._max_lr, self._step_size_up = min_lr, max_lr, cycliclr_step_size_up 

550 

551 if analytical_pi_updates: 

552 self._new_pi = to.zeros(H0, dtype=precision, device=tvo.get_device()) 

553 self._analytical_pi_updates = True 

554 else: 

555 gd_parameters.append(self._theta["pies"]) 

556 self._analytical_pi_updates = False 

557 

558 if optimizer is None: 

559 self._optimizer = opt.Adam(gd_parameters, lr=min_lr) 

560 else: 

561 self._optimizer = optimizer 

562 

563 self._scheduler = CyclicLR( 

564 self._optimizer, 

565 base_lr=min_lr, 

566 max_lr=max_lr, 

567 step_size_up=cycliclr_step_size_up, 

568 cycle_momentum=False, 

569 ) 

570 # number of datapoints processed in a training epoch 

571 self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device()) 

572 self._config = dict( 

573 net_shape=self._net_shape, 

574 activation=self._activation, 

575 external_model=self._external_model, 

576 precision=self.precision, 

577 min_lr=self._min_lr, 

578 max_lr=self._max_lr, 

579 step_size_up=self._step_size_up, 

580 analytical_pi_updates=self._analytical_pi_updates, 

581 device=tvo.get_device(), 

582 ) 

583 

584 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]: 

585 N, D = data.shape 

586 N_, S, H = states.shape 

587 assert N == N_, "Shape mismatch between data and states" 

588 pi = self.theta["pies"] 

589 states = states.to(dtype=self.precision) 

590 

591 mlp_out = self.forward(states) # (N, S, D) 

592 

593 # nansum used to automatically ignore missing data 

594 s1 = to.nansum( 

595 to.nn.functional.binary_cross_entropy( 

596 mlp_out, data.unsqueeze(1).expand(N, S, D), reduction="none" 

597 ), 

598 dim=2, 

599 ) # (N, S) 

600 s2 = states @ to.log(pi / (1.0 - pi)) # (N, S) 

601 lpj = s2 - s1 

602 assert lpj.shape == (N, S) 

603 assert not to.isnan(lpj).any() and not to.isinf(lpj).any() 

604 return lpj, mlp_out 

605 

606 def log_joint(self, data, states, lpj=None): 

607 D = data.shape[1] - to.isnan(data).sum(dim=1) # (N,): ignores missing data 

608 D = D.unsqueeze(1) # (N, 1) 

609 if lpj is None: 

610 lpj = self._log_pseudo_joint(data, states) 

611 # TODO: could pre-evaluate the constant factor once per epoch 

612 logjoints = lpj + to.log(1 - self.theta["pies"]).sum() 

613 return logjoints 

614 

615 def update_param_epoch(self) -> None: 

616 pi = self.theta["pies"] 

617 

618 if tvo.get_run_policy() == "mpi": 

619 with to.no_grad(): 

620 for p in self.theta.values(): 

621 if p.requires_grad: 

622 broadcast(p) 

623 

624 if pi.requires_grad: 

625 return # nothing to do 

626 else: 

627 all_reduce(self._train_datapoints) 

628 N = self._train_datapoints.item() 

629 all_reduce(self._new_pi) 

630 pi[:] = self._new_pi / N 

631 # avoids infinites in lpj evaluation 

632 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi) 

633 self._new_pi.zero_() 

634 

635 self._train_datapoints[:] = 0 

636 

637 def generate_data( 

638 self, N: int = None, hidden_state: to.Tensor = None 

639 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

640 H = self.shape[-1] 

641 if hidden_state is None: 

642 pies = self.theta["pies"] 

643 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies 

644 must_return_hidden_state = True 

645 else: 

646 if N is not None: 

647 shape = hidden_state.shape 

648 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

649 must_return_hidden_state = False 

650 

651 with to.no_grad(): 

652 mlp_out = self.forward(hidden_state) 

653 Y = to.distributions.Bernoulli(mlp_out).sample() 

654 

655 return (Y, hidden_state) if must_return_hidden_state else Y 

656 

657 def _optimize_nn_params( 

658 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates 

659 ) -> Tuple[float, to.Tensor]: 

660 F, mlp_out = super()._optimize_nn_params(idx, data, states) 

661 

662 with to.no_grad(): 

663 pi = self.theta["pies"] 

664 if pi.requires_grad: 

665 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi) 

666 

667 return F, mlp_out 

668 

669 def _accumulate_param_updates( 

670 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor 

671 ) -> None: 

672 """Evaluate partial updates to pi.""" 

673 K_batch = states.K[idx].type_as(states.lpj) 

674 

675 if not self.theta["pies"].requires_grad: 

676 # \pi_h = \frac{1}{N} \sum_n < K_nsh >_{q^n} 

677 self._new_pi.add_(mean_posterior(K_batch, states.lpj[idx]).sum(dim=0)) 

678 

679 self._train_datapoints.add_(data.shape[0]) 

680 

681 def forward(self, x: to.Tensor) -> to.Tensor: 

682 """Forward application of TVAE's MLP to the specified input.""" 

683 assert x.shape[-1] == self._net_shape[0], "Incompatible shape in forward input" 

684 

685 output = x.to(dtype=self.precision) 

686 if self._external_model is not None: 

687 output = self._external_model.forward(output).to(dtype=self.precision) 

688 else: 

689 assert isinstance(self.W, Sequence) and isinstance( 

690 self.b, Sequence 

691 ) # to make mypy happy 

692 assert callable(self._activation) # to make mypy happy 

693 

694 # middle layers (relu) 

695 for W, b in zip(self.W[:-1], self.b[:-1]): 

696 output = self._activation(output @ W + b) 

697 

698 # output layer (sigmoid) 

699 output = to.sigmoid(output @ self.W[-1] + self.b[-1]) 

700 

701 return output