Coverage for tvo/utils/gen.py: 71%
17 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6import tvo
9def generate_bars(
10 H: int,
11 bar_amp: float = 1.0,
12 neg_amp: bool = False,
13 bg_amp: float = 0.0,
14 add_unit: float = None,
15 precision: to.dtype = to.float64,
16):
17 """Generate a ground-truth dictionary W suitable for a std. bars test
19 Creates H bases vectors with horizontal and vertival bars on a R*R pixel grid,
20 (wth R = H // 2).
22 :param H: Number of latent variables
23 :param bar_amp: Amplitude of each bar
24 :param neg_amp: Set probability of amplitudes taking negative values to 50 percent
25 :param bg_amp: Background amplitude
26 :param add_unit: If not None an additional unit with amplitude add_unit will be inserted
27 :param precision: torch.dtype of the returned tensor
28 :returns: tensor containing the bars dictionary
29 """
30 R = H // 2
31 D = R**2
33 W = bg_amp * to.ones((R, R, H), dtype=precision, device=tvo.get_device())
34 for i in range(R):
35 W[i, :, i] = bar_amp
36 W[:, i, R + i] = bar_amp
38 if neg_amp:
39 sign = 1 - 2 * to.randint(high=2, size=(H), device=tvo.get_device())
40 W = sign[None, None, :] * W
42 if add_unit is not None:
43 add_unit = add_unit * to.ones((D, 1), device=tvo.get_device())
44 W = to.cat((W, add_unit), dim=1)
45 H += 1
47 return W.view((D, H))