Coverage for tvo/variational/_utils.py: 78%

74 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6import tvo 

7import numpy as np 

8from typing import Dict 

9from tvo.variational._set_redundant_lpj_to_low_CPU import set_redundant_lpj_to_low_CPU 

10 

11 

12def _unique_ind(x: to.Tensor) -> to.Tensor: 

13 """Find indices of unique rows in tensor. Prioritizes the first instance. 

14 

15 :param x: torch tensor 

16 :returns: indices of unique rows in tensor. 

17 """ 

18 # Get unique rows and inverse indices 

19 unique_rows, inverse_ind = to.unique(x, sorted=False, return_inverse=True, dim=0) 

20 

21 # get unique inverse indices 

22 uii = inverse_ind.unique() 

23 

24 # find where unique index in inverse index (uii x ii matrix) 

25 where_unique = to.eq(uii.unsqueeze(1), inverse_ind.repeat(len(uii), 1)) 

26 

27 # get index of first instance 

28 unique_indices = where_unique.to(to.float).argmax(1) 

29 

30 return unique_indices 

31 

32 # The code below is a bit faster, but is 1. unstable and 2.non-deterministic as of July 2023 and 

33 # pytorch=2.0.0. When the pytorch version increases, check if the docs for 

34 # Tensor.scatter_reduce_ still have the respective warnings & notes about the function. 

35 # Until then, the deterministic function above should be used instead. (If you checked, 

36 # please increment the pytorch version in this comment and push). 

37 

38 # Authored by Sebastian Salwig: 

39 # n = x.shape[0] 

40 # unique_rows, inverse_ind = to.unique(x, sorted=False, return_inverse=True, dim=0) 

41 # n_unique = unique_rows.shape[0] 

42 # uniq_ind = to.zeros(n_unique, dtype=to.int, device=unique_rows.device) 

43 # perm = to.arange(n, device=inverse_ind.device) 

44 # uniq_ind = inverse_ind.new_empty( 

45 # n_unique 

46 # ).scatter_reduce_(0, inverse_ind, perm,"amin",include_self=False) 

47 # return uniq_ind 

48 

49 # The slow CPU code below can be used to verify: 

50 # CPU code 

51 # for i in range(n_unique): 

52 # for j, n in enumerate(inverse_ind): 

53 # if n == i: 

54 # uniq_ind[i] = int(j) 

55 # uniq_ind.long() 

56 # break 

57 

58 

59def _set_redundant_lpj_to_low_GPU(new_states: to.Tensor, new_lpj: to.Tensor, old_states: to.Tensor): 

60 """Find redundant states in new_states w.r.t. old_states and set 

61 corresponding lpg to low. 

62 

63 :param new_states: set of new variational states (batch_size, newS, H) 

64 :param new_lpj: corresponding log-pseudo-joints (batch_size, newS) 

65 :param old_states: (batch_size, S, H) 

66 """ 

67 

68 N, S, H = old_states.shape 

69 newS = new_states.shape[1] 

70 

71 # old_states must come first for np.unique to discard redundant new_states 

72 old_and_new = to.cat((old_states, new_states), dim=1) 

73 for n in range(N): 

74 uniq_idx = _unique_ind(old_and_new[n]) 

75 # indexes of states in new_states[n] that are not in old_states[n] 

76 new_uniq_idx = uniq_idx[uniq_idx >= S] - S 

77 # BoolTensor in pytorch>=1.2, ByteTensor otherwise 

78 bool_or_byte = (to.empty(0) < 0).dtype 

79 mask = to.ones(newS, dtype=bool_or_byte, device=new_lpj.device) 

80 # indexes of all non-unique states in new_states (complementary of new_uniq_idx) 

81 mask[new_uniq_idx.to(device=new_lpj.device)] = 0 

82 # set lpj of redundant states to an arbitrary low value 

83 new_lpj[n][mask] = -1e20 

84 

85 

86# set_redundant_lpj_to_low is a performance hotspot. when running on CPU, we use a cython 

87# function that runs on numpy arrays, when running on GPU, we stick to torch tensors 

88def set_redundant_lpj_to_low(new_states: to.Tensor, new_lpj: to.Tensor, old_states: to.Tensor): 

89 if tvo.get_device().type == "cpu": 

90 set_redundant_lpj_to_low_CPU(new_states.numpy(), new_lpj.numpy(), old_states.numpy()) 

91 else: 

92 _set_redundant_lpj_to_low_GPU(new_states, new_lpj, old_states) 

93 

94 

95def generate_unique_states( 

96 n_states: int, H: int, crowdedness: float = 1.0, device: to.device = None 

97) -> to.Tensor: 

98 """Generate a torch tensor containing random and unique binary vectors. 

99 

100 :param n_states: number of unique vectors to be generated 

101 :param H: size of binary vector 

102 :param crowdedness: average crowdedness per state 

103 :param device: torch.device of output Tensor. Defaults to tvo.get_device() 

104 

105 Requires that n_states <= 2**H. Return has shape (n_states, H). 

106 """ 

107 if device is None: 

108 device = tvo.get_device() 

109 assert n_states <= 2**H, "n_states must be smaller than 2**H" 

