Coverage for tvo/utils/model_protocols.py: 93%
88 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5from typing_extensions import Protocol, runtime_checkable
6from typing import Tuple, Dict, Any, Optional, Union, TYPE_CHECKING
7import torch as to
8from tvo.utils.parallel import mpi_average_grads
9from abc import abstractmethod
11if TYPE_CHECKING:
12 from tvo.variational.TVOVariationalStates import TVOVariationalStates
15@runtime_checkable
16class Trainable(Protocol):
17 """The most basic model.
19 Requires implementation of log_joint, update_parameter_batch, update_parameter_epoch.
20 Provides default implementation of free_energy.
21 """
23 _theta: Dict[str, to.Tensor]
24 _config: Dict[str, Any] = {}
25 _optimizer: Optional[to.optim.Optimizer] = None
27 @abstractmethod
28 def log_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
29 """Evaluate log-joint probabilities for this model.
31 :param data: shape is (N,D)
32 :param states: shape is (N,S,H)
33 :returns: log-joints for data and states - shape is (N,S)
34 """
35 ...
37 def update_param_batch(
38 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates"
39 ) -> Optional[float]:
40 """Execute batch-wise M-step or batch-wise section of an M-step computation.
42 :param idx: indexes of the datapoints that compose the batch within the dataset
43 :param batch: batch of datapoints, Tensor with shape (N,D)
44 :param states: all variational states for this dataset
45 :param mstep_factors: optional dictionary containing the Tensors that were evaluated\
46 by the lpj_fn function returned by get_lpj_func during this batch's E-step.
48 If the model allows it, as an optimization this method can return this batch's free energy
49 evaluated _before_ the model parameter update. If the batch's free energy is returned here,
50 Trainers will skip a direct per-batch call to the free_energy method.
51 """
52 # by default, perform gradient-based parameter updates
53 if self._optimizer is None:
54 for t in self._theta.values():
55 t.requires_grad_(True)
56 self._optimizer = to.optim.Adam(self._theta.values())
57 assert self._optimizer is not None # to make mypy happy
58 log_joints = self.log_joint(batch, states.K[idx])
59 F = to.logsumexp(log_joints, dim=1).sum(dim=0)
60 loss = -F / batch.shape[0]
61 loss.backward()
62 mpi_average_grads(self.theta)
63 self._optimizer.step()
64 self._optimizer.zero_grad()
66 return F.item()
68 def update_param_epoch(self) -> None:
69 """Execute epoch-wise M-step or epoch-wise section of an M-step computation.
71 This method is called at the end of each training epoch.
72 Implementing this method is optional: models can leave the body empty (just a `pass`)
73 or even not implement it at all.
74 """
75 # by default, do nothing
76 return
78 def free_energy(
79 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates"
80 ) -> float:
81 """Evaluate free energy for the given batch of datapoints.
83 :param idx: indexes of the datapoints in batch within the full dataset
84 :param batch: batch of datapoints, Tensor with shape (N,D)
85 :param states: all TVOVariationalStates states for this dataset
87 .. note::
88 This default implementation of free_energy is only appropriate for Trainable models
89 that are not Optimized.
90 """
91 log_joints = states.lpj[idx] # these are actual log-joints in Trainable models
92 return to.logsumexp(log_joints, dim=1).sum(dim=0).item()
94 @property
95 def shape(self) -> Tuple[int, ...]:
96 """The model shape, i.e. number of observables D and latents H as tuple (D,H)
98 :returns: the model shape: observable layer size followed by hidden layer size, e.g. (D, H)
100 The default implementation returns self._shape if present, otherwise it tries to infer the
101 model's shape from the parameters self.theta: the number of latents is assumed to be equal
102 to the first dimension of the first tensor in self.theta, and the number of observables is
103 assumed to be equal to the last dimension of the last parameter in self.theta.
104 """
105 if hasattr(self, "_shape"):
106 return getattr(self, "_shape")
107 assert (
108 len(self.theta) != 0
109 ), "Cannot infer the model shape from self.theta and self._shape is not defined"
110 th = list(self.theta.values())
111 return (th[-1].shape[-1], th[0].shape[0])
113 @property
114 def config(self) -> Dict[str, Any]:
115 """Model configuration.
117 The default implementation returns self._config.
118 """
119 return self._config
121 @property
122 def theta(self) -> Dict[str, to.Tensor]:
123 """Dictionary of model parameters.
125 The default implementation returns self._theta.
126 """
127 return self._theta
129 @property
130 def precision(self) -> to.dtype:
131 """The floating point precision the model works at (either to.float32 or to.float64).
133 The default implementation returns self._precision or, if not present, the precision of
134 model parameters self.theta (expected to be identical for all floating point parameters).
135 """
136 if hasattr(self, "_precision"):
137 return getattr(self, "_precision")
138 assert len(self.theta) != 0
139 prec: to.dtype = None
140 for dt in (p.dtype for p in self.theta.values() if p.dtype.is_floating_point):
141 assert prec is None or dt == prec
142 prec = dt
143 return prec
146@runtime_checkable
147class Optimized(Trainable, Protocol):
148 """Additionally implements log_pseudo_joint, init_storage, init_batch, init_epoch."""
150 @abstractmethod
151 def log_joint(self, data: to.Tensor, states: to.Tensor, lpj: to.Tensor = None) -> to.Tensor:
152 """Evaluate log-joint probabilities for this model.
154 :param data: shape is (N,D)
155 :param states: shape is (N,S,H)
156 :param lpj: shape is (N,S). When lpj is not None it must contain pre-evaluated
157 log-pseudo joints for the given data and states. The implementation can take
158 advantage of the extra argument to save computation.
159 :returns: log-joints for data and states - shape is (N,S)
160 """
161 raise NotImplementedError
163 @abstractmethod
164 def log_pseudo_joint(self, data: to.Tensor, states: to.Tensor) -> to.Tensor:
165 """Evaluate log-pseudo-joint probabilities for this model.
167 :param data: shape is (N,D)
168 :param states: shape is (N,S,H)
169 :returns: log-pseudo-joints for data and states - shape is (N,S)
171 Log-pseudo-joint probabilities are the log-joint probabilities of the model
172 for the specified set of datapoints and variational states where, potentially,
173 some factors that do not depend on the variational states have been elided.
175 Implementation of this method is an optional performance optimization.
176 """
177 ...
179 def free_energy(
180 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates"
181 ) -> float:
182 """Evaluate free energy for the given batch of datapoints.
184 :param idx: indexes of the datapoints in batch within the full dataset
185 :param batch: batch of datapoints, Tensor with shape (N,D)
186 :param states: all TVOVariationalStates states for this dataset
188 .. note::
189 This default implementation of free_energy is only appropriate for Optimized models.
190 """
191 with to.no_grad():
192 log_joints = self.log_joint(batch, states.K[idx], states.lpj[idx])
193 return to.logsumexp(log_joints, dim=1).sum(dim=0).item()
195 def init_storage(self, S: int, Snew: int, batch_size: int) -> None:
196 """This method is called once by an experiment when initializing a model
198 :param n_states: Number of variational states per datapoint to keep in memory
199 :param n_new_states: Number of new states per datapoint sampled in variational E-step
200 :param batch_size: Batch size used by the data loader
202 Concrete models can optionally override this method if it's convenient.
203 By default, it does nothing.
204 """
205 pass
207 def init_epoch(self) -> None:
208 """This method is called once at the beginning of each training epoch.
210 Concrete models can optionally override this method if it's convenient.
211 By default, it does nothing.
212 """
213 pass
215 def init_batch(self) -> None:
216 """Model-specific initializations per batch."""
217 pass
219 @property
220 def sorted_by_lpj(self) -> Dict[str, to.Tensor]:
221 """Optional dictionary of Tensors that must be kept ordered in sync with log-pseudo-joints.
223 The Trainer will take care that the tensors in this dictionary are sorted the same way
224 log-pseudo-joints are during an E-step.
225 Tensors must have shapes (batch_size, S, ...) where S is the number of variational
226 states per datapoint used during training.
227 By default the dictionary is empty. Concrete models can override this property if need be.
228 """
229 return {}
232@runtime_checkable
233class Sampler(Protocol):
234 """Implements generate_data (hidden_state is an optional parameter)."""
236 @abstractmethod
237 def generate_data(
238 self, N: int = None, hidden_state: to.Tensor = None
239 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
240 """Sample N datapoints from this model. At least one of N or hidden_state must be provided.
242 :param N: number of data points to be generated.
243 :param hidden_state: Tensor with shape (N,H) where H is the number of units in the
244 first latent layer.
245 :returns: if hidden_state was not provided, a tuple (data, hidden_state) where data is
246 a Tensor with shape (N, D) where D is the number of observables for this model
247 and hidden_state is the corresponding tensor of hidden variables with shape
248 (N, H) where H is the number of hidden variables for this model.
249 """
250 ...
253@runtime_checkable
254class Reconstructor(Protocol):
255 """Implements data_estimator."""
257 @abstractmethod
258 def data_estimator(
259 self, idx: to.Tensor, batch: to.Tensor, states: "TVOVariationalStates"
260 ) -> to.Tensor:
261 """Estimator used for data reconstruction. Data reconstruction can only be supported
262 by a model if it implements this method. The estimator to be implemented is defined
263 as follows:""" r"""
264 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
265 """
266 ...