Coverage for tvo/models/pmm.py: 91%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5 

6import torch as to 

7import math 

8from torch.distributions.one_hot_categorical import OneHotCategorical 

9 

10from torch import Tensor 

11from typing import Union, Tuple 

12 

13import tvo 

14from tvo.utils.parallel import pprint, all_reduce, broadcast 

15from tvo.variational.TVOVariationalStates import TVOVariationalStates 

16from tvo.variational._utils import mean_posterior 

17from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor 

18from tvo.utils.sanity import fix_theta 

19 

20 

21class PMM(Optimized, Sampler, Reconstructor): 

22 def __init__( 

23 self, 

24 H: int, 

25 D: int, 

26 W_init: Tensor = None, 

27 pies_init: Tensor = None, 

28 precision: to.dtype = to.float64, 

29 ): 

30 """Poisson Mixture Model (PMM). 

31 

32 :param H: Number of hidden units. 

33 :param D: Number of observables. 

34 :param W_init: Tensor with shape (D,H), initializes PMM weights. 

35 :param pies_init: Tensor with shape (H,), initializes PMM priors. 

36 :param precision: Floating point precision required. Must be one of torch.float32 or 

37 torch.float64. 

38 

39 """ 

40 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}" 

41 self._precision = precision 

42 

43 device = tvo.get_device() 

44 

45 if W_init is not None: 

46 assert W_init.shape == (D, H) 

47 W_init = W_init.to(dtype=precision, device=device) 

48 else: 

49 W_init = to.rand((D, H), dtype=precision, device=device) 

50 broadcast(W_init) 

51 

52 if pies_init is not None: 

53 assert pies_init.shape == (H,) 

54 pies_init = pies_init.to(dtype=precision, device=device) 

55 else: 

56 pies_init = to.full((H,), 1.0 / H, dtype=precision, device=device) 

57 

58 self._theta = {"pies": pies_init, "W": W_init} 

59 eps, inf = 1.0e-5, math.inf 

60 self.policy = { 

61 "W": [None, to.full_like(self._theta["W"], eps), to.full_like(self._theta["W"], inf)], 

62 "pies": [ 

63 None, 

64 to.full_like(self._theta["pies"], eps), 

65 to.full_like(self._theta["pies"], 1.0 - eps), 

66 ], 

67 } 

68 

69 self.my_Wp = to.zeros((D, H), dtype=precision, device=device) 

70 self.my_Wq = to.zeros((H), dtype=precision, device=device) 

71 self.my_pies = to.zeros(H, dtype=precision, device=device) 

72 self.my_N = to.tensor([0], dtype=to.int, device=device) 

73 self._config = dict(H=H, D=D, precision=precision, device=device) 

74 self._shape = self.theta["W"].shape 

75 

76 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore 

77 """Evaluate log-pseudo-joints for GMM.""" 

78 Kfloat = states.to( 

79 dtype=self.theta["W"].dtype 

80 ) # N,C,C # TODO Find solution to avoid byte->float casting 

81 

82 Wbar = ( 

83 to.matmul(Kfloat, self.theta["W"].t()) + to.finfo(to.float32).tiny 

84 ) # N,C,D # TODO Pre-allocate tensor and use `out` argument of to.matmul 

85 

86 lpj = ( 

87 to.sum(data[:, None, :] * to.log(Wbar), dim=2) 

88 - to.sum(Wbar, dim=2) 

89 + to.matmul(Kfloat, to.log(self.theta["pies"])) 

90 ) 

91 return lpj.to(device=states.device) 

92 

93 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor: 

94 """Evaluate log-joints for PMM.""" 

95 if lpj is None: 

96 lpj = self.log_pseudo_joint(data, states) 

97 return lpj - to.sum(to.lgamma(data + 1), dim=1)[:, None] 

98 

99 def update_param_batch(self, idx: Tensor, batch: Tensor, states: Tensor) -> None: 

100 lpj = states.lpj[idx] 

101 K = states.K[idx] 

102 batch_size, S, _ = K.shape 

103 

104 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting 

105 

106 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H) mean_posterior(Kfloat, lpj) 

107 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H) 

108 

109 self.my_pies.add_(to.sum(batch_s_pjc, dim=0)) 

110 self.my_Wp.add_(to.sum(batch_Wp, dim=0)) 

111 self.my_Wq.add_(to.sum(batch_s_pjc, dim=0)) 

112 self.my_N.add_(batch_size) 

113 

114 return None 

115 

116 def update_param_epoch(self) -> None: 

117 theta = self.theta 

118 policy = self.policy 

119 

120 all_reduce(self.my_Wp) 

121 all_reduce(self.my_Wq) 

122 all_reduce(self.my_pies) 

123 all_reduce(self.my_N) 

124 

125 N = self.my_N.item() 

126 

127 # Calculate updated W 

128 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"]) 

129 broadcast(Wold_noisy) 

130 theta_new = {} 

131 try: 

132 theta_new["W"] = self.my_Wp / self.my_Wq[None, :] 

133 except RuntimeError: 

134 pprint("Inversion error. Will not update W but add some noise instead.") 

135 theta_new["W"] = Wold_noisy 

136 

137 # Calculate updated pi 

138 theta_new["pies"] = self.my_pies / N 

139 

140 policy["W"][0] = Wold_noisy 

141 policy["pies"][0] = theta["pies"] 

142 fix_theta(theta_new, policy) 

143 for key in theta: 

144 theta[key] = theta_new[key] 

145 

146 self.my_Wp[:] = 0.0 

147 self.my_Wq[:] = 0.0 

148 self.my_pies[:] = 0.0 

149 self.my_N[:] = 0.0 

150 

151 @property 

152 def shape(self) -> Tuple[int, ...]: 

153 return self.theta["W"].shape 

154 

155 def generate_data( 

156 self, N: int = None, hidden_state: to.Tensor = None 

157 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

158 precision, device = self.precision, tvo.get_device() 

159 D, H = self.shape 

160 

161 if hidden_state is None: 

162 assert N is not None 

163 pies = self.theta["pies"] 

164 hidden_state = OneHotCategorical(probs=pies).sample([N]) == 1 

165 must_return_hidden_state = True 

166 else: 

167 shape = hidden_state.shape 

168 if N is None: 

169 N = shape[0] 

170 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

171 must_return_hidden_state = False 

172 

173 Wbar = to.zeros((N, D), dtype=precision, device=device) 

174 

175 for n in range(N): 

176 for h in range(H): 

177 if hidden_state[n, h]: 

178 Wbar[n] += self.theta["W"][:, h] 

179 

180 # Add noise according to the model parameters 

181 Y = to.poisson(Wbar) 

182 

183 return (Y, hidden_state) if must_return_hidden_state else Y 

184 

185 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor: 

186 # Not yet implemented 

187 

188 """Estimator used for data reconstruction. Data reconstruction can only be supported 

189 by a model if it implements this method. The estimator to be implemented is defined 

190 as follows:""" r""" 

191 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

192 """ 

193 K = states.K[idx] 

194 # TODO Find solution to avoid byte->float casting of `K` 

195 # TODO Pre-allocate tensor and use `out` argument of to.matmul 

196 return mean_posterior( 

197 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx] 

198 )