Coverage for tvo/models/noisyor.py: 99%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5 

6from tvo.utils.model_protocols import Optimized, Sampler 

7from tvo.variational import TVOVariationalStates # type: ignore 

8from tvo.variational._utils import mean_posterior 

9from tvo.utils.parallel import all_reduce, broadcast 

10from torch import Tensor 

11import torch as to 

12from typing import Dict, Optional, Union, Tuple 

13import tvo 

14 

15 

16class NoisyOR(Optimized, Sampler): 

17 eps = 1e-7 

18 

19 def __init__( 

20 self, 

21 H: int, 

22 D: int, 

23 W_init: Tensor = None, 

24 pi_init: Tensor = None, 

25 precision: to.dtype = to.float64, 

26 ): 

27 """Shallow NoisyOR model. 

28 

29 :param H: Number of hidden units. 

30 :param D: Number of observables. 

31 :param W_init: Tensor with shape (D,H), initializes NoisyOR weights. 

32 :param pi_init: Tensor with shape (H,), initializes NoisyOR priors. 

33 :param precision: Floating point precision required. Must be one of torch.float32 or 

34 torch.float64. 

35 """ 

36 

37 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}" 

38 self._precision = precision 

39 

40 device = tvo.get_device() 

41 

42 if W_init is not None: 

43 assert W_init.shape == (D, H) 

44 else: 

45 W_init = to.rand(D, H, device=device) 

46 broadcast(W_init) 

47 

48 if pi_init is not None: 

49 assert pi_init.shape == (H,) 

50 assert (pi_init <= 1.0).all() and (pi_init >= 0).all() 

51 else: 

52 pi_init = to.full((H,), 1.0 / H, device=device, dtype=self.precision) 

53 

54 self._theta = { 

55 "pies": pi_init.to(device=device, dtype=precision), 

56 "W": W_init.to(device=device, dtype=precision), 

57 } 

58 

59 self.new_pi = to.zeros(H, device=device, dtype=precision) 

60 self.Btilde = to.zeros(D, H, device=device, dtype=precision) 

61 self.Ctilde = to.zeros(D, H, device=device, dtype=precision) 

62 # number of datapoints processed in a training epoch 

63 self._train_datapoints = to.tensor([0], dtype=to.int, device=device) 

64 self._config = dict(H=H, D=D, precision=self.precision, device=device) 

65 self._shape = self.theta["W"].shape 

66 

67 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore 

68 """Evaluate log-pseudo-joints for NoisyOR.""" 

69 K = states 

70 Y = data 

71 assert K.dtype == to.uint8 and Y.dtype == to.uint8 

72 pi = self.theta["pies"] 

73 W = self.theta["W"] 

74 batch_size, S, H = K.shape 

75 D = W.shape[0] 

76 dev = pi.device 

77 

78 logPriors = to.matmul(K.type_as(pi), to.log(pi / (1 - pi))) 

79 

80 logPy = to.empty((batch_size, S), device=dev, dtype=self.precision) 

81 # We will manually set the lpjs of all-zero states to the appropriate value. 

82 # For now, transform all-zero states in all-one states, to avoid computation of log(0). 

83 zeroStatesInd = to.nonzero((K == 0).all(dim=2)) 

84 # https://discuss.pytorch.org/t/use-torch-nonzero-as-index/33218 

85 zeroStatesInd = (zeroStatesInd[:, 0], zeroStatesInd[:, 1]) 

86 K[zeroStatesInd] = 1 

87 # prods_nsd = prod{h}{1-W_dh*K_nkh} 

88 prods = (W * K.type_as(W).unsqueeze(2)).neg_().add_(1).prod(dim=-1) 

89 to.clamp(prods, self.eps, 1 - self.eps, out=prods) 

90 # logPy_nk = sum{d}{y_nd*log(1/prods_nkd - 1) + log(prods_nkd)} 

91 f1 = to.log(1.0 / prods - 1.0) 

92 indeces = 1 - Y[:, None, :].expand(batch_size, S, D) 

93 # convert to BoolTensor in pytorch>=1.2, leave it as ByteTensor in earlier versions 

94 indeces = indeces.type_as(to.empty(0) < 0) 

95 f1[indeces] = 0.0 

96 logPy[:, :] = to.sum(f1, dim=-1) + to.sum(to.log(prods), dim=2) 

97 K[zeroStatesInd] = 0 

98 

99 lpj = logPriors + logPy 

