Coverage for tvo/variational/tvs.py: 32%
25 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6from tvo.variational.TVOVariationalStates import TVOVariationalStates
7from tvo.variational._utils import mean_posterior
8from tvo.utils.model_protocols import Trainable, Optimized
9from ._utils import update_states_for_batch, set_redundant_lpj_to_low
12class TVSVariationalStates(TVOVariationalStates):
13 def __init__(
14 self,
15 N: int,
16 H: int,
17 S: int,
18 precision: to.dtype,
19 S_new_prior: int,
20 S_new_marg: int,
21 K_init_file: str = None,
22 ):
23 """Truncated Variational Sampling class.
25 :param N: number of datapoints
26 :param H: number of latents
27 :param S: number of variational states
28 :param precision: floating point precision to be used for log_joint values.
29 Must be one of to.float32 or to.float64.
30 :param S_new_prior: number of states to be sampled from prior at every call to ~update
31 :param S_new_marg: number of states to be sampled from approximated marginal\
32 p(s_h=1|vec{y}^{(n)}, Theta) at every call to ~update
33 :param K_init_file: Full path to H5 file providing initial states
34 """
35 conf = {
36 "N": N,
37 "H": H,
38 "S": S,
39 "S_new_prior": S_new_prior,
40 "S_new_marg": S_new_marg,
41 "S_new": S_new_prior + S_new_marg,
42 "precision": precision,
43 "K_init_file": K_init_file,
44 }
45 super().__init__(conf)
47 def update(self, idx: to.Tensor, batch: to.Tensor, model: Trainable) -> int:
48 """See :func:`tvo.variational.TVOVariationalStates.update`."""
49 if isinstance(model, Optimized):
50 lpj_fn = model.log_pseudo_joint
51 sort_by_lpj = model.sorted_by_lpj
52 else:
53 lpj_fn = model.log_joint
54 sort_by_lpj = {}
56 K, lpj = self.K, self.lpj
57 batch_size, H = batch.shape[0], K.shape[2]
58 lpj[idx] = lpj_fn(batch, K[idx])
60 new_K_prior = (
61 to.rand(batch_size, self.config["S_new_prior"], H, device=K.device)
62 < model.theta["pies"]
63 ).byte()
65 approximate_marginals = (
66 mean_posterior(K[idx].type_as(lpj), lpj[idx])
67 .unsqueeze(1)
68 .expand(batch_size, self.config["S_new_marg"], H)
69 ) # approximates p(s_h=1|\yVecN, \Theta), shape is (batch_size, S_new_marg, H)
70 new_K_marg = (
71 to.rand(batch_size, self.config["S_new_marg"], H, device=K.device)
72 < approximate_marginals
73 ).byte()
75 new_K = to.cat((new_K_prior, new_K_marg), dim=1)
77 new_lpj = lpj_fn(batch, new_K)
79 set_redundant_lpj_to_low(new_K, new_lpj, K[idx])
81 return update_states_for_batch(new_K, new_lpj, idx, K, lpj, sort_by_lpj=sort_by_lpj)