Coverage for tvo/variational/TVOVariationalStates.py: 92%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6from torch import Tensor 

7 

8from abc import ABC, abstractmethod 

9from typing import Dict, Any 

10 

11from tvo.variational._utils import generate_unique_states 

12from tvo.utils import get 

13import tvo 

14from tvo.utils.model_protocols import Trainable 

15from tvo.utils.parallel import get_h5_dataset_to_processes 

16 

17 

18class TVOVariationalStates(ABC): 

19 def __init__(self, conf: Dict[str, Any], K_init: Tensor = None): 

20 """Abstract base class for TVO realizations. 

21 

22 :param conf: dictionary with hyper-parameters. Required keys: N, H, S, dtype, device 

23 :param K_init: if specified and if `conf` does specify `K_init_file`, self.K will be 

24 initialized with this Tensor of shape (N,S,H) 

25 """ 

26 required_keys = ("N", "H", "S", "S_new", "precision") 

27 for c in required_keys: 

28 assert c in conf and conf[c] is not None 

29 self.config = conf 

30 

31 N, H, S, _, precision = get(conf, *required_keys) 

32 

33 _K_init = ( 

34 get_h5_dataset_to_processes(conf["K_init_file"], ("initial_states", "states")) 

35 if "K_init_file" in conf and conf["K_init_file"] is not None 

36 else K_init 

37 ) 

38 

39 if _K_init is not None: 

40 assert _K_init.shape == (N, S, H) 

41 self.K = _K_init.clone().to(dtype=to.uint8) 

42 else: 

43 self.K = generate_unique_states(S, H).repeat(N, 1, 1) # (N, S, H) 

44 self.lpj = to.empty((N, S), dtype=precision, device=tvo.get_device()) 

45 self.precision = precision 

46 

47 @abstractmethod 

48 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int: 

49 """Generate new variational states, update K and lpj with best samples and their lpj. 

50 

51 :param idx: data point indices of batch w.r.t. K 

52 :param batch: batch of data points 

53 :param model: the model being used 

54 :returns: average number of variational state substitutions per datapoint performed 

55 """ 

56 pass # pragma: no cover