110 n_samples = max(n_states // 2, 1) 

111 

112 s_set = {tuple(s) for s in np.random.binomial(1, p=crowdedness / H, size=(n_samples, H))} 

113 while len(s_set) < n_states: 

114 s_set.update( 

115 {tuple(s) for s in np.random.binomial(1, p=crowdedness / H, size=(n_samples, H))} 

116 ) 

117 while len(s_set) > n_states: 

118 s_set.pop() 

119 return to.from_numpy(np.array(tuple(s for s in s_set), dtype=int)).to( 

120 dtype=to.uint8, device=device 

121 ) 

122 

123 

124def update_states_for_batch( 

125 new_states: to.Tensor, 

126 new_lpj: to.Tensor, 

127 idx: to.Tensor, 

128 all_states: to.Tensor, 

129 all_lpj: to.Tensor, 

130 sort_by_lpj: Dict[str, to.Tensor] = {}, 

131) -> int: 

132 """Perform substitution of old and new states (and lpj, ...) 

133 according to TVO criterion. 

134 

135 :param new_states: set of new variational states (idx.size, newS, H) 

136 :param new_lpj: corresponding log-pseudo-joints (idx.size, newS) 

137 :param idx: indeces of the datapoints that compose the batch within the dataset 

138 :param all_states: set of all variational states (N, S, H) 

139 :param all_lpj: corresponding log-pseudo-joints (N, S) 

140 :param sort_by_lpj: optional list of tensors with shape (n,s,...) that will be 

141 sorted by all_lpj, the same way all_lpj and all_states are sorted. 

142 

143 S is the number of variational states memorized for each of the N 

144 data-points. idx contains the ordered list of indexes for which the 

145 new_states have been evaluated (i.e. the states in new_states[0] are to 

146 be put into all_s[idx[0]]. all_s[n] is updated to contain the set of 

147 variational states with best log-pseudo-joints. 

148 """ 

149 # TODO Find out why lpj precision decreases for states without substitutions 

150 # (difference on the order of 1e-15). 

151 

152 S = all_states.shape[1] 

153 batch_size, newS, H = new_states.shape 

154 

155 old_states = all_states[idx] 

156 old_lpj = all_lpj[idx] 

157 

158 assert old_states.shape == (batch_size, S, H) 

159 assert old_lpj.shape == (batch_size, S) 

160 

161 conc_states = to.cat((old_states, new_states), dim=1) 

162 conc_lpj = to.cat((old_lpj, new_lpj), dim=1) # (batch_size, S+newS) 

163 

164 # is (batch_size, S) 

165 sorted_idx = to.flip(to.topk(conc_lpj, k=S, dim=1, largest=True, sorted=True)[1], [1]) 

166 flattened_sorted_idx = sorted_idx.flatten() 

167 

168 idx_n = idx.repeat(S, 1).t().flatten() 

169 idx_s = to.arange(S, device=all_states.device).repeat(batch_size) 

170 idx_sc = to.arange(batch_size, device=all_states.device).repeat(S, 1).t().flatten() 

171 

172 all_states[idx_n, idx_s] = conc_states[idx_sc, flattened_sorted_idx] 

173 all_lpj[idx_n, idx_s] = conc_lpj[idx_sc, flattened_sorted_idx] 

174 

175 for t in sort_by_lpj.values(): 

176 idx_n_ = to.arange(batch_size).repeat(S, 1).t().flatten() 

177 t[idx_n_, idx_s] = t[idx_n_, flattened_sorted_idx] 

178 

179 return (sorted_idx >= old_states.shape[1]).sum().item() # nsubs 

180 

181 

182def lpj2pjc(lpj: to.Tensor): 

183 """Shift log-pseudo-joint and convert log- to actual probability 

184 

185 :param lpj: log-pseudo-joint tensor 

186 :returns: probability tensor 

187 """ 

188 up_lpg_bound = 0.0 

189 shft = up_lpg_bound - lpj.max(dim=1, keepdim=True)[0] 

190 tmp = to.exp(lpj + shft) 

191 return tmp.div_(tmp.sum(dim=1, keepdim=True)) 

192 

193 

194def _mean_post_einsum(g: to.Tensor, lpj: to.Tensor) -> to.Tensor: 

195 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s). 

196 

197 :param g: Values of g(s) with shape (N,S,...). 

198 :param lpj: Log-pseudo-joint with shape (N,S). 

199 :returns: tensor with shape (N,...). 

200 """ 

201 return to.einsum("ns...,ns->n...", (g, lpj2pjc(lpj))) 

202 

203 

204def _mean_post_mul(g: to.Tensor, lpj: to.Tensor) -> to.Tensor: 

205 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s). 

206 

207 :param g: Values of g(s) with shape (N,S,...). 

208 :param lpj: Log-pseudo-joint with shape (N,S). 

209 :returns: tensor with shape (N,...). 

210 """ 

211 # reshape lpj from (N,S) to (N,S,1,...), to match dimensionality of g 

212 lpj = lpj.view(*lpj.shape, *(1 for _ in range(g.ndimension() - 2))) 

213 return lpj2pjc(lpj).mul(g).sum(dim=1) 

214 

215 

216def mean_posterior(g: to.Tensor, lpj: to.Tensor) -> to.Tensor: 

217 """Compute expectation value of g(s) w.r.t truncated variational distribution q(s). 

218 

219 :param g: Values of g(s) with shape (N,S,...). 

220 :param lpj: Log-pseudo-joint with shape (N,S). 

221 :returns: tensor with shape (N,...). 

222 """ 

223 if tvo.get_device().type == "cpu": 

224 means = _mean_post_einsum(g, lpj) 

225 else: 

226 means = _mean_post_mul(g, lpj) 

227 

228 assert means.shape == (g.shape[0], *g.shape[2:]) 

229 assert not to.isnan(means).any() and not to.isinf(means).any() 

230 return means