Coverage for tvo/utils/data.py: 70%
37 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import torch as to
6import torch.distributed as dist
7from torch.utils.data import TensorDataset, DataLoader, Dataset, Sampler
8import numpy as np
9import tvo
10from tvo.utils.parallel import broadcast
13class TVODataLoader(DataLoader):
14 def __init__(self, *data: to.Tensor, **kwargs):
15 """TVO DataLoader class. Derived from torch.utils.data.DataLoader.
17 :param data: Tensor containing the input dataset. Must have exactly two dimensions (N,D).
18 :param kwargs: forwarded to pytorch's DataLoader.
20 TVODataLoader is constructed exactly the same way as pytorch's DataLoader,
21 but it restricts datasets to TensorDataset constructed from the *data passed
22 as parameter. All other arguments are forwarded to pytorch's DataLoader.
24 When iterated over, TVODataLoader yields a tuple containing the indeces of
25 the datapoints in each batch as well as the actual datapoints for each
26 tensor in the input Tensor.
28 TVODataLoader instances optionally expose the attribute `precision`, which is set to the
29 dtype of the first dataset in *data if it is a floating point dtype.
30 """
31 N = data[0].shape[0]
32 assert all(d.shape[0] == N for d in data), "Dimension mismatch in data sets."
34 if data[0].dtype is not to.uint8:
35 self.precision = data[0].dtype
37 dataset = TensorDataset(to.arange(N), *data)
39 if tvo.get_run_policy() == "mpi" and "sampler" not in kwargs:
40 # Number of _desired_ datapoints per worker: the last worker might have less actual
41 # datapoints, but we want it to sample as many as the other workers so that all
42 # processes can loop over batches in sync.
43 # NOTE: this means that the E-step will sometimes write over a certain K[idx] and
44 # lpj[idx] twice over the course of an epoch, even in the same batch (although that
45 # will happen rarely). This double writing is not a race condition: the last write wins.
46 n_samples = to.tensor(N)
47 assert dist.is_initialized()
48 comm_size = dist.get_world_size()
49 # Ranks ..., (comm_size-2), (comm_size-1) are
50 # assigned one data point more than ranks
51 # 0, 1, ... if the dataset cannot be evenly
52 # distributed across MPI processes. The split
53 # point depends on the total number of data
54 # points and number of MPI processes (see
55 # scatter_to_processes, gather_from_processes)
56 broadcast(n_samples, src=comm_size - 1)
57 kwargs["sampler"] = ShufflingSampler(dataset, int(n_samples))
58 kwargs["shuffle"] = None
60 super().__init__(dataset, **kwargs)
63class ShufflingSampler(Sampler):
64 def __init__(self, dataset: Dataset, n_samples: int = None):
65 """A torch sampler that shuffles datapoints.
67 :param dataset: The torch dataset for this sampler.
68 :param n_samples: Number of desired samples. Defaults to len(dataset). If larger than
69 len(dataset), some datapoints will be sampled multiple times.
70 """
71 self._ds_len = len(dataset)
72 self.n_samples = n_samples if n_samples is not None else self._ds_len
74 def __iter__(self):
75 idxs = np.arange(self._ds_len)
76 np.random.shuffle(idxs)
78 if self.n_samples > self._ds_len:
79 n_extra_samples = self.n_samples - self._ds_len
80 replace = True if n_extra_samples > idxs.size else False
81 extra_samples = np.random.choice(idxs, size=n_extra_samples, replace=replace)
82 idxs = np.concatenate((idxs, extra_samples))
83 else:
84 idxs = idxs[: self.n_samples]
86 return iter(idxs)
88 def __len__(self):
89 return self.n_samples