Coverage for tvo/variational/tvs.py: 32%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6from tvo.variational.TVOVariationalStates import TVOVariationalStates 

7from tvo.variational._utils import mean_posterior 

8from tvo.utils.model_protocols import Trainable, Optimized 

9from ._utils import update_states_for_batch, set_redundant_lpj_to_low 

10 

11 

12class TVSVariationalStates(TVOVariationalStates): 

13 def __init__( 

14 self, 

15 N: int, 

16 H: int, 

17 S: int, 

18 precision: to.dtype, 

19 S_new_prior: int, 

20 S_new_marg: int, 

21 K_init_file: str = None, 

22 ): 

23 """Truncated Variational Sampling class. 

24 

25 :param N: number of datapoints 

26 :param H: number of latents 

27 :param S: number of variational states 

28 :param precision: floating point precision to be used for log_joint values. 

29 Must be one of to.float32 or to.float64. 

30 :param S_new_prior: number of states to be sampled from prior at every call to ~update 

31 :param S_new_marg: number of states to be sampled from approximated marginal\ 

32 p(s_h=1|vec{y}^{(n)}, Theta) at every call to ~update 

33 :param K_init_file: Full path to H5 file providing initial states 

34 """ 

35 conf = { 

36 "N": N, 

37 "H": H, 

38 "S": S, 

39 "S_new_prior": S_new_prior, 

40 "S_new_marg": S_new_marg, 

41 "S_new": S_new_prior + S_new_marg, 

42 "precision": precision, 

43 "K_init_file": K_init_file, 

44 } 

45 super().__init__(conf) 

46 

47 def update(self, idx: to.Tensor, batch: to.Tensor, model: Trainable) -> int: 

48 """See :func:`tvo.variational.TVOVariationalStates.update`.""" 

49 if isinstance(model, Optimized): 

50 lpj_fn = model.log_pseudo_joint 

51 sort_by_lpj = model.sorted_by_lpj 

52 else: 

53 lpj_fn = model.log_joint 

54 sort_by_lpj = {} 

55 

56 K, lpj = self.K, self.lpj 

57 batch_size, H = batch.shape[0], K.shape[2] 

58 lpj[idx] = lpj_fn(batch, K[idx]) 

59 

60 new_K_prior = ( 

61 to.rand(batch_size, self.config["S_new_prior"], H, device=K.device) 

62 < model.theta["pies"] 

63 ).byte() 

64 

65 approximate_marginals = ( 

66 mean_posterior(K[idx].type_as(lpj), lpj[idx]) 

67 .unsqueeze(1) 

68 .expand(batch_size, self.config["S_new_marg"], H) 

69 ) # approximates p(s_h=1|\yVecN, \Theta), shape is (batch_size, S_new_marg, H) 

70 new_K_marg = ( 

71 to.rand(batch_size, self.config["S_new_marg"], H, device=K.device) 

72 < approximate_marginals 

73 ).byte() 

74 

75 new_K = to.cat((new_K_prior, new_K_marg), dim=1) 

76 

77 new_lpj = lpj_fn(batch, new_K) 

78 

79 set_redundant_lpj_to_low(new_K, new_lpj, K[idx]) 

80 

81 return update_states_for_batch(new_K, new_lpj, idx, K, lpj, sort_by_lpj=sort_by_lpj)