Coverage for tvo/variational/RandomSampledVarStates.py: 100%
21 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6from torch import Tensor
7from tvo.utils.model_protocols import Trainable, Optimized
9from ._utils import update_states_for_batch
10from .TVOVariationalStates import TVOVariationalStates
13class RandomSampledVarStates(TVOVariationalStates):
14 def __init__(
15 self,
16 N: int,
17 H: int,
18 S: int,
19 precision: to.dtype,
20 S_new: int,
21 sparsity: float = 0.5,
22 K_init_file: str = None,
23 ):
24 """A TVOVariationalStates implementation that performs random sampling.
26 :param N: number of datapoints
27 :param H: number of latents
28 :param S: number of variational states
29 :param precision: floating point precision to be used for log_joint values.
30 Must be one of to.float32 or to.float64.
31 :param S_new: number of states to be sampled at every call to ~update
32 :param sparsity: average fraction of active units in sampled states.
33 :param K_init_file: Full path to H5 file providing initial states
34 """
35 conf = dict(
36 N=N,
37 H=H,
38 S=S,
39 precision=precision,
40 S_new=S_new,
41 sparsity=sparsity,
42 K_init_file=K_init_file,
43 )
44 super().__init__(conf)
46 def update(self, idx: Tensor, batch: Tensor, model: Trainable) -> int:
47 """See :func:`tvo.variational.TVOVariationalStates.update`."""
48 if isinstance(model, Optimized):
49 lpj_fn = model.log_pseudo_joint
50 sort_by_lpj = model.sorted_by_lpj
51 else:
52 lpj_fn = model.log_joint
53 sort_by_lpj = {}
54 K = self.K[idx]
55 batch_size, S, H = K.shape
56 self.lpj[idx] = lpj_fn(batch, K)
57 new_K = (
58 to.rand(batch_size, self.config["S_new"], H, device=K.device) < self.config["sparsity"]
59 ).byte()
60 new_lpj = lpj_fn(batch, new_K)
62 return update_states_for_batch(
63 new_K, new_lpj, idx, self.K, self.lpj, sort_by_lpj=sort_by_lpj
64 )