Coverage for tvo/utils/_utils.py: 53%
15 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5from typing import Dict, Any
8def get(d: Dict[Any, Any], *keys: Any):
9 """Shorthand to retrieve valus at specified keys from dictionary.
11 :param d: input dictionary
12 :param keys: a list of keys for dictionary d
14 Example usage::
16 val1, val2 = get(my_dict, 'key1', 'key2')
17 """
19 return map(d.get, keys)
22def get_lstsq(torch):
23 """
24 Versioned least squares function depending on Pytorch version.
25 Input: torch
26 """
27 torch_major_version, torch_minor_version = torch.__version__.split(".")[:2]
28 if int(torch_major_version) >= 2:
30 def lstsq(a, b):
31 return torch.linalg.lstsq(b, a)
33 elif int(torch_minor_version) >= 10:
34 # pytorch 1.10 deprecates to.lstsq in favour of to.linalg.lstsq,
35 # which takes arguments in reversed order
36 def lstsq(a, b):
37 return torch.linalg.lstsq(b, a)
39 elif int(torch_minor_version) >= 2:
40 # pytorch 1.2 deprecates to.gels in favour of to.lstsqgit
41 lstsq = torch.lstsq
43 else:
44 raise ValueError("Pytorch versions below 1.2 are unsupported")
46 return lstsq