Coverage for tvo/utils/model_protocols.py: 93%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5from typing_extensions import Protocol, runtime_checkable 

6from typing import Tuple, Dict, Any, Optional, Union, TYPE_CHECKING 

7import torch as to 

8from tvo.utils.parallel import mpi_average_grads 

9from abc import abstractmethod 

10 

11if TYPE_CHECKING: 

12 from tvo.variational.TVOVariationalStates import TVOVariationalStates 

13 

14 

15@runtime_checkable 

16class Trainable(Protocol): 

17 """The most basic model. 

18 

19 Requires implementation of log_joint, update_parameter_batch, update_parameter_epoch. 

20 Provides default implementation of free_energy. 

21 """ 

22 

23 _theta: Dict[str, to.Tensor] 

24 _config: Dict[str, Any] = {} 

25 _optimizer: Optional[to.optim.Optimizer] = None 

26 

27 @abstractmethod 

28 def log_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

29 """Evaluate log-joint probabilities for this model. 

30 

31 :param data: shape is (N,D) 

32 :param states: shape is (N,S,H) 

33 :returns: log-joints for data and states - shape is (N,S) 

34 """ 

35 ... 

36 

37 def update_param_batch( 

38 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates" 

39 ) -> Optional[float]: 

40 """Execute batch-wise M-step or batch-wise section of an M-step computation. 

41 

42 :param idx: indexes of the datapoints that compose the batch within the dataset 

43 :param batch: batch of datapoints, Tensor with shape (N,D) 

44 :param states: all variational states for this dataset 

45 :param mstep_factors: optional dictionary containing the Tensors that were evaluated\ 

46 by the lpj_fn function returned by get_lpj_func during this batch's E-step. 

47 

48 If the model allows it, as an optimization this method can return this batch's free energy 

49 evaluated _before_ the model parameter update. If the batch's free energy is returned here, 

50 Trainers will skip a direct per-batch call to the free_energy method. 

51 """ 

52 # by default, perform gradient-based parameter updates 

53 if self._optimizer is None: 

54 for t in self._theta.values(): 

55 t.requires_grad_(True) 

56 self._optimizer = to.optim.Adam(self._theta.values()) 

57 assert self._optimizer is not None # to make mypy happy 

58 log_joints = self.log_joint(batch, states.K[idx]) 

59 F = to.logsumexp(log_joints, dim=1).sum(dim=0) 

60 loss = -F / batch.shape[0] 

61 loss.backward() 

62 mpi_average_grads(self.theta) 

63 self._optimizer.step() 

64 self._optimizer.zero_grad() 

65 

66 return F.item() 

67 

68 def update_param_epoch(self) -> None: 

69 """Execute epoch-wise M-step or epoch-wise section of an M-step computation. 

70 

71 This method is called at the end of each training epoch. 

72 Implementing this method is optional: models can leave the body empty (just a `pass`) 

73 or even not implement it at all. 

74 """ 

75 # by default, do nothing 

76 return 

77 

78 def free_energy( 

79 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates" 

80 ) -> float: 

81 """Evaluate free energy for the given batch of datapoints. 

82 

83 :param idx: indexes of the datapoints in batch within the full dataset 

84 :param batch: batch of datapoints, Tensor with shape (N,D) 

85 :param states: all TVOVariationalStates states for this dataset 

86 

87 .. note:: 

88 This default implementation of free_energy is only appropriate for Trainable models 

89 that are not Optimized. 

90 """ 

91 log_joints = states.lpj[idx] # these are actual log-joints in Trainable models 

92 return to.logsumexp(log_joints, dim=1).sum(dim=0).item() 

93 

94 @property 

95 def shape(self) -> Tuple[int, ...]: 

96 """The model shape, i.e. number of observables D and latents H as tuple (D,H) 

97 

98 :returns: the model shape: observable layer size followed by hidden layer size, e.g. (D, H) 

99 

100 The default implementation returns self._shape if present, otherwise it tries to infer the 

101 model's shape from the parameters self.theta: the number of latents is assumed to be equal 

102 to the first dimension of the first tensor in self.theta, and the number of observables is 

103 assumed to be equal to the last dimension of the last parameter in self.theta. 

104 """ 

105 if hasattr(self, "_shape"): 

106 return getattr(self, "_shape") 

107 assert ( 

108 len(self.theta) != 0 

109 ), "Cannot infer the model shape from self.theta and self._shape is not defined" 

110 th = list(self.theta.values()) 

111 return (th[-1].shape[-1], th[0].shape[0]) 

112 

113 @property 

114 def config(self) -> Dict[str, Any]: 

115 """Model configuration. 

116 

117 The default implementation returns self._config. 

118 """ 

119 return self._config 

120 

121 @property 

122 def theta(self) -> Dict[str, to.Tensor]: 

123 """Dictionary of model parameters. 

124 

125 The default implementation returns self._theta. 

126 """ 

127 return self._theta 

128 

129 @property 

130 def precision(self) -> to.dtype: 

131 """The floating point precision the model works at (either to.float32 or to.float64). 

132 

133 The default implementation returns self._precision or, if not present, the precision of 

134 model parameters self.theta (expected to be identical for all floating point parameters). 

135 """ 

