Coverage for tvo/variational/_utils.py: 78%
74 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6import tvo
7import numpy as np
8from typing import Dict
9from tvo.variational._set_redundant_lpj_to_low_CPU import set_redundant_lpj_to_low_CPU
12def _unique_ind(x: to.Tensor) -> to.Tensor:
13 """Find indices of unique rows in tensor. Prioritizes the first instance.
15 :param x: torch tensor
16 :returns: indices of unique rows in tensor.
17 """
18 # Get unique rows and inverse indices
19 unique_rows, inverse_ind = to.unique(x, sorted=False, return_inverse=True, dim=0)
21 # get unique inverse indices
22 uii = inverse_ind.unique()
24 # find where unique index in inverse index (uii x ii matrix)
25 where_unique = to.eq(uii.unsqueeze(1), inverse_ind.repeat(len(uii), 1))
27 # get index of first instance
28 unique_indices = where_unique.to(to.float).argmax(1)
30 return unique_indices
32 # The code below is a bit faster, but is 1. unstable and 2.non-deterministic as of July 2023 and
33 # pytorch=2.0.0. When the pytorch version increases, check if the docs for
34 # Tensor.scatter_reduce_ still have the respective warnings & notes about the function.
35 # Until then, the deterministic function above should be used instead. (If you checked,
36 # please increment the pytorch version in this comment and push).
38 # Authored by Sebastian Salwig:
39 # n = x.shape[0]
40 # unique_rows, inverse_ind = to.unique(x, sorted=False, return_inverse=True, dim=0)
41 # n_unique = unique_rows.shape[0]
42 # uniq_ind = to.zeros(n_unique, dtype=to.int, device=unique_rows.device)
43 # perm = to.arange(n, device=inverse_ind.device)
44 # uniq_ind = inverse_ind.new_empty(
45 # n_unique
46 # ).scatter_reduce_(0, inverse_ind, perm,"amin",include_self=False)
47 # return uniq_ind
49 # The slow CPU code below can be used to verify:
50 # CPU code
51 # for i in range(n_unique):
52 # for j, n in enumerate(inverse_ind):
53 # if n == i:
54 # uniq_ind[i] = int(j)
55 # uniq_ind.long()
56 # break
59def _set_redundant_lpj_to_low_GPU(new_states: to.Tensor, new_lpj: to.Tensor, old_states: to.Tensor):
60 """Find redundant states in new_states w.r.t. old_states and set
61 corresponding lpg to low.
63 :param new_states: set of new variational states (batch_size, newS, H)
64 :param new_lpj: corresponding log-pseudo-joints (batch_size, newS)
65 :param old_states: (batch_size, S, H)
66 """
68 N, S, H = old_states.shape
69 newS = new_states.shape[1]
71 # old_states must come first for np.unique to discard redundant new_states
72 old_and_new = to.cat((old_states, new_states), dim=1)
73 for n in range(N):
74 uniq_idx = _unique_ind(old_and_new[n])
75 # indexes of states in new_states[n] that are not in old_states[n]
76 new_uniq_idx = uniq_idx[uniq_idx >= S] - S
77 # BoolTensor in pytorch>=1.2, ByteTensor otherwise
78 bool_or_byte = (to.empty(0) < 0).dtype
79 mask = to.ones(newS, dtype=bool_or_byte, device=new_lpj.device)
80 # indexes of all non-unique states in new_states (complementary of new_uniq_idx)
81 mask[new_uniq_idx.to(device=new_lpj.device)] = 0
82 # set lpj of redundant states to an arbitrary low value
83 new_lpj[n][mask] = -1e20
86# set_redundant_lpj_to_low is a performance hotspot. when running on CPU, we use a cython
87# function that runs on numpy arrays, when running on GPU, we stick to torch tensors
88def set_redundant_lpj_to_low(new_states: to.Tensor, new_lpj: to.Tensor, old_states: to.Tensor):
89 if tvo.get_device().type == "cpu":
90 set_redundant_lpj_to_low_CPU(new_states.numpy(), new_lpj.numpy(), old_states.numpy())
91 else:
92 _set_redundant_lpj_to_low_GPU(new_states, new_lpj, old_states)
95def generate_unique_states(
96 n_states: int, H: int, crowdedness: float = 1.0, device: to.device = None
97) -> to.Tensor:
98 """Generate a torch tensor containing random and unique binary vectors.
100 :param n_states: number of unique vectors to be generated
101 :param H: size of binary vector
102 :param crowdedness: average crowdedness per state
103 :param device: torch.device of output Tensor. Defaults to tvo.get_device()
105 Requires that n_states <= 2**H. Return has shape (n_states, H).
106 """
107 if device is None:
108 device = tvo.get_device()
109 assert n_states <= 2**H, "n_states must be smaller than 2**H"
110 n_samples = max(n_states // 2, 1)
112 s_set = {tuple(s) for s in np.random.binomial(1, p=crowdedness / H, size=(n_samples, H))}
113 while len(s_set) < n_states:
114 s_set.update(
115 {tuple(s) for s in np.random.binomial(1, p=crowdedness / H, size=(n_samples, H))}
116 )
117 while len(s_set) > n_states:
118 s_set.pop()
119 return to.from_numpy(np.array(tuple(s for s in s_set), dtype=int)).to(
120 dtype=to.uint8, device=device
121 )
124def update_states_for_batch(
125 new_states: to.Tensor,
126 new_lpj: to.Tensor,
127 idx: to.Tensor,
128 all_states: to.Tensor,
129 all_lpj: to.Tensor,
130 sort_by_lpj: Dict[str, to.Tensor] = {},
131) -> int:
132 """Perform substitution of old and new states (and lpj, ...)
133 according to TVO criterion.
135 :param new_states: set of new variational states (idx.size, newS, H)
136 :param new_lpj: corresponding log-pseudo-joints (idx.size, newS)
137 :param idx: indeces of the datapoints that compose the batch within the dataset
138 :param all_states: set of all variational states (N, S, H)
139 :param all_lpj: corresponding log-pseudo-joints (N, S)
140 :param sort_by_lpj: optional list of tensors with shape (n,s,...) that will be
141 sorted by all_lpj, the same way all_lpj and all_states are sorted.
143 S is the number of variational states memorized for each of the N
144 data-points. idx contains the ordered list of indexes for which the
145 new_states have been evaluated (i.e. the states in new_states[0] are to
146 be put into all_s[idx[0]]. all_s[n] is updated to contain the set of
147 variational states with best log-pseudo-joints.
148 """
149 # TODO Find out why lpj precision decreases for states without substitutions
150 # (difference on the order of 1e-15).
152 S = all_states.shape[1]
153 batch_size, newS, H = new_states.shape
155 old_states = all_states[idx]
156 old_lpj = all_lpj[idx]
158 assert old_states.shape == (batch_size, S, H)
159 assert old_lpj.shape == (batch_size, S)
161 conc_states = to.cat((old_states, new_states), dim=1)
162 conc_lpj = to.cat((old_lpj, new_lpj), dim=1) # (batch_size, S+newS)
164 # is (batch_size, S)
165 sorted_idx = to.flip(to.topk(conc_lpj, k=S, dim=1, largest=True, sorted=True)[1], [1])
166 flattened_sorted_idx = sorted_idx.flatten()
168 idx_n = idx.repeat(S, 1).t().flatten()
169 idx_s = to.arange(S, device=all_states.device).repeat(batch_size)
170 idx_sc = to.arange(batch_size, device=all_states.device).repeat(S, 1).t().flatten()
172 all_states[idx_n, idx_s] = conc_states[idx_sc, flattened_sorted_idx]
173 all_lpj[idx_n, idx_s] = conc_lpj[idx_sc, flattened_sorted_idx]
175 for t in sort_by_lpj.values():
176 idx_n_ = to.arange(batch_size).repeat(S, 1).t().flatten()
177 t[idx_n_, idx_s] = t[idx_n_, flattened_sorted_idx]
179 return (sorted_idx >= old_states.shape[1]).sum().item() # nsubs
182def lpj2pjc(lpj: to.Tensor):
183 """Shift log-pseudo-joint and convert log- to actual probability
185 :param lpj: log-pseudo-joint tensor
186 :returns: probability tensor
187 """
188 up_lpg_bound = 0.0
189 shft = up_lpg_bound - lpj.max(dim=1, keepdim=True)[0]
190 tmp = to.exp(lpj + shft)
191 return tmp.div_(tmp.sum(dim=1, keepdim=True))
194def _mean_post_einsum(g: to.Tensor, lpj: to.Tensor) -> to.Tensor:
195 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s).
197 :param g: Values of g(s) with shape (N,S,...).
198 :param lpj: Log-pseudo-joint with shape (N,S).
199 :returns: tensor with shape (N,...).
200 """
201 return to.einsum("ns...,ns->n...", (g, lpj2pjc(lpj)))
204def _mean_post_mul(g: to.Tensor, lpj: to.Tensor) -> to.Tensor:
205 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s).
207 :param g: Values of g(s) with shape (N,S,...).
208 :param lpj: Log-pseudo-joint with shape (N,S).
209 :returns: tensor with shape (N,...).
210 """
211 # reshape lpj from (N,S) to (N,S,1,...), to match dimensionality of g
212 lpj = lpj.view(*lpj.shape, *(1 for _ in range(g.ndimension() - 2)))
213 return lpj2pjc(lpj).mul(g).sum(dim=1)
216def mean_posterior(g: to.Tensor, lpj: to.Tensor) -> to.Tensor:
217 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s).
219 :param g: Values of g(s) with shape (N,S,...).
220 :param lpj: Log-pseudo-joint with shape (N,S).
221 :returns: tensor with shape (N,...).
222 """
223 if tvo.get_device().type == "cpu":
224 means = _mean_post_einsum(g, lpj)
225 else:
226 means = _mean_post_mul(g, lpj)
228 assert means.shape == (g.shape[0], *g.shape[2:])
229 assert not to.isnan(means).any() and not to.isinf(means).any()
230 return means