Coverage for tvo/utils/gen.py: 71%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6import tvo 

7 

8 

9def generate_bars( 

10 H: int, 

11 bar_amp: float = 1.0, 

12 neg_amp: bool = False, 

13 bg_amp: float = 0.0, 

14 add_unit: float = None, 

15 precision: to.dtype = to.float64, 

16): 

17 """Generate a ground-truth dictionary W suitable for a std. bars test 

18 

19 Creates H bases vectors with horizontal and vertival bars on a R*R pixel grid, 

20 (wth R = H // 2). 

21 

22 :param H: Number of latent variables 

23 :param bar_amp: Amplitude of each bar 

24 :param neg_amp: Set probability of amplitudes taking negative values to 50 percent 

25 :param bg_amp: Background amplitude 

26 :param add_unit: If not None an additional unit with amplitude add_unit will be inserted 

27 :param precision: torch.dtype of the returned tensor 

28 :returns: tensor containing the bars dictionary 

29 """ 

30 R = H // 2 

31 D = R**2 

32 

33 W = bg_amp * to.ones((R, R, H), dtype=precision, device=tvo.get_device()) 

34 for i in range(R): 

35 W[i, :, i] = bar_amp 

36 W[:, i, R + i] = bar_amp 

37 

38 if neg_amp: 

39 sign = 1 - 2 * to.randint(high=2, size=(H), device=tvo.get_device()) 

40 W = sign[None, None, :] * W 

41 

42 if add_unit is not None: 

43 add_unit = add_unit * to.ones((D, 1), device=tvo.get_device()) 

44 W = to.cat((W, add_unit), dim=1) 

45 H += 1 

46 

47 return W.view((D, H))