136 if hasattr(self, "_precision"): 

137 return getattr(self, "_precision") 

138 assert len(self.theta) != 0 

139 prec: to.dtype = None 

140 for dt in (p.dtype for p in self.theta.values() if p.dtype.is_floating_point): 

141 assert prec is None or dt == prec 

142 prec = dt 

143 return prec 

144 

145 

146@runtime_checkable 

147class Optimized(Trainable, Protocol): 

148 """Additionally implements log_pseudo_joint, init_storage, init_batch, init_epoch.""" 

149 

150 @abstractmethod 

151 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj: to.Tensor = None) -> to.Tensor: 

152 """Evaluate log-joint probabilities for this model. 

153 

154 :param data: shape is (N,D) 

155 :param states: shape is (N,S,H) 

156 :param lpj: shape is (N,S). When lpj is not None it must contain pre-evaluated 

157 log-pseudo joints for the given data and states. The implementation can take 

158 advantage of the extra argument to save computation. 

159 :returns: log-joints for data and states - shape is (N,S) 

160 """ 

161 raise NotImplementedError 

162 

163 @abstractmethod 

164 def log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor: 

165 """Evaluate log-pseudo-joint probabilities for this model. 

166 

167 :param data: shape is (N,D) 

168 :param states: shape is (N,S,H) 

169 :returns: log-pseudo-joints for data and states - shape is (N,S) 

170 

171 Log-pseudo-joint probabilities are the log-joint probabilities of the model 

172 for the specified set of datapoints and variational states where, potentially, 

173 some factors that do not depend on the variational states have been elided. 

174 

175 Implementation of this method is an optional performance optimization. 

176 """ 

177 ... 

178 

179 def free_energy( 

180 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates" 

181 ) -> float: 

182 """Evaluate free energy for the given batch of datapoints. 

183 

184 :param idx: indexes of the datapoints in batch within the full dataset 

185 :param batch: batch of datapoints, Tensor with shape (N,D) 

186 :param states: all TVOVariationalStates states for this dataset 

187 

188 .. note:: 

189 This default implementation of free_energy is only appropriate for Optimized models. 

190 """ 

191 with to.no_grad(): 

192 log_joints = self.log_joint(batch, states.K[idx], states.lpj[idx]) 

193 return to.logsumexp(log_joints, dim=1).sum(dim=0).item() 

194 

195 def init_storage(self, S: int, Snew: int, batch_size: int) -> None: 

196 """This method is called once by an experiment when initializing a model 

197 

198 :param n_states: Number of variational states per datapoint to keep in memory 

199 :param n_new_states: Number of new states per datapoint sampled in variational E-step 

200 :param batch_size: Batch size used by the data loader 

201 

202 Concrete models can optionally override this method if it's convenient. 

203 By default, it does nothing. 

204 """ 

205 pass 

206 

207 def init_epoch(self) -> None: 

208 """This method is called once at the beginning of each training epoch. 

209 

210 Concrete models can optionally override this method if it's convenient. 

211 By default, it does nothing. 

212 """ 

213 pass 

214 

215 def init_batch(self) -> None: 

216 """Model-specific initializations per batch.""" 

217 pass 

218 

219 @property 

220 def sorted_by_lpj(self) -> Dict[str, to.Tensor]: 

221 """Optional dictionary of Tensors that must be kept ordered in sync with log-pseudo-joints. 

222 

223 The Trainer will take care that the tensors in this dictionary are sorted the same way 

224 log-pseudo-joints are during an E-step. 

225 Tensors must have shapes (batch_size, S, ...) where S is the number of variational 

226 states per datapoint used during training. 

227 By default the dictionary is empty. Concrete models can override this property if need be. 

228 """ 

229 return {} 

230 

231 

232@runtime_checkable 

233class Sampler(Protocol): 

234 """Implements generate_data (hidden_state is an optional parameter).""" 

235 

236 @abstractmethod 

237 def generate_data( 

238 self, N: int = None, hidden_state: to.Tensor = None 

239 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]: 

240 """Sample N datapoints from this model. At least one of N or hidden_state must be provided. 

241 

242 :param N: number of data points to be generated. 

243 :param hidden_state: Tensor with shape (N,H) where H is the number of units in the 

244 first latent layer. 

245 :returns: if hidden_state was not provided, a tuple (data, hidden_state) where data is 

246 a Tensor with shape (N, D) where D is the number of observables for this model 

247 and hidden_state is the corresponding tensor of hidden variables with shape 

248 (N, H) where H is the number of hidden variables for this model. 

249 """ 

250 ... 

251 

252 

253@runtime_checkable 

254class Reconstructor(Protocol): 

255 """Implements data_estimator.""" 

256 

257 @abstractmethod 

258 def data_estimator( 

259 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates" 

260 ) -> to.Tensor: 

261 """Estimator used for data reconstruction. Data reconstruction can only be supported 

262 by a model if it implements this method. The estimator to be implemented is defined 

263 as follows:""" r""" 

264 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa 

265 """ 

266 ...