Coverage for tvo/utils/global_settings.py: 84%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import os 

6import torch as to 

7 

8 

9def _choose_device() -> to.device: 

10 dev = to.device("cpu") 

11 if "TVO_GPU" in os.environ: 

12 gpu_n = int(os.environ["TVO_GPU"]) 

13 dev = to.device(f"cuda:{gpu_n}") 

14 return dev 

15 

16 

17class _GlobalDevice: 

18 """A singleton object containing the global device settings for the framework. 

19 

20 Set and get the corresponding to.device with `{set,get}_device()`. 

21 """ 

22 

23 _device: to.device = _choose_device() 

24 

25 @classmethod 

26 def get_device(cls) -> to.device: 

27 return cls._device 

28 

29 @classmethod 

30 def set_device(cls, dev: to.device): 

31 cls._device = dev 

32 

33 

34def get_device() -> to.device: 

35 """Get the torch.device that all objects in the package will use by default. 

36 

37 The default ('cpu') can be changed by setting the `TVO_GPU` environment variable 

38 to the number of the desired CUDA device. For example, in bash, `export TVO_GPU=0` 

39 will make the framework default to device 'cuda:0'. Note that some computations might 

40 still be performed on CPU for performance reasons. 

41 """ 

42 return _GlobalDevice.get_device() 

43 

44 

45def _set_device(dev: to.device): 

46 """Private method to change the TVO device settings. USE WITH CARE.""" 

47 _GlobalDevice.set_device(dev) 

48 

49 

50def _choose_run_policy() -> str: 

51 policy = "seq" 

52 if ( 

53 "TVO_MPI" in os.environ and os.environ["TVO_MPI"] != 0 

54 ) or "OMPI_COMM_WORLD_SIZE" in os.environ: 

55 policy = "mpi" 

56 return policy 

57 

58 

59class _GlobalPolicy: 

60 """A singleton object containing the global execution policy for the framework. 

61 

62 Set and get the policy with `{set,get}_run_policy()`. 

63 """ 

64 

65 _policy: str = _choose_run_policy() 

66 

67 @classmethod 

68 def get_policy(cls) -> str: 

69 return cls._policy 

70 

71 

72def get_run_policy() -> str: 

73 """Get the current parallelization policy. 

74 

75 * `'seq'`: the framework will not perform any parallelization other than what torch tensors 

76 offer out of the box on the relevant device. 

77 * `'mpi'`: the framework will perform data parallelization for the algorithms that 

78 implement it. 

79 

80 The policy is 'seq' unless the framework detects that the program is running within `mpirun`, 

81 in which case the policy is 'mpi'. The default can also be overridden by setting the 

82 `TVO_MPI` environment variable to a non-zero value, e.g. in bash with 

83 `export TVO_MPI=1`. 

84 """ 

85 return _GlobalPolicy.get_policy()