Coverage for tvo/variational/evo.py: 98%
214 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import numpy as np
6import torch as to
8from itertools import combinations
9from typing import Callable, Tuple, Optional
10from torch import Tensor
12import tvo
13from tvo.utils import get
14from tvo.variational.TVOVariationalStates import TVOVariationalStates
15from tvo.variational._utils import update_states_for_batch, set_redundant_lpj_to_low
17from tvo.utils.model_protocols import Optimized, Trainable
20class EVOVariationalStates(TVOVariationalStates):
21 def __init__(
22 self,
23 N: int,
24 H: int,
25 S: int,
26 precision: to.dtype,
27 parent_selection: str,
28 mutation: str,
29 n_parents: int,
30 n_generations: int,
31 n_children: int = None,
32 crossover: bool = False,
33 bitflip_frequency: float = None,
34 K_init_file: str = None,
35 ):
36 """Evolutionary Variational Optimization class.
38 :param N: number of datapoints
39 :param H: number of latents
40 :param S: number of variational states
41 :param precision: floating point precision to be used for log_joint values.
42 Must be one of to.float32 or to.float64.
43 :param selection: one of "batch_fitparents" or "randparents"
44 :param mutation: one of "randflip" or "sparseflip"
45 :param n_parents: number of parent states to select
46 :param n_generations: number of EA generations to produce
47 :param n_children: if crossover is False, number of children states to produce per
48 generation. Must be None if crossover is True.
49 :param crossover: if True, apply crossover. Must be False if n_children is specified.
50 :param bitflip_frequency: Probability of flipping a bit during the mutation step (e.g.
51 2/H for an average of 2 bitflips per mutation). Required when
52 using the 'sparsity' mutation algorithm.
53 :param K_init_file: Full path to H5 file providing initial states
54 """
55 assert (
56 not crossover or n_children is None
57 ), "Exactly one of n_children and crossover may be provided."
58 if crossover:
59 mutation = f"cross_{mutation}"
60 n_children = n_parents - 1
61 assert n_children is not None # make mypy happy
62 S_new = get_n_new_states(mutation, n_parents, n_children, n_generations)
64 conf = dict(
65 N=N,
66 H=H,
67 S=S,
68 S_new=S_new,
69 precision=precision,
70 parent_selection=parent_selection,
71 mutation=mutation,
72 n_parents=n_parents,
73 n_children=n_children,
74 n_generations=n_generations,
75 p_bf=bitflip_frequency,
76 K_init_file=K_init_file,
77 )
78 super().__init__(conf)
80 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int:
81 if isinstance(model, Optimized):
82 lpj_fn = model.log_pseudo_joint
83 sort_by_lpj = model.sorted_by_lpj
84 else:
85 lpj_fn = model.log_joint
86 sort_by_lpj = {}
88 K = self.K
89 lpj = self.lpj
91 parent_selection, mutation, n_parents, n_children, n_generations = get(
92 self.config, "parent_selection", "mutation", "n_parents", "n_children", "n_generations"
93 )
95 lpj[idx] = lpj_fn(batch, K[idx])
97 def lpj_fn_(states):
98 return lpj_fn(batch, states)
100 new_states, new_lpj = evolve_states(
101 lpj=lpj[idx].to(device="cpu"),
102 states=K[idx].to(device="cpu"),
103 lpj_fn=lpj_fn_,
104 n_parents=n_parents,
105 n_children=n_children,
106 n_generations=n_generations,
107 parent_selection=parent_selection,
108 mutation=mutation,
109 sparsity=model.theta["pies"].mean() if "sparseflip" in mutation else None,
110 p_bf=self.config.get("p_bf"),
111 )
113 return update_states_for_batch(
114 new_states.to(device=K.device), new_lpj.to(device=lpj.device), idx, K, lpj, sort_by_lpj
115 )
118def evolve_states(
119 lpj: Tensor,
120 states: Tensor,
121 lpj_fn: Callable[[Tensor], Tensor],
122 n_parents: int,
123 n_children: int,
124 n_generations: int,
125 parent_selection: str = "batch_fitparents",
126 mutation: str = "cross_randflip",
127 sparsity: Optional[float] = None,
128 p_bf: Optional[float] = None,
129) -> Tuple[Tensor, Tensor]:
130 """
131 Take old variational states states (N,K,H) with lpj values (N,K) and
132 return new states and their log-pseudo-joints for each datapoint. The
133 helper function `evo.get_n_new_states` can be used to retrieve the
134 exact number S of states generated depending on the chosen genetic
135 algorithm. lpj_fn must be a callable that takes a set of states with
136 shape (N,M,H) as arguments and returns a tuple of log-pseudo-joint
137 evaluations for those states (shape (N,M)). This function does not
138 guarantee that all new_states returned are unique and not already
139 contained in states, but it does guarantee that all redundant states
140 will have lpj'states lower than the minimum lpj of all states already
141 in states, for each datapoint.
143 Pre-conditions: H >= n_children, K >= n_parents
145 Return: new_states (N,S,H), new_lpj (N,S)
147 Each generation of new states is obtained by selecting `n_parents`
148 parents from the previous generation following the strategy indicated
149 by `parent_selection` and then mutating each parent `n_children` times
150 following the strategy indicated by `genetic_algorithm`.
152 parent_selection can be one of the following:
153 - 'batch_fitparents'
154 parents are selected using fitness-proportional sampling
155 _with replacement_.
156 - 'randparents'
157 random uniform selection of parents.
159 genetic_algorithm can be one of the following:
160 - 'randflip'
161 each children is obtained by flipping one bit of the parent.
162 every bit has the same probability of being flipped.
163 - 'sparseflip'
164 each children is obtained by flipping bits in the parent.
165 the probability of each bit being flipped depends on the sparsity
166 and p_bh parameters (see method'states description)
167 - 'cross'
168 children are generated by one-point-crossover of the parents. each
169 parent is crossed-over with each other parent
170 at a point chosen via random uniform sampling.
171 - 'cross_randflip'
172 as above, but the children additionally go through 'randflip'
173 - 'cross_sparseflip'
174 as 'cross', but the children additionally go through 'sparseflip'
175 """
176 dtype_f, device = lpj.dtype, lpj.device
177 N, K, H = states.shape
178 max_new_states = get_n_new_states(mutation, n_parents, n_children, n_generations)
179 new_states_per_gen = max_new_states // n_generations
181 # Pre-allocations
182 # It'states probable that not all new_states will be filled with a
183 # new unique state. Unfilled new_states will remain uninitialized and
184 # their corresponding new_lpj will be lower than any state in states[n].
185 new_states = to.empty((N, max_new_states, H), dtype=to.uint8, device=device)
186 new_lpj = to.empty((N, max_new_states), dtype=dtype_f, device=device)
187 parents = to.empty((N, n_parents, H), dtype=to.uint8, device=device)
189 select, mutate = get_EA(parent_selection, mutation)
191 for g in range(n_generations):
192 # parent selection
193 gen_idx = to.arange(g * new_states_per_gen, (g + 1) * new_states_per_gen, device=device)
194 if g == 0:
195 parents[:] = select(states, n_parents, lpj)
196 else:
197 old_gen_idx = gen_idx - new_states_per_gen
198 parents[:] = select(new_states[:, old_gen_idx], n_parents, new_lpj[:, old_gen_idx])
200 # children generation
201 new_states[:, gen_idx] = mutate(parents, n_children, sparsity, p_bf)
203 # children fitness evaluation
204 # new_lpj[:, gen_idx] = lpj_fn(new_states[:, gen_idx])
205 new_lpj[:, gen_idx] = lpj_fn(new_states[:, gen_idx].to(device=tvo.get_device())).to(
206 device="cpu"
207 )
209 set_redundant_lpj_to_low(new_states, new_lpj, states)
211 return new_states, new_lpj
214def get_n_new_states(mutation: str, n_parents: int, n_children: int, n_gen: int) -> int:
215 if mutation[:5] == "cross":
216 return n_parents * (n_parents - 1) * n_gen
217 else:
218 return n_parents * n_children * n_gen
221def get_EA(parent_selection: str, mutation: str) -> Tuple:
222 """Refer to the doc of `evolve_states` for the list of valid arguments"""
223 parent_sel_dict = {"batch_fitparents": batch_fitparents, "randparents": batch_randparents}
224 mutation_dict = {
225 "randflip": batch_randflip,
226 "sparseflip": batch_sparseflip,
227 "cross": batch_cross,
228 "cross_randflip": batch_cross_randflip,
229 "cross_sparseflip": batch_cross_sparseflip,
230 }
231 # input validation
232 valid_parent_sel = parent_sel_dict.keys()
233 if parent_selection not in valid_parent_sel: # pragma: no cover
234 raise ValueError(
235 f'Parent selection "{parent_selection} \
236 " \
237not supported. Valid options: {list(valid_parent_sel)}'
238 )
239 valid_mutations = mutation_dict.keys()
240 if mutation not in valid_mutations: # pragma: no cover
241 raise ValueError(
242 f'Mutation operator "{mutation} \
243 " not \
244supported. Valid options: {list(valid_mutations)}'
245 )
247 return (parent_sel_dict[parent_selection], mutation_dict[mutation])
250def batch_randflip(
251 parents: Tensor, n_children: int, sparsity: Optional[float] = None, p_bf: Optional[float] = None
252) -> Tensor:
253 """Generate n_children new states from parents by flipping one different bit per children.
255 :param parents: Tensor with shape (N, n_parents, H)
256 :param n_children: How many children to generate per parent per datapoint
257 :returns: children, a Tensor with shape (N, n_parents * n_children, H)
258 """
259 device = parents.device
261 # Select k indices to be flipped by generating H random numbers per parent
262 # and taking the indexes of the largest k.
263 # This ensures that, per parent, each child is different.
264 N, n_parents, H = parents.shape
265 ind_flip = to.topk(
266 to.rand((N, n_parents, H), dtype=to.float64, device=device),
267 k=n_children,
268 dim=2,
269 sorted=False,
270 )[1]
271 ind_flip = ind_flip.view(N, n_parents * n_children)
273 # Each parent is "repeated" n_children times and inserted in children.
274 # We then flips bits in the children states
275 children = parents.repeat(1, 1, n_children).view(N, -1, H) # is (N, n_parents*n_children, H)
277 n_idx = to.arange(N)[:, None] # broadcastable to ind_flip shape
278 s_idx = to.arange(n_parents * n_children)[None, :] # broadcastable to ind_flip shape
279 children[n_idx, s_idx, ind_flip] = 1 - children[n_idx, s_idx, ind_flip]
281 return children
284def batch_sparseflip(
285 parents: Tensor, n_children: int, sparsity: Optional[float], p_bf: Optional[float]
286) -> Tensor:
287 """Take a set of parent bitstrings, generate n_children new bitstrings
288 by performing bitflips on each of the parents.
290 :param parents: Tensor with shape (N, n_parents, H)
291 :param n_children: number of children to produce per parent per datapoint
292 :param sparsity: the algorithm will strive to produce children with the given sparsity
293 :param p_bf: overall probability that a bit is flipped. the average number
294 of bitflips per children is p_bf*parents.shape[1]
295 :returns: Tensor with shape (N, n_parents*n_children, H)
296 """
297 # Initialization
298 precision, device = to.float64, parents.device
299 N, n_parents, H = parents.shape
300 eps = 1e-100
301 crowdedness = sparsity * H
303 H = float(H)
304 s_abs = parents.sum(dim=2).to(dtype=precision) # is (N, n_parents)
306 # # Probability to flip a 1 to a 0 and vice versa (Joerg's idea)
307 # p_0 = H / ( 2 * ( H - s_abs) + eps) * p_bf, # is (n_parents,)
308 # p_1 = H / ( 2 * s_abs + eps) * p_bf # is (n_parents,)
310 # Probability to flip a 1 to a 0 and vice versa (modification of Joerg's idea)
311 # is (n_parents)
312 alpha = (
313 (H - s_abs)
314 * ((H * p_bf) - (crowdedness - s_abs))
315 / ((crowdedness - s_abs + H * p_bf) * s_abs + eps)
316 )
317 p_0 = (H * p_bf) / (H + (alpha - 1.0) * s_abs) + eps # is (N, n_parents)
318 p_1 = alpha * p_0
319 p_0 = p_0[:, :, None].expand(-1, -1, int(H)).repeat(1, 1, n_children).view(N, -1, int(H))
320 p_1 = p_1[:, :, None].expand(-1, -1, int(H)).repeat(1, 1, n_children).view(N, -1, int(H))
322 # start from children equal to the parents (with each parent repeated n_children times)
323 children = parents.repeat(1, 1, n_children).view(N, n_parents * n_children, int(H))
324 assert children.shape == (N, n_parents * n_children, H)
325 bool_or_byte = (to.empty(0) < 0).dtype # BoolTensor in pytorch>=1.2, ByteTensor otherwise
326 children_idx = children.to(bool_or_byte)
327 p = to.where(children_idx, p_1, p_0)
329 # Determine bits to be flipped and do the bitflip
330 flips = to.rand((N, n_parents * n_children, int(H)), dtype=precision, device=device) < p
331 children[flips] = 1 - children[flips]
333 return children
336# TODO probably to be made a cython helper function for performance
337def _fill_crossed_idxs_for_batch(
338 parent_pairs, crossed_idxs, parent1_starts, cutting_points_1, cutting_points_2, parent2_ends
339):
340 n_pairs = parent_pairs.shape[0]
341 N = parent1_starts.shape[0]
342 for n in range(N):
343 for pp_idx in range(n_pairs):
344 parent1, parent2 = parent_pairs[pp_idx]
345 o1 = parent1_starts[n, pp_idx]
346 o2 = cutting_points_1[n, pp_idx]
347 o3 = cutting_points_2[n, pp_idx]
348 o4 = parent2_ends[n, pp_idx]
349 crossed_idxs[n, o1:o2] = parent1
350 crossed_idxs[n, o2:o3] = parent2
351 crossed_idxs[n, o3:o4] = parent1
354def batch_cross(parents: Tensor) -> Tensor:
355 """For each datapoint, each pair of parents is crossed generating two children.
357 :param parents: Tensor with shape (N, n_parents, H)
358 :returns: Tensor with shape (N, n_parents*(n_parents - 1), H)
360 The crossover is performed by selecting a "cut point" and switching the
361 """
362 N, n_parents, H = parents.shape
363 parent_pairs = np.array(list(combinations(range(n_parents), 2)), dtype=np.int64)
364 n_pairs = parent_pairs.shape[0]
365 cutting_points = np.random.randint(low=1, high=H, size=(N, n_pairs))
366 n_children = n_pairs * 2 # will produce 2 children per pair
368 # The next lines build (N, n_children, H) indexes that swap
369 # parent elements to produce the desired crossover.
370 crossed_idxs = np.empty((N, n_children * H), dtype=np.int64)
371 parent_pair_idxs = np.arange(n_pairs)
372 parent1_starts = np.tile(parent_pair_idxs * (2 * H), (N, 1)) # (N, n_children * H)
373 cutting_points_1 = parent1_starts + cutting_points
374 cutting_points_2 = cutting_points_1 + H
375 parent2_ends = parent1_starts + 2 * H
376 _fill_crossed_idxs_for_batch(
377 parent_pairs, crossed_idxs, parent1_starts, cutting_points_1, cutting_points_2, parent2_ends
378 )
379 crossed_idxs = crossed_idxs.reshape(N, n_children, H)
381 children = parents[np.arange(N)[:, None, None], crossed_idxs, np.arange(H)[None, None, :]]
382 return children
385def batch_cross_randflip(
386 parents: Tensor, n_children: int, sparsity: float = None, p_bf: float = None
387) -> Tensor:
388 children = batch_randflip(batch_cross(parents), 1)
389 return children
392def batch_cross_sparseflip(
393 parents: Tensor, n_children: int, sparsity: float, p_bf: float
394) -> Tensor:
395 children = batch_sparseflip(batch_cross(parents), 1, sparsity, p_bf)
396 return children
399def batch_fitparents(candidates: Tensor, n_parents: int, lpj: Tensor) -> Tensor:
400 # NOTE: this a fitness-proportional parent selection __with replacement__
402 precision, device = lpj.dtype, candidates.device
403 assert candidates.shape[:2] == lpj.shape, "candidates and lpj must have same shape"
405 # compute fitness (per batch)
406 lpj_fitness = lpj - 2 * to.min(to.tensor([to.min(lpj).item(), 0.0])).item()
407 # is (batch_size, no_candidates).
408 lpj_fitness = lpj_fitness / lpj_fitness.sum()
409 assert lpj_fitness.shape == lpj.shape
411 # we will look for the indeces n for which cum_p[n-1] < x < cump[n]
412 # last dimension of x < cum_p will be of the form [False,...,False,
413 # True,...,True]
414 # summing along the last dimension gives the number of elements greater
415 # than x subtracting that from the size of the dimension gives the
416 # desired index n
417 cum_p = to.cumsum(lpj_fitness, dim=-1) # (x, y, ..., z), same shape as lpj
418 x = to.rand((*cum_p.shape[:-1], n_parents), dtype=precision, device=device)
419 # (x, y, ..., n_parents)
421 # TODO Find simpler solution
422 x_view = tuple(x.shape) + (-1,)
423 cum_p_view = list(cum_p.shape)
424 cum_p_view.insert(-1, -1)
425 cum_p_view = tuple(cum_p_view) # type: ignore
427 chosen_idx = cum_p.shape[-1] - 1 - (x.view(x_view) < cum_p.view(cum_p_view)).sum(dim=-1)
429 # TODO Find solution without numpy conversion
430 all_idx = to.from_numpy(np.indices(tuple(chosen_idx.shape))).to(device=device)
431 # TODO Find solution without numpy conversion
432 all_idx[-1] = chosen_idx
433 choices = candidates[tuple(i for i in all_idx)]
434 assert choices.shape == (candidates.shape[0], n_parents, candidates.shape[2])
435 return choices
438def batch_randparents(candidates: Tensor, n_parents: int, lpj: Tensor = None) -> Tensor:
439 device = candidates.device
440 batch_size, n_candidates, H = candidates.shape
441 # for each batch, choose n_parents random idxs, concatenate all idxs together
442 ind_children = to.cat(
443 tuple(to.randperm(n_candidates, device=device)[:n_parents] for _ in range(batch_size))
444 )
445 # generate indxs for the first dimention of candidates that match ind_children, e.g.
446 # [0,0,1,1,2,2] for batch_size=3 and n_parents=2
447 # (for each element in the batch, we have 2 ind_children)
448 # TODO: change to `repeat_interleave(to.arange(batch_size), n_parents)` when
449 # a new-enough pytorch version becomes available at Oldenburg.
450 ind_batch = to.arange(batch_size).unsqueeze(1).repeat(1, n_parents).view(-1)
451 # need a reshape because the fancy indexing flattens the first two dimensions
452 parents = candidates[ind_batch, ind_children].reshape(batch_size, n_parents, H)
453 return parents
456# ******* The following functions are currently unused ******* #
457# THey are kept here as a reference for the new batched implementations.
460def fitparents(candidates: Tensor, n_parents: int, lpj: Tensor) -> Tensor:
461 device = candidates.device
463 # compute fitness (per data point)
464 lpj_fitness = lpj - 2 * to.min([to.min(lpj), 0.0]) # is (no_candidates,)
465 lpj_fitness = lpj_fitness / lpj_fitness.sum()
467 # sample (indices of) parents according to fitness
468 # TODO Find solution without numpy conversion
469 ind_children = np.random.choice(
470 candidates.shape[0], size=n_parents, replace=False, p=lpj_fitness.to(device="cpu").numpy()
471 )
472 # is (n_parents, H)
473 return candidates[to.from_numpy(ind_children).to(device=device)]
476def randflip(
477 parents: Tensor, n_children: int, sparsity: Optional[float] = None, p_bf: Optional[float] = None
478) -> Tensor:
479 """Generate n_children new states from parents by flipping one different bit per children."""
481 precision, device = to.float64, parents.device
483 # Select k indices to be flipped by generating H random numbers per parent
484 # and taking the indexes of the largest k.
485 # This ensures that, per parent, each child is different.
486 n_parents, H = parents.shape
487 ind_flip = to.topk(
488 to.rand((n_parents, H), dtype=precision, device=device),
489 k=n_children,
490 dim=1,
491 largest=True,
492 sorted=False,
493 )[1]
494 ind_flip_flat = ind_flip.flatten() # [ parent1bitflip1, parent1bitflip2,
495 # parent2bitflip1, parent2bitflip2 ]
497 # Each parent is "repeated" n_children times and inserted in children.
498 # We then flips bits in the children states
499 children = parents.repeat(1, n_children).view(-1, H)
500 # is (n_parents*n_children, H)
502 # for each new state (0 to n_children*n_parents-1), flip bit at the
503 # position indicated by ind_flip_flat
504 ind_slice_flat = to.arange(n_children * n_parents, device=parents.device)
506 children[ind_slice_flat, ind_flip_flat] = 1 - children[ind_slice_flat, ind_flip_flat]
508 return children
511def cross_sparseflip(parents: Tensor, n_children: int, sparsity: float, p_bf: float) -> Tensor:
512 children = sparseflip(cross(parents), 1, sparsity, p_bf)
513 return children
516def cross_randflip(
517 parents: Tensor, n_children: int, sparsity: float = None, p_bf: float = None
518) -> Tensor:
519 children = randflip(cross(parents), 1)
520 return children
523def cross(parents: Tensor) -> Tensor:
524 """Each pair of parents is crossed generating two children.
526 :param parents: Tensor with shape (n_parents, H)
527 :returns: Tensor with shape (n_parents*(n_parents - 1), H)
529 The crossover is performed by selecting a "cut point" and switching the
530 contents of the parents after the cut point.
531 """
532 n_parents, H = parents.shape
533 n_children = n_parents * (n_parents - 1)
534 cutting_points = np.random.randint(low=1, high=H, size=(n_children // 2,))
535 parent_pairs = np.array(list(combinations(range(n_parents), 2)), dtype=np.int64)
537 # The next lines build (n_children, H) indexes that swap parent entries to produce
538 # the desired crossover.
539 crossed_idxs = np.empty(n_children * H, dtype=np.int64)
540 parent_pair_idxs = np.arange(n_children // 2)
541 parent1_starts = parent_pair_idxs * (2 * H)
542 cutting_points_1 = parent1_starts + cutting_points
543 cutting_points_2 = cutting_points_1 + H
544 parent2_ends = parent1_starts + 2 * H
545 for pp_idx, o1, o2, o3, o4 in zip(
546 parent_pair_idxs, parent1_starts, cutting_points_1, cutting_points_2, parent2_ends
547 ):
548 parent1, parent2 = parent_pairs[pp_idx]
549 crossed_idxs[o1:o2] = parent1
550 crossed_idxs[o2:o3] = parent2
551 crossed_idxs[o3:o4] = parent1
552 crossed_idxs = crossed_idxs.reshape(n_children, H)
554 children = parents[crossed_idxs, range(H)]
555 return children
558def sparseflip(
559 parents: Tensor, n_children: int, sparsity: Optional[float], p_bf: Optional[float]
560) -> Tensor:
561 """Take a set of parent bitstrings, generate n_children new bitstrings
562 by performing bitflips on each of the parents.
564 The returned object has shape(parents.shape[0]*n_children,
565 parents.shape[1])
567 sparsity and p_bf regulate the probabilities of flipping each bit:
568 - sparsity: the algorithm will strive to produce children with the
569 given sparsity
570 - p_bf: overall probability that a bit is flipped. the average number
571 of bitflips per children is p_bf*parents.shape[1]
572 """
573 # Initialization
574 precision, device = to.float64, parents.device
575 n_parents, H = parents.shape
576 s_abs = parents.sum(dim=1) # is (n_parents)
577 children = parents.repeat(1, n_children).view(-1, H)
578 eps = 1e-100
579 crowdedness = sparsity * H
581 H = float(H)
582 s_abs = s_abs.to(dtype=precision)
584 # # Probability to flip a 1 to a 0 and vice versa (Joerg's idea)
585 # p_0 = H / ( 2 * ( H - s_abs) + eps) * p_bf, # is (n_parents,)
586 # p_1 = H / ( 2 * s_abs + eps) * p_bf # is (n_parents,)
588 # Probability to flip a 1 to a 0 and vice versa
589 # (modification of Joerg's idea)
590 alpha = (
591 (H - s_abs)
592 * ((H * p_bf) - (crowdedness - s_abs))
593 / ((crowdedness - s_abs + H * p_bf) * s_abs + eps)
594 ) # is (n_parents)
595 p_0 = (H * p_bf) / (H + (alpha - 1.0) * s_abs) + eps # is (n_parents,)
596 p_1 = (
597 (alpha * p_0)[:, None].expand(-1, int(H)).repeat(1, n_children).view(-1, int(H))
598 ) # is (n_parents*n_children, H)
599 p_0 = p_0[:, None].expand(-1, int(H)).repeat(1, n_children).view(-1, int(H))
600 # is (n_parents*n_children, H)
601 p = to.empty(p_0.shape, dtype=precision, device=device)
602 # BoolTensor in pytorch>=1.2, ByteTensor otherwise
603 bool_or_byte = (to.empty(0) < 0).dtype
604 children_idx = children.to(bool_or_byte)
605 p[children_idx] = p_1[children_idx]
606 p[~children_idx] = p_0[~children_idx]
608 # Determine bits to be flipped and do the bitflip
609 flips = to.rand((n_parents * n_children, int(H)), dtype=precision, device=device) < p
610 children[flips] = 1 - children[flips]
612 return children