Coverage for tvo/variational/TVOVariationalStates.py: 92%
25 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6from torch import Tensor
8from abc import ABC, abstractmethod
9from typing import Dict, Any
11from tvo.variational._utils import generate_unique_states
12from tvo.utils import get
13import tvo
14from tvo.utils.model_protocols import Trainable
15from tvo.utils.parallel import get_h5_dataset_to_processes
18class TVOVariationalStates(ABC):
19 def __init__(self, conf: Dict[str, Any], K_init: Tensor = None):
20 """Abstract base class for TVO realizations.
22 :param conf: dictionary with hyper-parameters. Required keys: N, H, S, dtype, device
23 :param K_init: if specified and if `conf` does specify `K_init_file`, self.K will be
24 initialized with this Tensor of shape (N,S,H)
25 """
26 required_keys = ("N", "H", "S", "S_new", "precision")
27 for c in required_keys:
28 assert c in conf and conf[c] is not None
29 self.config = conf
31 N, H, S, _, precision = get(conf, *required_keys)
33 _K_init = (
34 get_h5_dataset_to_processes(conf["K_init_file"], ("initial_states", "states"))
35 if "K_init_file" in conf and conf["K_init_file"] is not None
36 else K_init
37 )
39 if _K_init is not None:
40 assert _K_init.shape == (N, S, H)
41 self.K = _K_init.clone().to(dtype=to.uint8)
42 else:
43 self.K = generate_unique_states(S, H).repeat(N, 1, 1) # (N, S, H)
44 self.lpj = to.empty((N, S), dtype=precision, device=tvo.get_device())
45 self.precision = precision
47 @abstractmethod
48 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int:
49 """Generate new variational states, update K and lpj with best samples and their lpj.
51 :param idx: data point indices of batch w.r.t. K
52 :param batch: batch of data points
53 :param model: the model being used
54 :returns: average number of variational state substitutions per datapoint performed
55 """
56 pass # pragma: no cover