Coverage for tvo/models/bsc.py: 96%

127 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import math 

6import torch as to 

7 

8from torch import Tensor 

9from typing import Union, Tuple 

10 

11import tvo 

12from tvo.utils.parallel import pprint, all_reduce, broadcast 

13from tvo.variational.TVOVariationalStates import TVOVariationalStates 

14from tvo.variational._utils import mean_posterior, lpj2pjc 

15from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor 

16from tvo.utils.sanity import fix_theta 

17from tvo.utils._utils import get_lstsq 

18 

19 

20lstsq = get_lstsq(torch=to) 

21 

22 

23class BSC(Optimized, Sampler, Reconstructor): 

24 def __init__( 

25 self, 

26 H: int, 

27 D: int, 

28 W_init: Tensor = None, 

29 sigma2_init: Tensor = None, 

30 pies_init: Tensor = None, 

31 individual_priors: bool = True, 

32 precision: to.dtype = to.float64, 

33 ): 

34 """Shallow Binary Sparse Coding (BSC) model. 

35 

36 :param H: Number of hidden units. 

37 :param D: Number of observables. 

38 :param W_init: Tensor with shape (D,H), initializes BSC weights. 

39 :param pies_init: Tensor with shape (H,), initializes BSC priors. 

40 :param individual_priors: Whether to use a Bernoulli prior with H individual prior 

41 probabilities. If False, the same prior probability will be used for all latents. 

42 :param precision: Floating point precision required. Must be one of torch.float32 or 

43 torch.float64. 

44 

45 """ 

46 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}" 

47 self._precision = precision 

48 

49 device = tvo.get_device() 

50 

51 if W_init is not None: 

52 assert W_init.shape == (D, H) 

53 W_init = W_init.to(dtype=precision, device=device) 

54 else: 

55 W_init = to.rand((D, H), dtype=precision, device=device) 

56 broadcast(W_init) 

57 

58 if pies_init is not None: 

59 assert pies_init.shape == (H,) if individual_priors else pies_init.shape == (1,) 

60 pies_init = pies_init.to(dtype=precision, device=device) 

61 else: 

62 pies_init = ( 

63 to.full((H,), 1.0 / H, dtype=precision, device=device) 

64 if individual_priors 

65 else to.tensor([1.0 / H], dtype=precision, device=device) 

66 ) 

67 

68 if sigma2_init is not None: 

69 assert sigma2_init.shape == (1,) 

70 sigma2_init = sigma2_init.to(dtype=precision, device=device) 

71 else: 

72 sigma2_init = to.tensor([1.0], dtype=precision, device=device) 

73 

74 self._theta = {"pies": pies_init, "W": W_init, "sigma2": sigma2_init} 

75 eps, inf = 1.0e-5, math.inf 

76 self.policy = { 

77 "W": [None, to.full_like(self._theta["W"], -inf), to.full_like(self._theta["W"], inf)], 

78 "pies": [ 

79 None, 

80 to.full_like(self._theta["pies"], eps), 

81 to.full_like(self._theta["pies"], 1.0 - eps), 

82 ], 

83 "sigma2": [ 

84 None, 

85 to.full_like(self._theta["sigma2"], eps), 

86 to.full_like(self._theta["sigma2"], inf), 

87 ], 

88 } 

89 

90 self.my_Wp = to.zeros((D, H), dtype=precision, device=device) 

91 self.my_Wq = to.zeros((H, H), dtype=precision, device=device) 

92 self.my_pies = to.zeros(H, dtype=precision, device=device) 

93 self.my_sigma2 = to.zeros(1, dtype=precision, device=device) 

94 self.my_N = to.tensor([0], dtype=to.int, device=device) 

95 self._config = dict( 

96 H=H, D=D, individual_priors=individual_priors, precision=precision, device=device 

97 ) 

98 self._shape = self.theta["W"].shape 

99 

100 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore 

101 """Evaluate log-pseudo-joints for BSC""" 

102 Kfloat = states.to( 

103 dtype=self.theta["W"].dtype 

104 ) # TODO Find solution to avoid byte->float casting 

105 Wbar = to.matmul( 

106 Kfloat, self.theta["W"].t() 

107 ) # TODO Pre-allocate tensor and use `out` argument of to.matmul 

108 Kpriorterm = ( 

109 to.matmul(Kfloat, to.log(self.theta["pies"] / (1 - self.theta["pies"]))) 

110 if self.config["individual_priors"] 

111 else to.log(self.theta["pies"] / (1 - self.theta["pies"])) * Kfloat.sum(dim=2) 

112 ) 

113 lpj = ( 

114 to.mul( 

115 to.nansum(to.pow(Wbar - data[:, None, :], 2), dim=2), -1 / 2 / self.theta["sigma2"] 

116 ) 

117 + Kpriorterm 

118 ) 

119 return lpj.to(device=states.device) 

120 

121 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor: 

122 """Evaluate log-joints for BSC.""" 

123 if lpj is None: 

124 lpj = self.log_pseudo_joint(data, states) 

