Coverage for tvo/models/gmm.py: 91%

121 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2021 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5 

6import torch as to 

7import math 

8from torch.distributions.one_hot_categorical import OneHotCategorical 

9 

10from torch import Tensor 

11from typing import Union, Tuple 

12 

13import tvo 

14from tvo.utils.parallel import pprint, all_reduce, broadcast 

15from tvo.variational.TVOVariationalStates import TVOVariationalStates 

16from tvo.variational._utils import mean_posterior 

17from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor 

18from tvo.utils.sanity import fix_theta 

19 

20 

21class GMM(Optimized, Sampler, Reconstructor): 

22 def __init__( 

23 self, 

24 H: int, 

25 D: int, 

26 W_init: Tensor = None, 

27 sigma2_init: Tensor = None, 

28 pies_init: Tensor = None, 

29 precision: to.dtype = to.float64, 

30 ): 

31 """Gaussian Mixture model (GMM). 

32 

33 :param H: Number of hidden units. 

34 :param D: Number of observables. 

35 :param W_init: Tensor with shape (D,H), initializes GM weights. 

36 :param pies_init: Tensor with shape (H,), initializes GM priors. 

37 :param precision: Floating point precision required. Must be one of torch.float32 or 

38 torch.float64. 

39 

40 """ 

41 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}" 

42 self._precision = precision 

43 

44 device = tvo.get_device() 

45 

46 if W_init is not None: 

47 assert W_init.shape == (D, H) 

48 W_init = W_init.to(dtype=precision, device=device) 

49 else: 

50 W_init = to.rand((D, H), dtype=precision, device=device) 

51 broadcast(W_init) 

52 

53 if pies_init is not None: 

54 assert pies_init.shape == (H,) 

55 pies_init = pies_init.to(dtype=precision, device=device) 

56 else: 

57 pies_init = to.full((H,), 1.0 / H, dtype=precision, device=device) 

58 

59 if sigma2_init is not None: 

60 assert sigma2_init.shape == (1,) 

61 sigma2_init = sigma2_init.to(dtype=precision, device=device) 

62 else: 

63 sigma2_init = to.tensor([1.0], dtype=precision, device=device) 

64 

65 self._theta = {"pies": pies_init, "W": W_init, "sigma2": sigma2_init} 

66 eps, inf = 1.0e-5, math.inf 

67 self.policy = { 

68 "W": [None, to.full_like(self._theta["W"], -inf), to.full_like(self._theta["W"], inf)], 

69 "pies": [ 

70 None, 

71 to.full_like(self._theta["pies"], eps), 

72 to.full_like(self._theta["pies"], 1.0 - eps), 

73 ], 

74 "sigma2": [ 

75 None, 

76 to.full_like(self._theta["sigma2"], eps), 

77 to.full_like(self._theta["sigma2"], inf), 

78 ], 

79 } 

80 

81 self.my_Wp = to.zeros((D, H), dtype=precision, device=device) 

82 self.my_Wq = to.zeros((H), dtype=precision, device=device) 

83 self.my_pies = to.zeros(H, dtype=precision, device=device) 

84 self.my_sigma2 = to.zeros(1, dtype=precision, device=device) 

85 self.my_N = to.tensor([0], dtype=to.int, device=device) 

86 self._config = dict(H=H, D=D, precision=precision, device=device) 

87 self._shape = self.theta["W"].shape 

88 

89 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore 

90 """Evaluate log-pseudo-joints for GMM.""" 

91 Kfloat = states.to( 

92 dtype=self.theta["W"].dtype 

93 ) # N,C,C # TODO Find solution to avoid byte->float casting 

94 Wbar = to.matmul( 

95 Kfloat, self.theta["W"].t() 

96 ) # N,C,D # TODO Pre-allocate tensor and use `out` argument of to.matmul 

97 lpj = to.mul( 

98 to.sum(to.pow(Wbar - data[:, None, :], 2), dim=2), -1 / 2 / self.theta["sigma2"] 

99 ) + to.matmul(Kfloat, to.log(self.theta["pies"])) 

100 return lpj.to(device=states.device) 

101 

102 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor: 

103 """Evaluate log-joints for GMM.""" 

104 if lpj is None: 

105 lpj = self.log_pseudo_joint(data, states) 

106 D = self.shape[0] 

107 return lpj - D / 2 * to.log(2 * math.pi * self.theta["sigma2"]) 