100 # for all-zero states, set lpj to arbitrary very low value if y!=0, 0 otherwise 

101 # in the end we want exp(lpj(y,s=0)) = 1 if y=0, 0 otherwise 

102 lpj[zeroStatesInd] = -1e30 * data[zeroStatesInd[0]].any(dim=1).type_as(lpj) 

103 assert ( 

104 not to.isnan(lpj).any() and not to.isinf(lpj).any() 

105 ), "some NoisyOR lpj values are invalid!" 

106 return lpj.to(device=states.device) # (N, S) 

107 

108 def update_param_batch( 

109 self, 

110 idx: Tensor, 

111 batch: Tensor, 

112 states: TVOVariationalStates, 

113 mstep_factors: Dict[str, Tensor] = None, 

114 ) -> Optional[float]: 

115 lpj = states.lpj[idx] 

116 K = states.K[idx] 

117 Kfloat = K.type_as(lpj) 

118 

119 # pi_h = sum{n}{<K_hns>} / N 

120 # (division by N has to wait until after the mpi all_reduce) 

121 self.new_pi += mean_posterior(Kfloat, lpj).sum(dim=0) 

122 assert not to.isnan(self.new_pi).any() 

123 

124 # Ws_nsdh = 1 - (W_dh * Kfloat_nsh) 

125 Ws = (self.theta["W"][None, None, :, :] * Kfloat[:, :, None, :]).neg_().add_(1) 

126 Ws_prod = to.prod(Ws, dim=3, keepdim=True) 

127 B = Kfloat.unsqueeze(2) / (Ws * Ws_prod.neg().add_(1)).add_(self.eps) # (N,S,D,H) 

128 self.Btilde.add_( 

129 (mean_posterior(B, lpj) * (batch.type_as(lpj) - 1).unsqueeze(2)).sum(dim=0) 

130 ) 

131 C = B.mul_(Ws_prod).div_(Ws) # (N,S,D,H) 

132 self.Ctilde.add_(to.sum(mean_posterior(C, lpj), dim=0)) 

133 assert not to.isnan(self.Ctilde).any() 

134 assert not to.isnan(self.Btilde).any() 

135 

136 self._train_datapoints.add_(batch.shape[0]) 

137 

138 return None 

139 

140 def update_param_epoch(self) -> None: 

141 all_reduce(self._train_datapoints) 

142 all_reduce(self.new_pi) 

143 N = self._train_datapoints.item() 

144 self.theta["pies"][:] = self.new_pi / N 

145 to.clamp(self.theta["pies"], self.eps, 1 - self.eps, out=self.theta["pies"]) 

146 self.new_pi[:] = 0.0 

147 

148 all_reduce(self.Btilde) 

149 all_reduce(self.Ctilde) 

150 self.theta["W"][:] = 1 + self.Btilde / (self.Ctilde + self.eps) 

151 to.clamp(self.theta["W"], self.eps, 1 - self.eps, out=self.theta["W"]) 

152 self.Btilde[:] = 0.0 

153 self.Ctilde[:] = 0.0 

154 

155 self._train_datapoints[:] = 0 

156 

157 def log_joint(self, data, states, lpj=None): 

158 pi = self.theta["pies"] 

159 if lpj is None: 

160 lpj = self.log_pseudo_joint(data, states) 

161 # TODO: could pre-evaluate the constant factor once per epoch 

162 return to.sum(to.log(1 - pi)) + lpj 

163 

164 def generate_data( 

165 self, N: int = None, hidden_state: Tensor = None 

166 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

167 """Use hidden states to sample datapoints according to the NoisyOR generative model. 

168 

169 :param hidden_state: a tensor with shape (N, H) where H is the number of hidden units. 

170 :returns: the datapoints, as a tensor with shape (N, D) where D is 

171 the number of observables. 

172 """ 

173 theta = self.theta 

174 W = theta["W"] 

175 D, H = W.shape 

176 

177 if hidden_state is None: 

178 pies = theta["pies"] 

179 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies 

180 must_return_hidden_state = True 

181 else: 

182 if N is not None: 

183 shape = hidden_state.shape 

184 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})" 

185 must_return_hidden_state = False 

186 

187 # py_nd = 1 - prod_h (1 - W_dh * s_nh) 

188 py = 1 - to.prod(1 - W[None, :, :] * hidden_state.type_as(W)[:, None, :], dim=2) 

189 Y = (to.rand_like(py) < py).byte() 

190 

191 return (Y, hidden_state) if must_return_hidden_state else Y