Coverage for tvo/variational/RandomSampledVarStates.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6from torch import Tensor 

7from tvo.utils.model_protocols import Trainable, Optimized 

8 

9from ._utils import update_states_for_batch 

10from .TVOVariationalStates import TVOVariationalStates 

11 

12 

13class RandomSampledVarStates(TVOVariationalStates): 

14 def __init__( 

15 self, 

16 N: int, 

17 H: int, 

18 S: int, 

19 precision: to.dtype, 

20 S_new: int, 

21 sparsity: float = 0.5, 

22 K_init_file: str = None, 

23 ): 

24 """A TVOVariationalStates implementation that performs random sampling. 

25 

26 :param N: number of datapoints 

27 :param H: number of latents 

28 :param S: number of variational states 

29 :param precision: floating point precision to be used for log_joint values. 

30 Must be one of to.float32 or to.float64. 

31 :param S_new: number of states to be sampled at every call to ~update 

32 :param sparsity: average fraction of active units in sampled states. 

33 :param K_init_file: Full path to H5 file providing initial states 

34 """ 

35 conf = dict( 

36 N=N, 

37 H=H, 

38 S=S, 

39 precision=precision, 

40 S_new=S_new, 

41 sparsity=sparsity, 

42 K_init_file=K_init_file, 

43 ) 

44 super().__init__(conf) 

45 

46 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int: 

47 """See :func:`tvo.variational.TVOVariationalStates.update`.""" 

48 if isinstance(model, Optimized): 

49 lpj_fn = model.log_pseudo_joint 

50 sort_by_lpj = model.sorted_by_lpj 

51 else: 

52 lpj_fn = model.log_joint 

53 sort_by_lpj = {} 

54 K = self.K[idx] 

55 batch_size, S, H = K.shape 

56 self.lpj[idx] = lpj_fn(batch, K) 

57 new_K = ( 

58 to.rand(batch_size, self.config["S_new"], H, device=K.device) < self.config["sparsity"] 

59 ).byte() 

60 new_lpj = lpj_fn(batch, new_K) 

61 

62 return update_states_for_batch( 

63 new_K, new_lpj, idx, self.K, self.lpj, sort_by_lpj=sort_by_lpj 

64 )