125 D = to.sum(to.logical_not(to.isnan(data)), dim=1) # (N,) 

126 H = self.shape[1] 

127 priorterm = ( 

128 to.log(1 - self.theta["pies"]).sum() 

129 if self.config["individual_priors"] 

130 else H * to.log(1 - self.theta["pies"]) 

131 ) 

132 return lpj + priorterm - D.unsqueeze(1) / 2 * to.log(2 * math.pi * self.theta["sigma2"]) 

133 

134 def update_param_batch(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> None: 

135 lpj = states.lpj[idx] 

136 K = states.K[idx] 

137 batch_size, S, _ = K.shape 

138 

139 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting 

140 Wbar = to.matmul( 

141 Kfloat, self.theta["W"].t() 

142 ) # TODO Find solution to re-use evaluations from E-step 

143 

144 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H) 

145 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H) 

146 Kq = Kfloat.mul(lpj2pjc(lpj)[:, :, None]) 

147 batch_Wq = to.einsum("ijk,ijl->kl", Kq, Kfloat) # is (batch_size,H,H) 

148 batch_sigma2 = mean_posterior( 

149 to.sum((batch[:, None, :] - Wbar) ** 2, dim=2), lpj 

150 ) # is (batch_size,) 

151 

152 self.my_pies.add_(to.sum(batch_s_pjc, dim=0)) 

153 self.my_Wp.add_(to.sum(batch_Wp, dim=0)) 

154 self.my_Wq.add_(batch_Wq) 

155 self.my_sigma2.add_(to.sum(batch_sigma2)) 

156 self.my_N.add_(batch_size) 

157 

158 return None 

159 

160 def update_param_epoch(self) -> None: 

161 theta = self.theta 

162 policy = self.policy 

163 

164 all_reduce(self.my_Wp) 

165 all_reduce(self.my_Wq) 

166 all_reduce(self.my_pies) 

167 all_reduce(self.my_sigma2) 

168 all_reduce(self.my_N) 

169 

170 N = self.my_N.item() 

171 D, H = self.shape 

172 

173 # Calculate updated W 

174 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"]) 

175 broadcast(Wold_noisy) 

176 theta_new = {} 

177 try: 

178 theta_new["W"] = to.linalg.lstsq(self.my_Wq, self.my_Wp.t())[0].t() 

179 except RuntimeError: 

180 pprint("Inversion error. Will not update W but add some noise instead.") 

181 theta_new["W"] = Wold_noisy 

182 

183 # Calculate updated pi 

184 theta_new["pies"] = ( 

185 self.my_pies / N 

186 if self.config["individual_priors"] 

187 else self.my_pies.sum(dim=0, keepdim=True) / N / H 

188 ) 

189 

190 # Calculate updated sigma^2 

191 theta_new["sigma2"] = self.my_sigma2 / N / D 

192 

193 policy["W"][0] = Wold_noisy 

194 policy["pies"][0] = theta["pies"] 

195 policy["sigma2"][0] = theta["sigma2"] 

196 fix_theta(theta_new, policy) 

197 for key in theta: 

198 theta[key][:] = theta_new[key] 

199 

200 self.my_Wp[:] = 0.0 

201 self.my_Wq[:] = 0.0 

202 self.my_pies[:] = 0.0 

203 self.my_sigma2[:] = 0.0 

204 self.my_N[:] = 0.0 

205 

206 @property 

207 def shape(self) -> Tuple[int, ...]: 

208 return self.theta["W"].shape 

209 

210 def generate_data( 

211 self, N: int = None, hidden_state: to.Tensor = None 

212 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

213 precision, device = self.precision, tvo.get_device() 

214 D, H = self.shape 

215 

216 if hidden_state is None: 

217 assert N is not None 

218 pies = self.theta["pies"] 

219 hidden_state = to.rand((N, H), dtype=precision, device=device) < pies 

220 must_return_hidden_state = True 

221 else: 

222 shape = hidden_state.shape 

223 if N is None: 

224 N = shape[0] 

225 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

226 must_return_hidden_state = False 

227 

228 Wbar = to.zeros((N, D), dtype=precision, device=device) 

229 

230 # Linear superposition 

231 for n in range(N): 

232 for h in range(H): 

233 if hidden_state[n, h]: 

234 Wbar[n] += self.theta["W"][:, h] 

235 

236 # Add noise according to the model parameters 

237 Y = Wbar + to.sqrt(self.theta["sigma2"]) * to.randn((N, D), dtype=precision, device=device) 

238 

239 return (Y, hidden_state) if must_return_hidden_state else Y 

240 

241 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor: 

242 """Estimator used for data reconstruction. Data reconstruction can only be supported 

243 by a model if it implements this method. The estimator to be implemented is defined 

244 as follows:""" r""" 

245 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

246 """ 

247 K = states.K[idx] 

248 # TODO Find solution to avoid byte->float casting of `K` 

249 # TODO Pre-allocate tensor and use `out` argument of to.matmul 

250 return mean_posterior( 

251 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx] 

252 )