Coverage for tvo/utils/global_settings.py: 84%
32 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import os
6import torch as to
9def _choose_device() -> to.device:
10 dev = to.device("cpu")
11 if "TVO_GPU" in os.environ:
12 gpu_n = int(os.environ["TVO_GPU"])
13 dev = to.device(f"cuda:{gpu_n}")
14 return dev
17class _GlobalDevice:
18 """A singleton object containing the global device settings for the framework.
20 Set and get the corresponding to.device with `{set,get}_device()`.
21 """
23 _device: to.device = _choose_device()
25 @classmethod
26 def get_device(cls) -> to.device:
27 return cls._device
29 @classmethod
30 def set_device(cls, dev: to.device):
31 cls._device = dev
34def get_device() -> to.device:
35 """Get the torch.device that all objects in the package will use by default.
37 The default ('cpu') can be changed by setting the `TVO_GPU` environment variable
38 to the number of the desired CUDA device. For example, in bash, `export TVO_GPU=0`
39 will make the framework default to device 'cuda:0'. Note that some computations might
40 still be performed on CPU for performance reasons.
41 """
42 return _GlobalDevice.get_device()
45def _set_device(dev: to.device):
46 """Private method to change the TVO device settings. USE WITH CARE."""
47 _GlobalDevice.set_device(dev)
50def _choose_run_policy() -> str:
51 policy = "seq"
52 if (
53 "TVO_MPI" in os.environ and os.environ["TVO_MPI"] != 0
54 ) or "OMPI_COMM_WORLD_SIZE" in os.environ:
55 policy = "mpi"
56 return policy
59class _GlobalPolicy:
60 """A singleton object containing the global execution policy for the framework.
62 Set and get the policy with `{set,get}_run_policy()`.
63 """
65 _policy: str = _choose_run_policy()
67 @classmethod
68 def get_policy(cls) -> str:
69 return cls._policy
72def get_run_policy() -> str:
73 """Get the current parallelization policy.
75 * `'seq'`: the framework will not perform any parallelization other than what torch tensors
76 offer out of the box on the relevant device.
77 * `'mpi'`: the framework will perform data parallelization for the algorithms that
78 implement it.
80 The policy is 'seq' unless the framework detects that the program is running within `mpirun`,
81 in which case the policy is 'mpi'. The default can also be overridden by setting the
82 `TVO_MPI` environment variable to a non-zero value, e.g. in bash with
83 `export TVO_MPI=1`.
84 """
85 return _GlobalPolicy.get_policy()