Coverage for tvo/models/tvae.py: 95%
352 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5from tvo.utils.model_protocols import Trainable, Sampler, Reconstructor
6from tvo.variational.TVOVariationalStates import TVOVariationalStates
7from tvo.variational._utils import mean_posterior
8from tvo.utils.parallel import all_reduce, broadcast, mpi_average_grads
9from tvo.utils import get, CyclicLR
10import torch.optim as opt
11import tvo
12import torch as to
13from typing import Tuple, List, Dict, Iterable, Optional, Sequence, Union, Callable
14from math import pi as MATH_PI
15from abc import abstractmethod
18def _get_net_shape(net_shape: Sequence[int] = None, W_init: Sequence[to.Tensor] = None):
19 if net_shape is not None:
20 return tuple(reversed(net_shape))
21 else:
22 assert (
23 W_init is not None
24 ), "Must pass one of `net_shape` and `W_init` to __init__ of the\
25 TVAE model"
26 return tuple(w.shape[0] for w in W_init) + (W_init[-1].shape[1],)
29def _init_W(
30 net_shape: Sequence[int], precision: to.dtype, init: Optional[Sequence[to.Tensor]]
31) -> List[to.Tensor]:
32 """Return weights initialized with Xavier or to specified init values.
34 This method also makes sure that device and precision are the ones required by
35 the model.
36 """
37 if init is None:
38 n_layers = len(net_shape) - 1
39 W_shapes = ((net_shape[ln], net_shape[ln + 1]) for ln in range(n_layers))
40 W = list(map(to.nn.init.xavier_normal_, (to.empty(s) for s in W_shapes)))
41 else:
42 assert (
43 len(init) == len(net_shape) - 1
44 ), f"Shape is {net_shape} but {len(init)} weights passed"
45 Wshapes = [w.shape for w in init]
46 expected_Wshapes = [(net_shape[ln], net_shape[ln + 1]) for ln in range(len(init))]
47 err_msg = f"Input W shapes: {Wshapes}\nExpected W shapes {expected_Wshapes}"
48 assert all(ws == exp_s for ws, exp_s in zip(Wshapes, expected_Wshapes)), err_msg
49 W = list(w.clone() for w in init)
50 for w in W:
51 broadcast(w)
52 return [w.to(device=tvo.get_device(), dtype=precision).requires_grad_(True) for w in W]
55def _init_b(
56 net_shape: Sequence[int], precision: to.dtype, init: Optional[Iterable[to.Tensor]]
57) -> List[to.Tensor]:
58 """Return biases initialized to zeros or to specified init values.
60 This method also makes sure that device and precision are the ones required by the model.
61 """
62 if init is None:
63 B = [to.zeros(s) for s in net_shape[1:]]
64 else:
65 assert all(b.shape == (net_shape[ln + 1],) for ln, b in enumerate(init))
66 B = [b.clone() for b in init]
67 return [b.to(device=tvo.get_device(), dtype=precision).requires_grad_(True) for b in B]
70def _init_pi(
71 precision: to.dtype, init: Optional[to.Tensor], H0: int, requires_grad: bool
72) -> to.Tensor:
73 if init is None:
74 pi = to.full((H0,), 1 / H0)
75 else:
76 assert init.shape == (H0,)
77 pi = init.clone()
78 return pi.to(device=tvo.get_device(), dtype=precision).requires_grad_(requires_grad)
81def _init_sigma2(precision: to.dtype, init: Optional[float], requires_grad: bool) -> to.Tensor:
82 sigma2 = to.tensor([0.01] if init is None else [init])
83 return sigma2.to(device=tvo.get_device(), dtype=precision).requires_grad_(requires_grad)
86class _TVAE(Trainable, Sampler, Reconstructor):
87 _theta: Dict[str, to.Tensor]
88 _precision: to.dtype
89 _net_shape: Sequence[int]
90 _scheduler: opt.lr_scheduler._LRScheduler
91 _optimizer: opt.Optimizer
92 _activation: Optional[Callable] = None
93 _external_model: Optional[to.nn.Module] = None
95 @abstractmethod
96 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj: to.Tensor = None) -> to.Tensor:
97 ...
99 def _log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
100 with to.no_grad():
101 lpj, _ = self._lpj_and_mlpout(data, states)
102 return lpj
104 @abstractmethod
105 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]:
106 ...
108 def free_energy(self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates) -> float:
109 with to.no_grad():
110 return super().free_energy(idx, batch, states)
112 def _free_energy_from_logjoints(self, logjoints: to.Tensor) -> to.Tensor:
113 Fn = to.logsumexp(logjoints, dim=1)
114 assert Fn.shape == (logjoints.shape[0],)
115 assert not to.isnan(Fn).any() and not to.isinf(Fn).any()
116 return Fn.sum()
118 def update_param_batch(
119 self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates
120 ) -> float:
121 if to.isnan(batch).any():
122 raise RuntimeError("There are NaNs in this batch")
123 F, mlp_out = self._optimize_nn_params(idx, batch, states)
124 with to.no_grad():
125 self._accumulate_param_updates(idx, batch, states, mlp_out)
126 return F
128 @property
129 def shape(self) -> Tuple[int, ...]:
130 """Shape of TVAE model as a bayes net: (D, H0)
132 Neural network shape is ignored.
133 """
134 return tuple((self._net_shape[-1], self._net_shape[0]))
136 @property
137 def net_shape(self) -> Tuple[int, ...]:
138 """Full TVAE network shape (D, Hn, Hn-1, ..., H0)."""
139 return tuple(reversed(self._net_shape))
141 def _optimize_nn_params(
142 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates
143 ) -> Tuple[float, to.Tensor]:
144 """
145 Gradient-based optimized parameters are changed in-place. All other arguments are left
146 untouched.
148 :returns: F and mlp_output _before_ the weight update
149 """
150 assert self._optimizer is not None # to make mypy happy
152 lpj, mlp_out = self._lpj_and_mlpout(data, states.K[idx])
153 F = self._free_energy_from_logjoints(self.log_joint(data, states.K[idx], lpj))
154 loss = -F / data.shape[0]
155 loss.backward()
157 mpi_average_grads(self.theta)
158 self._optimizer.step()
159 self._scheduler.step()
160 self._optimizer.zero_grad()
162 return F.item(), mlp_out
164 def _accumulate_param_updates(
165 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor
166 ) -> None:
167 pass
169 def data_estimator(
170 self, idx: to.Tensor, batch: to.Tensor, states: TVOVariationalStates
171 ) -> to.Tensor: # type: ignore
172 r"""
173 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
174 """
176 lpj = states.lpj[idx]
177 K = states.K[idx]
179 with to.no_grad():
180 means = self.forward(K) # N,S,D
182 return mean_posterior(means, lpj) # N, D
184 @abstractmethod
185 def forward(self, x: to.Tensor) -> to.Tensor:
186 """Forward application of TVAE's MLP to the specified input."""
187 ...
190class GaussianTVAE(_TVAE):
191 def __init__(
192 self,
193 shape: Sequence[int] = None,
194 precision: to.dtype = to.float64,
195 min_lr: float = 0.001,
196 max_lr: float = 0.01,
197 cycliclr_step_size_up=400,
198 pi_init: to.Tensor = None,
199 W_init: Sequence[to.Tensor] = None,
200 b_init: Sequence[to.Tensor] = None,
201 sigma2_init: float = None,
202 analytical_sigma_updates: bool = True,
203 analytical_pi_updates: bool = True,
204 clamp_sigma_updates: bool = False,
205 activation: Callable = None,
206 external_model: Optional[to.nn.Module] = None,
207 optimizer: Optional[opt.Optimizer] = None,
208 ):
209 """Create a TVAE model with Gaussian observables.
211 :param shape: Network shape, from observable to most hidden: (D,...,H1,H0). One of shape,
212 (W_init, b_init), external_model must be specified exclusively.
213 :param precision: One of to.float32 or to.float64, indicates the floating point precision
214 of model parameters.
215 :param min_lr: See docs of tvo.utils.CyclicLR
216 :param max_lr: See docs of tvo.utils.CyclicLR
217 :param cycliclr_step_size_up: See docs of tvo.utils.CyclicLR
218 :param pi_init: Optional tensor with initial prior values
219 :param W_init: Optional list of tensors with initial weight values. Weight matrices
220 must be ordered from most hidden to observable layer. One of shape,
221 (W_init, b_init), external_model must be specified exclusively.
222 :param b_init: Optional list of tensors with initial bias. One of shape,
223 (W_init, b_init), external_model must be specified exclusively.
224 :param sigma2_init: Optional initial value for model variance.
225 :param analytical_sigma_updates: Whether sigmas should be updated via the analytical
226 max-likelihood solution rather than gradient descent.
227 :param analytical_pi_updates: Whether priors should be updated via the analytical
228 max-likelihood solution rather than gradient descent.
229 :param clamp_sigma_updates: Whether to limit the rate at which sigma can be updated.
230 :param activation: Decoder activation function used if external_model is not specified.
231 Defaults to ReLU if not specified and external_model not used.
232 :param external_model: Optional decoder neural network. One of shape, (W_init, b_init),
233 external_model must be specified exclusively.
234 :param optimizer: Gradient optimizer (defaults to Adam if not specified)
235 """
236 self._theta: Dict[str, to.Tensor] = {}
237 self._clamp_sigma = clamp_sigma_updates
238 self._precision = precision
239 self._activation = activation
240 self._external_model = external_model
241 assert (
242 (shape is not None and W_init is None and b_init is None and external_model is None)
243 or (
244 shape is None
245 and W_init is not None
246 and b_init is not None
247 and external_model is None
248 )
249 or (shape is None and W_init is None and b_init is None and external_model is not None)
250 ), "Must exclusively specify one one `shape`, (`W_init`, `b_init`), `external_model`"
252 if external_model is not None:
253 assert hasattr(
254 external_model, "H0"
255 ), "for externally defined models, H0 has to be provided manually"
256 assert hasattr(
257 external_model, "shape"
258 ), "for externally defined models, shape has to be provided manually"
259 assert activation is None, "Must specify activation as part of external_model"
260 H0 = external_model.H0
261 self._net_shape = external_model.shape
262 self.W = self.b = None
263 gd_parameters = list(external_model.parameters())
264 else:
265 self._net_shape = _get_net_shape(shape, W_init)
266 H0 = self._net_shape[0]
267 self.W = _init_W(self._net_shape, precision, W_init)
268 self.b = _init_b(self._net_shape, precision, b_init)
269 self._theta.update({f"W_{i}": W for i, W in enumerate(self.W)})
270 self._theta.update({f"b_{i}": b for i, b in enumerate(self.b)})
271 gd_parameters = self.W + self.b
272 self._activation = to.nn.ReLU() if activation is None else activation
273 assert callable(self._activation)
275 self._theta["pies"] = _init_pi(
276 precision, pi_init, H0, requires_grad=not analytical_pi_updates
277 )
278 self._theta["sigma2"] = _init_sigma2(
279 precision, sigma2_init, requires_grad=not analytical_sigma_updates
280 )
282 self._min_lr, self._max_lr, self._step_size_up = min_lr, max_lr, cycliclr_step_size_up
284 if analytical_sigma_updates:
285 self._new_sigma2 = to.zeros(1, dtype=precision, device=tvo.get_device())
286 self._analytical_sigma_updates = True
287 else:
288 gd_parameters.append(self._theta["sigma2"])
289 self._analytical_sigma_updates = False
291 if analytical_pi_updates:
292 self._new_pi = to.zeros(H0, dtype=precision, device=tvo.get_device())
293 self._analytical_pi_updates = True
294 else:
295 gd_parameters.append(self._theta["pies"])
296 self._analytical_pi_updates = False
298 if optimizer is None:
299 self._optimizer = opt.Adam(gd_parameters, lr=min_lr)
300 else:
301 self._optimizer = optimizer
303 self._scheduler = CyclicLR(
304 self._optimizer,
305 base_lr=min_lr,
306 max_lr=max_lr,
307 step_size_up=cycliclr_step_size_up,
308 cycle_momentum=False,
309 )
310 # number of datapoints processed in a training epoch
311 self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device())
312 self._config = dict(
313 net_shape=self._net_shape,
314 activation=self._activation,
315 external_model=self._external_model,
316 precision=self.precision,
317 min_lr=self._min_lr,
318 max_lr=self._max_lr,
319 step_size_up=self._step_size_up,
320 analytical_sigma_updates=self._analytical_sigma_updates,
321 analytical_pi_updates=self._analytical_pi_updates,
322 clamp_sigma_updates=self._clamp_sigma,
323 device=tvo.get_device(),
324 )
326 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]:
327 N = data.shape[0]
328 N_, S, H = states.shape
329 assert N == N_, "Shape mismatch between data and states"
330 pi, sigma2 = get(self.theta, "pies", "sigma2")
331 states = states.to(dtype=self.precision)
333 mlp_out = self.forward(states) # (N, S, D)
335 # nansum used to automatically ignore missing data
336 s1 = to.nansum((data.unsqueeze(1) - mlp_out).pow_(2), dim=2).div_(2 * sigma2) # (N, S)
337 s2 = states @ to.log(pi / (1.0 - pi)) # (N, S)
338 lpj = s2 - s1
339 assert lpj.shape == (N, S)
340 assert not to.isnan(lpj).any() and not to.isinf(lpj).any()
341 return lpj, mlp_out
343 def log_joint(self, data, states, lpj=None):
344 pi, sigma2 = get(self.theta, "pies", "sigma2")
345 D = data.shape[1] - to.isnan(data).sum(dim=1) # (N,): ignores missing data
346 D = D.unsqueeze(1) # (N, 1)
347 if lpj is None:
348 lpj = self._log_pseudo_joint(data, states)
349 # TODO: could pre-evaluate the constant factor once per epoch
350 logjoints = lpj - D / 2 * to.log(2 * MATH_PI * sigma2) + to.log(1 - pi).sum()
351 return logjoints
353 def update_param_epoch(self) -> None:
354 pi = self.theta["pies"]
355 sigma2 = self.theta["sigma2"]
357 if tvo.get_run_policy() == "mpi":
358 with to.no_grad():
359 for p in self.theta.values():
360 if p.requires_grad:
361 broadcast(p)
363 if pi.requires_grad and sigma2.requires_grad:
364 return # nothing to do
365 else:
366 D = self._net_shape[-1]
367 all_reduce(self._train_datapoints)
368 N = self._train_datapoints.item()
370 if not pi.requires_grad:
371 all_reduce(self._new_pi)
372 pi[:] = self._new_pi / N
373 # avoids infinites in lpj evaluation
374 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi)
375 self._new_pi.zero_()
377 # FIXME in case of missing data there is a correction that should be applied here
378 if not sigma2.requires_grad:
379 all_reduce(self._new_sigma2)
380 # disallow arbitrary growth of sigma. at each iteration, it can grow by at most 1%
381 new_sigma_min = (sigma2 - sigma2.abs() * 0.01).item()
382 new_sigma_max = (sigma2 + sigma2.abs() * 0.01).item()
383 sigma2[:] = self._new_sigma2 / (N * D)
384 if self._clamp_sigma:
385 to.clamp(sigma2, new_sigma_min, new_sigma_max, out=sigma2)
386 self._new_sigma2.zero_()
388 self._train_datapoints[:] = 0
390 def generate_data(
391 self, N: int = None, hidden_state: to.Tensor = None
392 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
393 H = self.shape[-1]
394 if hidden_state is None:
395 pies = self.theta["pies"]
396 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies
397 must_return_hidden_state = True
398 else:
399 if N is not None:
400 shape = hidden_state.shape
401 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
402 must_return_hidden_state = False
404 with to.no_grad():
405 mlp_out = self.forward(hidden_state)
406 Y = to.distributions.Normal(loc=mlp_out, scale=to.sqrt(self.theta["sigma2"])).sample()
408 return (Y, hidden_state) if must_return_hidden_state else Y
410 def _optimize_nn_params(
411 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates
412 ) -> Tuple[float, to.Tensor]:
413 F, mlp_out = super()._optimize_nn_params(idx, data, states)
415 with to.no_grad():
416 sigma2 = self.theta["sigma2"]
417 if sigma2.requires_grad:
418 to.clamp(sigma2, 1e-5, out=sigma2)
419 pi = self.theta["pies"]
420 if pi.requires_grad:
421 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi)
423 return F, mlp_out
425 def _accumulate_param_updates(
426 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor
427 ) -> None:
428 """Evaluate partial updates to pi and sigma2."""
429 K_batch = states.K[idx].type_as(states.lpj)
431 if not self.theta["pies"].requires_grad:
432 # \pi_h = \frac{1}{N} \sum_n < K_nsh >_{q^n}
433 self._new_pi.add_(mean_posterior(K_batch, states.lpj[idx]).sum(dim=0))
435 if not self.theta["sigma2"].requires_grad:
436 # \sigma2 = \frac{1}{DN} \sum_{n,d} < (y^n_d - \vec{a}^L_d)^2 >_{q^n}
437 # TODO would it be better (faster or more numerically stable) to
438 # sum over D _before_ taking the mean_posterior?
439 y_minus_a_sqr = (data.unsqueeze(1) - mlp_out).pow_(2) # (N, S, D)
440 assert y_minus_a_sqr.shape == (idx.numel(), K_batch.shape[1], data.shape[1])
441 self._new_sigma2.add_(mean_posterior(y_minus_a_sqr, states.lpj[idx]).sum((0, 1)))
443 self._train_datapoints.add_(data.shape[0])
445 def forward(self, x: to.Tensor) -> to.Tensor:
446 """Forward application of TVAE's MLP to the specified input."""
447 assert x.shape[-1] == self._net_shape[0], "Incompatible shape in forward input"
449 output = x.to(dtype=self.precision)
450 if self._external_model is not None:
451 output = self._external_model.forward(output).to(dtype=self.precision)
452 else:
453 assert isinstance(self.W, Sequence) and isinstance(
454 self.b, Sequence
455 ) # to make mypy happy
456 assert callable(self._activation) # to make mypy happy
458 # middle layers (relu)
459 for W, b in zip(self.W[:-1], self.b[:-1]):
460 # output = self._activation(output @ W + b)
461 output = self._activation(output @ W + b)
463 # output layer (linear)
464 output = output @ self.W[-1] + self.b[-1]
466 return output
469class BernoulliTVAE(_TVAE):
470 def __init__(
471 self,
472 shape: Sequence[int] = None,
473 precision: to.dtype = to.float64,
474 min_lr: float = 0.001,
475 max_lr: float = 0.01,
476 cycliclr_step_size_up=400,
477 pi_init: to.Tensor = None,
478 W_init: Sequence[to.Tensor] = None,
479 b_init: Sequence[to.Tensor] = None,
480 analytical_pi_updates: bool = True,
481 activation: Callable = None,
482 external_model: Optional[to.nn.Module] = None,
483 optimizer: Optional[opt.Optimizer] = None,
484 ):
485 """Create a TVAE model with Bernoulli observables.
487 :param shape: Network shape, from observable to most hidden: (D,...,H1,H0).
488 Can be None if W_init is not None.
489 :param precision: One of to.float32 or to.float64, indicates the floating point precision
490 of model parameters.
491 :param min_lr: See docs of tvo.utils.CyclicLR
492 :param max_lr: See docs of tvo.utils.CyclicLR
493 :param cycliclr_step_size_up: See docs of tvo.utils.CyclicLR
494 :param pi_init: Optional tensor with initial prior values
495 :param W_init: Optional list of tensors with initial weight values. Weight matrices
496 must be ordered from most hidden to observable layer. If this parameter
497 is not None, the shape parameter can be omitted.
498 :param b_init: Optional list of tensors with initial.
499 :param analytical_pi_updates: Whether priors should be updated via the analytical
500 max-likelihood solution rather than gradient descent.
501 :param activation: Decoder activation function used if external_model is not specified.
502 Defaults to ReLU if not specified and external_model not used.
503 :param external_model: Optional decoder neural network. One of shape, (W_init, b_init),
504 external_model must be specified exclusively.
505 :param optimizer: Gradient optimizer (defaults to Adam if not specified)
506 """
507 self._theta: Dict[str, to.Tensor] = {}
508 self._precision = precision
509 self._activation = activation
510 self._external_model = external_model
511 assert (
512 (shape is not None and W_init is None and b_init is None and external_model is None)
513 or (
514 shape is None
515 and W_init is not None
516 and b_init is not None
517 and external_model is None
518 )
519 or (shape is None and W_init is None and b_init is None and external_model is not None)
520 ), "Must exclusively specify one one `shape`, (`W_init`, `b_init`), `external_model`"
522 if external_model is not None:
523 assert hasattr(
524 external_model, "H0"
525 ), "for externally defined models, H0 has to be provided manually"
526 assert hasattr(
527 external_model, "shape"
528 ), "for externally defined models, shape has to be provided manually"
529 assert activation is None, "Must specify activation as part of external_model"
530 H0 = external_model.H0
531 self._net_shape = external_model.shape
532 self.W = self.b = None
533 gd_parameters = list(external_model.parameters())
534 else:
535 self._net_shape = _get_net_shape(shape, W_init)
536 H0 = self._net_shape[0]
537 self.W = _init_W(self._net_shape, precision, W_init)
538 self.b = _init_b(self._net_shape, precision, b_init)
539 self._theta.update({f"W_{i}": W for i, W in enumerate(self.W)})
540 self._theta.update({f"b_{i}": b for i, b in enumerate(self.b)})
541 gd_parameters = self.W + self.b
542 self._activation = to.nn.ReLU() if activation is None else activation
543 assert callable(self._activation)
545 self._theta["pies"] = _init_pi(
546 precision, pi_init, H0, requires_grad=not analytical_pi_updates
547 )
549 self._min_lr, self._max_lr, self._step_size_up = min_lr, max_lr, cycliclr_step_size_up
551 if analytical_pi_updates:
552 self._new_pi = to.zeros(H0, dtype=precision, device=tvo.get_device())
553 self._analytical_pi_updates = True
554 else:
555 gd_parameters.append(self._theta["pies"])
556 self._analytical_pi_updates = False
558 if optimizer is None:
559 self._optimizer = opt.Adam(gd_parameters, lr=min_lr)
560 else:
561 self._optimizer = optimizer
563 self._scheduler = CyclicLR(
564 self._optimizer,
565 base_lr=min_lr,
566 max_lr=max_lr,
567 step_size_up=cycliclr_step_size_up,
568 cycle_momentum=False,
569 )
570 # number of datapoints processed in a training epoch
571 self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device())
572 self._config = dict(
573 net_shape=self._net_shape,
574 activation=self._activation,
575 external_model=self._external_model,
576 precision=self.precision,
577 min_lr=self._min_lr,
578 max_lr=self._max_lr,
579 step_size_up=self._step_size_up,
580 analytical_pi_updates=self._analytical_pi_updates,
581 device=tvo.get_device(),
582 )
584 def _lpj_and_mlpout(self, data: to.Tensor, states: to.Tensor) -> Tuple[to.Tensor, to.Tensor]:
585 N, D = data.shape
586 N_, S, H = states.shape
587 assert N == N_, "Shape mismatch between data and states"
588 pi = self.theta["pies"]
589 states = states.to(dtype=self.precision)
591 mlp_out = self.forward(states) # (N, S, D)
593 # nansum used to automatically ignore missing data
594 s1 = to.nansum(
595 to.nn.functional.binary_cross_entropy(
596 mlp_out, data.unsqueeze(1).expand(N, S, D), reduction="none"
597 ),
598 dim=2,
599 ) # (N, S)
600 s2 = states @ to.log(pi / (1.0 - pi)) # (N, S)
601 lpj = s2 - s1
602 assert lpj.shape == (N, S)
603 assert not to.isnan(lpj).any() and not to.isinf(lpj).any()
604 return lpj, mlp_out
606 def log_joint(self, data, states, lpj=None):
607 D = data.shape[1] - to.isnan(data).sum(dim=1) # (N,): ignores missing data
608 D = D.unsqueeze(1) # (N, 1)
609 if lpj is None:
610 lpj = self._log_pseudo_joint(data, states)
611 # TODO: could pre-evaluate the constant factor once per epoch
612 logjoints = lpj + to.log(1 - self.theta["pies"]).sum()
613 return logjoints
615 def update_param_epoch(self) -> None:
616 pi = self.theta["pies"]
618 if tvo.get_run_policy() == "mpi":
619 with to.no_grad():
620 for p in self.theta.values():
621 if p.requires_grad:
622 broadcast(p)
624 if pi.requires_grad:
625 return # nothing to do
626 else:
627 all_reduce(self._train_datapoints)
628 N = self._train_datapoints.item()
629 all_reduce(self._new_pi)
630 pi[:] = self._new_pi / N
631 # avoids infinites in lpj evaluation
632 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi)
633 self._new_pi.zero_()
635 self._train_datapoints[:] = 0
637 def generate_data(
638 self, N: int = None, hidden_state: to.Tensor = None
639 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
640 H = self.shape[-1]
641 if hidden_state is None:
642 pies = self.theta["pies"]
643 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies
644 must_return_hidden_state = True
645 else:
646 if N is not None:
647 shape = hidden_state.shape
648 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
649 must_return_hidden_state = False
651 with to.no_grad():
652 mlp_out = self.forward(hidden_state)
653 Y = to.distributions.Bernoulli(mlp_out).sample()
655 return (Y, hidden_state) if must_return_hidden_state else Y
657 def _optimize_nn_params(
658 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates
659 ) -> Tuple[float, to.Tensor]:
660 F, mlp_out = super()._optimize_nn_params(idx, data, states)
662 with to.no_grad():
663 pi = self.theta["pies"]
664 if pi.requires_grad:
665 to.clamp(pi, 1e-5, 1 - 1e-5, out=pi)
667 return F, mlp_out
669 def _accumulate_param_updates(
670 self, idx: to.Tensor, data: to.Tensor, states: TVOVariationalStates, mlp_out: to.Tensor
671 ) -> None:
672 """Evaluate partial updates to pi."""
673 K_batch = states.K[idx].type_as(states.lpj)
675 if not self.theta["pies"].requires_grad:
676 # \pi_h = \frac{1}{N} \sum_n < K_nsh >_{q^n}
677 self._new_pi.add_(mean_posterior(K_batch, states.lpj[idx]).sum(dim=0))
679 self._train_datapoints.add_(data.shape[0])
681 def forward(self, x: to.Tensor) -> to.Tensor:
682 """Forward application of TVAE's MLP to the specified input."""
683 assert x.shape[-1] == self._net_shape[0], "Incompatible shape in forward input"
685 output = x.to(dtype=self.precision)
686 if self._external_model is not None:
687 output = self._external_model.forward(output).to(dtype=self.precision)
688 else:
689 assert isinstance(self.W, Sequence) and isinstance(
690 self.b, Sequence
691 ) # to make mypy happy
692 assert callable(self._activation) # to make mypy happy
694 # middle layers (relu)
695 for W, b in zip(self.W[:-1], self.b[:-1]):
696 output = self._activation(output @ W + b)
698 # output layer (sigmoid)
699 output = to.sigmoid(output @ self.W[-1] + self.b[-1])
701 return output