Coverage for tvo/models/noisyor.py: 99%
105 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
6from tvo.utils.model_protocols import Optimized, Sampler
7from tvo.variational import TVOVariationalStates # type: ignore
8from tvo.variational._utils import mean_posterior
9from tvo.utils.parallel import all_reduce, broadcast
10from torch import Tensor
11import torch as to
12from typing import Dict, Optional, Union, Tuple
13import tvo
16class NoisyOR(Optimized, Sampler):
17 eps = 1e-7
19 def __init__(
20 self,
21 H: int,
22 D: int,
23 W_init: Tensor = None,
24 pi_init: Tensor = None,
25 precision: to.dtype = to.float64,
26 ):
27 """Shallow NoisyOR model.
29 :param H: Number of hidden units.
30 :param D: Number of observables.
31 :param W_init: Tensor with shape (D,H), initializes NoisyOR weights.
32 :param pi_init: Tensor with shape (H,), initializes NoisyOR priors.
33 :param precision: Floating point precision required. Must be one of torch.float32 or
34 torch.float64.
35 """
37 assert precision in (to.float32, to.float64), "precision must be one of torch.float{32,64}"
38 self._precision = precision
40 device = tvo.get_device()
42 if W_init is not None:
43 assert W_init.shape == (D, H)
44 else:
45 W_init = to.rand(D, H, device=device)
46 broadcast(W_init)
48 if pi_init is not None:
49 assert pi_init.shape == (H,)
50 assert (pi_init <= 1.0).all() and (pi_init >= 0).all()
51 else:
52 pi_init = to.full((H,), 1.0 / H, device=device, dtype=self.precision)
54 self._theta = {
55 "pies": pi_init.to(device=device, dtype=precision),
56 "W": W_init.to(device=device, dtype=precision),
57 }
59 self.new_pi = to.zeros(H, device=device, dtype=precision)
60 self.Btilde = to.zeros(D, H, device=device, dtype=precision)
61 self.Ctilde = to.zeros(D, H, device=device, dtype=precision)
62 # number of datapoints processed in a training epoch
63 self._train_datapoints = to.tensor([0], dtype=to.int, device=device)
64 self._config = dict(H=H, D=D, precision=self.precision, device=device)
65 self._shape = self.theta["W"].shape
67 def log_pseudo_joint(self, data: Tensor, states: Tensor) -> Tensor: # type: ignore
68 """Evaluate log-pseudo-joints for NoisyOR."""
69 K = states
70 Y = data
71 assert K.dtype == to.uint8 and Y.dtype == to.uint8
72 pi = self.theta["pies"]
73 W = self.theta["W"]
74 batch_size, S, H = K.shape
75 D = W.shape[0]
76 dev = pi.device
78 logPriors = to.matmul(K.type_as(pi), to.log(pi / (1 - pi)))
80 logPy = to.empty((batch_size, S), device=dev, dtype=self.precision)
81 # We will manually set the lpjs of all-zero states to the appropriate value.
82 # For now, transform all-zero states in all-one states, to avoid computation of log(0).
83 zeroStatesInd = to.nonzero((K == 0).all(dim=2))
84 # https://discuss.pytorch.org/t/use-torch-nonzero-as-index/33218
85 zeroStatesInd = (zeroStatesInd[:, 0], zeroStatesInd[:, 1])
86 K[zeroStatesInd] = 1
87 # prods_nsd = prod{h}{1-W_dh*K_nkh}
88 prods = (W * K.type_as(W).unsqueeze(2)).neg_().add_(1).prod(dim=-1)
89 to.clamp(prods, self.eps, 1 - self.eps, out=prods)
90 # logPy_nk = sum{d}{y_nd*log(1/prods_nkd - 1) + log(prods_nkd)}
91 f1 = to.log(1.0 / prods - 1.0)
92 indeces = 1 - Y[:, None, :].expand(batch_size, S, D)
93 # convert to BoolTensor in pytorch>=1.2, leave it as ByteTensor in earlier versions
94 indeces = indeces.type_as(to.empty(0) < 0)
95 f1[indeces] = 0.0
96 logPy[:, :] = to.sum(f1, dim=-1) + to.sum(to.log(prods), dim=2)
97 K[zeroStatesInd] = 0
99 lpj = logPriors + logPy
100 # for all-zero states, set lpj to arbitrary very low value if y!=0, 0 otherwise
101 # in the end we want exp(lpj(y,s=0)) = 1 if y=0, 0 otherwise
102 lpj[zeroStatesInd] = -1e30 * data[zeroStatesInd[0]].any(dim=1).type_as(lpj)
103 assert (
104 not to.isnan(lpj).any() and not to.isinf(lpj).any()
105 ), "some NoisyOR lpj values are invalid!"
106 return lpj.to(device=states.device) # (N, S)
108 def update_param_batch(
109 self,
110 idx: Tensor,
111 batch: Tensor,
112 states: TVOVariationalStates,
113 mstep_factors: Dict[str, Tensor] = None,
114 ) -> Optional[float]:
115 lpj = states.lpj[idx]
116 K = states.K[idx]
117 Kfloat = K.type_as(lpj)
119 # pi_h = sum{n}{<K_hns>} / N
120 # (division by N has to wait until after the mpi all_reduce)
121 self.new_pi += mean_posterior(Kfloat, lpj).sum(dim=0)
122 assert not to.isnan(self.new_pi).any()
124 # Ws_nsdh = 1 - (W_dh * Kfloat_nsh)
125 Ws = (self.theta["W"][None, None, :, :] * Kfloat[:, :, None, :]).neg_().add_(1)
126 Ws_prod = to.prod(Ws, dim=3, keepdim=True)
127 B = Kfloat.unsqueeze(2) / (Ws * Ws_prod.neg().add_(1)).add_(self.eps) # (N,S,D,H)
128 self.Btilde.add_(
129 (mean_posterior(B, lpj) * (batch.type_as(lpj) - 1).unsqueeze(2)).sum(dim=0)
130 )
131 C = B.mul_(Ws_prod).div_(Ws) # (N,S,D,H)
132 self.Ctilde.add_(to.sum(mean_posterior(C, lpj), dim=0))
133 assert not to.isnan(self.Ctilde).any()
134 assert not to.isnan(self.Btilde).any()
136 self._train_datapoints.add_(batch.shape[0])
138 return None
140 def update_param_epoch(self) -> None:
141 all_reduce(self._train_datapoints)
142 all_reduce(self.new_pi)
143 N = self._train_datapoints.item()
144 self.theta["pies"][:] = self.new_pi / N
145 to.clamp(self.theta["pies"], self.eps, 1 - self.eps, out=self.theta["pies"])
146 self.new_pi[:] = 0.0
148 all_reduce(self.Btilde)
149 all_reduce(self.Ctilde)
150 self.theta["W"][:] = 1 + self.Btilde / (self.Ctilde + self.eps)
151 to.clamp(self.theta["W"], self.eps, 1 - self.eps, out=self.theta["W"])
152 self.Btilde[:] = 0.0
153 self.Ctilde[:] = 0.0
155 self._train_datapoints[:] = 0
157 def log_joint(self, data, states, lpj=None):
158 pi = self.theta["pies"]
159 if lpj is None:
160 lpj = self.log_pseudo_joint(data, states)
161 # TODO: could pre-evaluate the constant factor once per epoch
162 return to.sum(to.log(1 - pi)) + lpj
164 def generate_data(
165 self, N: int = None, hidden_state: Tensor = None
166 ) -> Union[to.Tensor, Tuple[to.Tensor, to.Tensor]]:
167 """Use hidden states to sample datapoints according to the NoisyOR generative model.
169 :param hidden_state: a tensor with shape (N, H) where H is the number of hidden units.
170 :returns: the datapoints, as a tensor with shape (N, D) where D is
171 the number of observables.
172 """
173 theta = self.theta
174 W = theta["W"]
175 D, H = W.shape
177 if hidden_state is None:
178 pies = theta["pies"]
179 hidden_state = to.rand((N, H), dtype=pies.dtype, device=pies.device) < pies
180 must_return_hidden_state = True
181 else:
182 if N is not None:
183 shape = hidden_state.shape
184 assert shape == (N, H), f"hidden_state has shape {shape}, expected ({N},{H})"
185 must_return_hidden_state = False
187 # py_nd = 1 - prod_h (1 - W_dh * s_nh)
188 py = 1 - to.prod(1 - W[None, :, :] * hidden_state.type_as(W)[:, None, :], dim=2)
189 Y = (to.rand_like(py) < py).byte()
191 return (Y, hidden_state) if must_return_hidden_state else Y