Coverage for tvo/models/bsc.py: 96%
127 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import math
6import torch as to
8from torch import Tensor
9from typing import Union, Tuple
11import tvo
12from tvo.utils.parallel import pprint, all_reduce, broadcast
13from tvo.variational.TVOVariationalStates import TVOVariationalStates
14from tvo.variational._utils import mean_posterior, lpj2pjc
15from tvo.utils.model_protocols import Optimized, Sampler, Reconstructor
16from tvo.utils.sanity import fix_theta
17from tvo.utils._utils import get_lstsq
20lstsq = get_lstsq(torch=to)
23class BSC(Optimized, Sampler, Reconstructor):
24 def __init__(
25 self,
26 H: int,
27 D: int,
28 W_init: Tensor = None,
29 sigma2_init: Tensor = None,
30 pies_init: Tensor = None,
31 individual_priors: bool = True,
32 precision: to.dtype = to.float64,
33 ):
34 """Shallow Binary Sparse Coding (BSC) model.
36 :param H: Number of hidden units.
37 :param D: Number of observables.
38 :param W_init: Tensor with shape (D,H), initializes BSC weights.
39 :param pies_init: Tensor with shape (H,), initializes BSC priors.
40 :param individual_priors: Whether to use a Bernoulli prior with H individual prior
41 probabilities. If False, the same prior probability will be used for all latents.
42 :param precision: Floating point precision required. Must be one of torch.float32 or
43 torch.float64.
45 """
46 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}"
47 self._precision = precision
49 device = tvo.get_device()
51 if W_init is not None:
52 assert W_init.shape == (D, H)
53 W_init = W_init.to(dtype=precision, device=device)
54 else:
55 W_init = to.rand((D, H), dtype=precision, device=device)
56 broadcast(W_init)
58 if pies_init is not None:
59 assert pies_init.shape == (H,) if individual_priors else pies_init.shape == (1,)
60 pies_init = pies_init.to(dtype=precision, device=device)
61 else:
62 pies_init = (
63 to.full((H,), 1.0 / H, dtype=precision, device=device)
64 if individual_priors
65 else to.tensor([1.0 / H], dtype=precision, device=device)
66 )
68 if sigma2_init is not None:
69 assert sigma2_init.shape == (1,)
70 sigma2_init = sigma2_init.to(dtype=precision, device=device)
71 else:
72 sigma2_init = to.tensor([1.0], dtype=precision, device=device)
74 self._theta = {"pies": pies_init, "W": W_init, "sigma2": sigma2_init}
75 eps, inf = 1.0e-5, math.inf
76 self.policy = {
77 "W": [None, to.full_like(self._theta["W"], -inf), to.full_like(self._theta["W"], inf)],
78 "pies": [
79 None,
80 to.full_like(self._theta["pies"], eps),
81 to.full_like(self._theta["pies"], 1.0 - eps),
82 ],
83 "sigma2": [
84 None,
85 to.full_like(self._theta["sigma2"], eps),
86 to.full_like(self._theta["sigma2"], inf),
87 ],
88 }
90 self.my_Wp = to.zeros((D, H), dtype=precision, device=device)
91 self.my_Wq = to.zeros((H, H), dtype=precision, device=device)
92 self.my_pies = to.zeros(H, dtype=precision, device=device)
93 self.my_sigma2 = to.zeros(1, dtype=precision, device=device)
94 self.my_N = to.tensor([0], dtype=to.int, device=device)
95 self._config = dict(
96 H=H, D=D, individual_priors=individual_priors, precision=precision, device=device
97 )
98 self._shape = self.theta["W"].shape
100 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore
101 """Evaluate log-pseudo-joints for BSC"""
102 Kfloat = states.to(
103 dtype=self.theta["W"].dtype
104 ) # TODO Find solution to avoid byte->float casting
105 Wbar = to.matmul(
106 Kfloat, self.theta["W"].t()
107 ) # TODO Pre-allocate tensor and use `out` argument of to.matmul
108 Kpriorterm = (
109 to.matmul(Kfloat, to.log(self.theta["pies"] / (1 - self.theta["pies"])))
110 if self.config["individual_priors"]
111 else to.log(self.theta["pies"] / (1 - self.theta["pies"])) * Kfloat.sum(dim=2)
112 )
113 lpj = (
114 to.mul(
115 to.nansum(to.pow(Wbar - data[:, None, :], 2), dim=2), -1 / 2 / self.theta["sigma2"]
116 )
117 + Kpriorterm
118 )
119 return lpj.to(device=states.device)
121 def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor:
122 """Evaluate log-joints for BSC."""
123 if lpj is None:
124 lpj = self.log_pseudo_joint(data, states)
125 D = to.sum(to.logical_not(to.isnan(data)), dim=1) # (N,)
126 H = self.shape[1]
127 priorterm = (
128 to.log(1 - self.theta["pies"]).sum()
129 if self.config["individual_priors"]
130 else H * to.log(1 - self.theta["pies"])
131 )
132 return lpj + priorterm - D.unsqueeze(1) / 2 * to.log(2 * math.pi * self.theta["sigma2"])
134 def update_param_batch(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> None:
135 lpj = states.lpj[idx]
136 K = states.K[idx]
137 batch_size, S, _ = K.shape
139 Kfloat = K.to(dtype=lpj.dtype) # TODO Find solution to avoid byte->float casting
140 Wbar = to.matmul(
141 Kfloat, self.theta["W"].t()
142 ) # TODO Find solution to re-use evaluations from E-step
144 batch_s_pjc = mean_posterior(Kfloat, lpj) # is (batch_size,H)
145 batch_Wp = batch.unsqueeze(2) * batch_s_pjc.unsqueeze(1) # is (batch_size,D,H)
146 Kq = Kfloat.mul(lpj2pjc(lpj)[:, :, None])
147 batch_Wq = to.einsum("ijk,ijl->kl", Kq, Kfloat) # is (batch_size,H,H)
148 batch_sigma2 = mean_posterior(
149 to.sum((batch[:, None, :] - Wbar) ** 2, dim=2), lpj
150 ) # is (batch_size,)
152 self.my_pies.add_(to.sum(batch_s_pjc, dim=0))
153 self.my_Wp.add_(to.sum(batch_Wp, dim=0))
154 self.my_Wq.add_(batch_Wq)
155 self.my_sigma2.add_(to.sum(batch_sigma2))
156 self.my_N.add_(batch_size)
158 return None
160 def update_param_epoch(self) -> None:
161 theta = self.theta
162 policy = self.policy
164 all_reduce(self.my_Wp)
165 all_reduce(self.my_Wq)
166 all_reduce(self.my_pies)
167 all_reduce(self.my_sigma2)
168 all_reduce(self.my_N)
170 N = self.my_N.item()
171 D, H = self.shape
173 # Calculate updated W
174 Wold_noisy = theta["W"] + 0.1 * to.randn_like(theta["W"])
175 broadcast(Wold_noisy)
176 theta_new = {}
177 try:
178 theta_new["W"] = to.linalg.lstsq(self.my_Wq, self.my_Wp.t())[0].t()
179 except RuntimeError:
180 pprint("Inversion error. Will not update W but add some noise instead.")
181 theta_new["W"] = Wold_noisy
183 # Calculate updated pi
184 theta_new["pies"] = (
185 self.my_pies / N
186 if self.config["individual_priors"]
187 else self.my_pies.sum(dim=0, keepdim=True) / N / H
188 )
190 # Calculate updated sigma^2
191 theta_new["sigma2"] = self.my_sigma2 / N / D
193 policy["W"][0] = Wold_noisy
194 policy["pies"][0] = theta["pies"]
195 policy["sigma2"][0] = theta["sigma2"]
196 fix_theta(theta_new, policy)
197 for key in theta:
198 theta[key][:] = theta_new[key]
200 self.my_Wp[:] = 0.0
201 self.my_Wq[:] = 0.0
202 self.my_pies[:] = 0.0
203 self.my_sigma2[:] = 0.0
204 self.my_N[:] = 0.0
206 @property
207 def shape(self) -> Tuple[int, ...]:
208 return self.theta["W"].shape
210 def generate_data(
211 self, N: int = None, hidden_state: to.Tensor = None
212 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
213 precision, device = self.precision, tvo.get_device()
214 D, H = self.shape
216 if hidden_state is None:
217 assert N is not None
218 pies = self.theta["pies"]
219 hidden_state = to.rand((N, H), dtype=precision, device=device) < pies
220 must_return_hidden_state = True
221 else:
222 shape = hidden_state.shape
223 if N is None:
224 N = shape[0]
225 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
226 must_return_hidden_state = False
228 Wbar = to.zeros((N, D), dtype=precision, device=device)
230 # Linear superposition
231 for n in range(N):
232 for h in range(H):
233 if hidden_state[n, h]:
234 Wbar[n] += self.theta["W"][:, h]
236 # Add noise according to the model parameters
237 Y = Wbar + to.sqrt(self.theta["sigma2"]) * to.randn((N, D), dtype=precision, device=device)
239 return (Y, hidden_state) if must_return_hidden_state else Y
241 def data_estimator(self, idx: Tensor, batch: Tensor, states: TVOVariationalStates) -> Tensor:
242 """Estimator used for data reconstruction. Data reconstruction can only be supported
243 by a model if it implements this method. The estimator to be implemented is defined
244 as follows:""" r"""
245 :math:`\\langle \langle y_d \rangle_{p(y_d|\vec{s},\Theta)} \rangle_{q(\vec{s}|\mathcal{K},\Theta)}` # noqa
246 """
247 K = states.K[idx]
248 # TODO Find solution to avoid byte->float casting of `K`
249 # TODO Pre-allocate tensor and use `out` argument of to.matmul
250 return mean_posterior(
251 to.matmul(K.to(dtype=self.precision), self.theta["W"].t()), states.lpj[idx]
252 )