Coverage for tvo/utils/sanity.py: 80%
41 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
7from torch import Tensor
8from typing import Union
9import torch.distributed as dist
10from tvo.utils.parallel import broadcast
11from typing import Dict, List
14def fix_theta(theta: Dict[str, Tensor], policy: Dict[str, List]):
15 """Perform sanity check of values in theta dict according to policy.
17 :param theta: Dictionary containing model parameters
18 :param policy: Policy dictionary. Must have same keys as theta.
20 Each key in policy contains a list with three floats/tensors referred to as replacement,
21 low_bound and up_bound. If tensors are provided these must have the same shape as the
22 corresponding tensors in the theta dictionary. For each key
23 - infinite values in tensors from the theta dictionary are replaced with values from the
24 corresponding entry in replacement and
25 - the values in tensors from the theta dictionary are clamped to the corresponding values
26 of low_bound and up_bound.
27 """
28 assert set(theta.keys()) == set(policy.keys()), "theta and policy must have same keys"
30 rank = dist.get_rank() if dist.is_initialized() else 0
32 for key, val in policy.items():
33 new_val = theta[key]
34 if rank == 0:
35 replacement, low_bound, up_bound = val
37 fix_infinite(new_val, replacement, key)
38 fix_bounds(new_val, low_bound, up_bound, key)
40 broadcast(new_val)
41 theta[key] = new_val
44def fix_infinite(values: Tensor, replacement: Union[float, Tensor], name: str = None):
45 """Fill infinite entries in values with replacement
46 :param values: Input tensor
47 :param replacement: Scalar or tensor with replacements for infinite values
48 :param name: Name of input tensor (optional).
49 """
50 mask_infinite = to.isnan(values) | to.isinf(values)
51 if mask_infinite.any():
52 if isinstance(replacement, float):
53 values[mask_infinite] = replacement
54 elif isinstance(replacement, Tensor):
55 values[mask_infinite] = replacement[mask_infinite]
56 if name is not None:
57 print("Sanity check: Replaced infinite entries of %s." % name)
60def fix_bounds(
61 values: Tensor,
62 lower: Union[float, Tensor] = None,
63 upper: Union[float, Tensor] = None,
64 name: str = None,
65):
66 """Clamp entries in values to not exceed lower and upper bounds.
67 :param values: Input tensor
68 :param lower: Scalar or tensor with lower bounds for values
69 :param upper: Scalar or tensor with upper bounds for values
70 :param name: Name of input tensor (optional).
71 """
72 if (lower is not None) and (values < lower).any():
73 if isinstance(lower, float):
74 to.clamp(input=values, min=lower, out=values)
75 elif isinstance(lower, Tensor):
76 to.max(input=lower, other=values, out=values)
77 if name is not None:
78 print("Sanity check: Reset lower bound of %s" % name)
80 if (upper is not None) and (values >= upper).any():
81 if isinstance(upper, float):
82 to.clamp(input=values, max=upper, out=values)
83 elif isinstance(upper, Tensor):
84 to.min(input=upper, other=values, out=values)
85 if name is not None:
86 print("Sanity check: Reset upper bound of %s" % name)