108 

109 def update_param_batch(self, idx: Tensor, batch: Tensor, states: Tensor) -> None: 

110 lpj = states.lpj[idx] 

111 K = states.K[idx] 

112 batch_size, S, _ = K.shape 

113 

114 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting 

115 Wbar = to.matmul( 

116 Kfloat, self.theta["W"].t() 

117 ) # N,S,D # TODO Find solution to re-use evaluations from E-step 

118 

119 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H) mean_posterior(Kfloat, lpj) 

120 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H) 

121 batch_sigma2 = mean_posterior(to.sum((batch[:, None, :] - Wbar) ** 2, dim=2), lpj) 

122 

123 self.my_pies.add_(to.sum(batch_s_pjc, dim=0)) 

124 self.my_Wp.add_(to.sum(batch_Wp, dim=0)) 

125 self.my_Wq.add_(to.sum(batch_s_pjc, dim=0)) 

126 self.my_sigma2.add_(to.sum(batch_sigma2)) 

127 self.my_N.add_(batch_size) 

128 

129 return None 

130 

131 def update_param_epoch(self) -> None: 

132 theta = self.theta 

133 policy = self.policy 

134 

135 all_reduce(self.my_Wp) 

136 all_reduce(self.my_Wq) 

137 all_reduce(self.my_pies) 

138 all_reduce(self.my_sigma2) 

139 all_reduce(self.my_N) 

140 

141 N = self.my_N.item() 

142 D = self.shape[0] 

143 

144 # Calculate updated W 

145 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"]) 

146 broadcast(Wold_noisy) 

147 theta_new = {} 

148 try: 

149 theta_new["W"] = self.my_Wp / self.my_Wq[None, :] 

150 except RuntimeError: 

151 pprint("Inversion error. Will not update W but add some noise instead.") 

152 theta_new["W"] = Wold_noisy 

153 

154 # Calculate updated pi 

155 theta_new["pies"] = self.my_pies / N 

156 

157 # Calculate updated sigma^2 

158 theta_new["sigma2"] = self.my_sigma2 / N / D 

159 

160 policy["W"][0] = Wold_noisy 

161 policy["pies"][0] = theta["pies"] 

162 policy["sigma2"][0] = theta["sigma2"] 

163 fix_theta(theta_new, policy) 

164 for key in theta: 

165 theta[key] = theta_new[key] 

166 

167 self.my_Wp[:] = 0.0 

168 self.my_Wq[:] = 0.0 

169 self.my_pies[:] = 0.0 

170 self.my_sigma2[:] = 0.0 

171 self.my_N[:] = 0.0 

172 

173 @property 

174 def shape(self) -> Tuple[int, ...]: 

175 return self.theta["W"].shape 

176 

177 def generate_data( 

178 self, N: int = None, hidden_state: to.Tensor = None 

179 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

180 precision, device = self.precision, tvo.get_device() 

181 D, H = self.shape 

182 

183 if hidden_state is None: 

184 assert N is not None 

185 pies = self.theta["pies"] 

186 hidden_state = OneHotCategorical(probs=pies).sample([N]) == 1 

187 must_return_hidden_state = True 

188 else: 

189 shape = hidden_state.shape 

190 if N is None: 

191 N = shape[0] 

192 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

193 must_return_hidden_state = False 

194 

195 Wbar = to.zeros((N, D), dtype=precision, device=device) 

196 

197 for n in range(N): 

198 for h in range(H): 

199 if hidden_state[n, h]: 

200 Wbar[n] += self.theta["W"][:, h] 

201 

202 # Add noise according to the model parameters 

203 Y = Wbar + to.sqrt(self.theta["sigma2"]) * to.randn((N, D), dtype=precision, device=device) 

204 

205 return (Y, hidden_state) if must_return_hidden_state else Y 

206 

207 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor: 

208 # Not yet implemented 

209 

210 """Estimator used for data reconstruction. Data reconstruction can only be supported 

211 by a model if it implements this method. The estimator to be implemented is defined 

212 as follows:""" r""" 

213 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

214 """ 

215 # Not 

216 K = states.K[idx] 

217 # TODO Find solution to avoid byte->float casting of `K` 

218 # TODO Pre-allocate tensor and use `out` argument of to.matmul 

219 return mean_posterior( 

220 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx] 

221 )