Coverage for tvo/utils/_utils.py: 53%

15 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5from typing import Dict, Any 

6 

7 

8def get(d: Dict[Any, Any], *keys: Any): 

9 """Shorthand to retrieve valus at specified keys from dictionary. 

10 

11 :param d: input dictionary 

12 :param keys: a list of keys for dictionary d 

13 

14 Example usage:: 

15 

16 val1, val2 = get(my_dict, 'key1', 'key2') 

17 """ 

18 

19 return map(d.get, keys) 

20 

21 

22def get_lstsq(torch): 

23 """ 

24 Versioned least squares function depending on Pytorch version. 

25 Input: torch 

26 """ 

27 torch_major_version, torch_minor_version = torch.__version__.split(".")[:2] 

28 if int(torch_major_version) >= 2: 

29 

30 def lstsq(a, b): 

31 return torch.linalg.lstsq(b, a) 

32 

33 elif int(torch_minor_version) >= 10: 

34 # pytorch 1.10 deprecates to.lstsq in favour of to.linalg.lstsq, 

35 # which takes arguments in reversed order 

36 def lstsq(a, b): 

37 return torch.linalg.lstsq(b, a) 

38 

39 elif int(torch_minor_version) >= 2: 

40 # pytorch 1.2 deprecates to.gels in favour of to.lstsqgit 

41 lstsq = torch.lstsq 

42 

43 else: 

44 raise ValueError("Pytorch versions below 1.2 are unsupported") 

45 

46 return lstsq