Coverage for tvo/utils/sanity.py: 80%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import torch as to 

6 

7from torch import Tensor 

8from typing import Union 

9import torch.distributed as dist 

10from tvo.utils.parallel import broadcast 

11from typing import Dict, List 

12 

13 

14def fix_theta(theta: Dict[str, Tensor], policy: Dict[str, List]): 

15 """Perform sanity check of values in theta dict according to policy. 

16 

17 :param theta: Dictionary containing model parameters 

18 :param policy: Policy dictionary. Must have same keys as theta. 

19 

20 Each key in policy contains a list with three floats/tensors referred to as replacement, 

21 low_bound and up_bound. If tensors are provided these must have the same shape as the 

22 corresponding tensors in the theta dictionary. For each key 

23 - infinite values in tensors from the theta dictionary are replaced with values from the 

24 corresponding entry in replacement and 

25 - the values in tensors from the theta dictionary are clamped to the corresponding values 

26 of low_bound and up_bound. 

27 """ 

28 assert set(theta.keys()) == set(policy.keys()), "theta and policy must have same keys" 

29 

30 rank = dist.get_rank() if dist.is_initialized() else 0 

31 

32 for key, val in policy.items(): 

33 new_val = theta[key] 

34 if rank == 0: 

35 replacement, low_bound, up_bound = val 

36 

37 fix_infinite(new_val, replacement, key) 

38 fix_bounds(new_val, low_bound, up_bound, key) 

39 

40 broadcast(new_val) 

41 theta[key] = new_val 

42 

43 

44def fix_infinite(values: Tensor, replacement: Union[float, Tensor], name: str = None): 

45 """Fill infinite entries in values with replacement 

46 :param values: Input tensor 

47 :param replacement: Scalar or tensor with replacements for infinite values 

48 :param name: Name of input tensor (optional). 

49 """ 

50 mask_infinite = to.isnan(values) | to.isinf(values) 

51 if mask_infinite.any(): 

52 if isinstance(replacement, float): 

53 values[mask_infinite] = replacement 

54 elif isinstance(replacement, Tensor): 

55 values[mask_infinite] = replacement[mask_infinite] 

56 if name is not None: 

57 print("Sanity check: Replaced infinite entries of %s." % name) 

58 

59 

60def fix_bounds( 

61 values: Tensor, 

62 lower: Union[float, Tensor] = None, 

63 upper: Union[float, Tensor] = None, 

64 name: str = None, 

65): 

66 """Clamp entries in values to not exceed lower and upper bounds. 

67 :param values: Input tensor 

68 :param lower: Scalar or tensor with lower bounds for values 

69 :param upper: Scalar or tensor with upper bounds for values 

70 :param name: Name of input tensor (optional). 

71 """ 

72 if (lower is not None) and (values < lower).any(): 

73 if isinstance(lower, float): 

74 to.clamp(input=values, min=lower, out=values) 

75 elif isinstance(lower, Tensor): 

76 to.max(input=lower, other=values, out=values) 

77 if name is not None: 

78 print("Sanity check: Reset lower bound of %s" % name) 

79 

80 if (upper is not None) and (values >= upper).any(): 

81 if isinstance(upper, float): 

82 to.clamp(input=values, max=upper, out=values) 

83 elif isinstance(upper, Tensor): 

84 to.min(input=upper, other=values, out=values) 

85 if name is not None: 

86 print("Sanity check: Reset upper bound of %s" % name)