Coverage for tvo/utils/parallel.py: 30%
145 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# -*- coding: utf-8 -*-
2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg.
3# Licensed under the Academic Free License version 3.0
5import platform
6import math
7import h5py
9import torch
10import torch.distributed as dist
11from torch import Tensor
12from typing import Iterable, Union, Dict, Tuple
14import tvo
17def pprint(obj: object = "", end: str = "\n"):
18 """Print on root process of torch.distributed process group.
20 param obj: Message to print
21 param end: Suffix of message. Default is linebreak.
22 """
23 if tvo.get_run_policy() == "mpi" and dist.get_rank() != 0:
24 return
26 print(obj, end=end)
29def init_processes(multi_node: bool = False):
30 """Initialize MPI process group using torch.distributed module.
32 param multi_node: Deploy multiple computing nodes.
34 Eventually updates the value of tvo.device.
35 """
36 if torch.distributed.is_initialized():
37 return
39 dist.init_process_group("mpi")
41 global_rank = dist.get_rank()
42 comm_size = dist.get_world_size()
44 if tvo.get_device().type == "cuda":
45 device_count = int(torch.cuda.device_count())
46 if multi_node:
47 node_count = comm_size // device_count
48 else:
49 node_count = 1
50 # 0..device_count (first_node), ..., 0..device_count (last_node)
51 local_rank = (list(range(device_count)) * node_count)[global_rank]
52 device_str = "cuda:%i" % local_rank
53 else:
54 device_str = "cpu"
56 tvo._set_device(torch.device(device_str))
58 pprint("Initializting %i processes." % comm_size)
59 print(
60 "New process on %s. Global rank %d. Device %s. Total no processes %d."
61 % (platform.node(), global_rank, device_str, comm_size)
62 )
65def bcast_dtype(data: Tensor, src: int = 0) -> torch.dtype:
66 """Broadcast dtype of data on src rank.
68 :param data: Tensor on src rank
69 :param src: Source rank
70 :returns: dtype on each rank
71 """
72 if tvo.get_run_policy() == "seq":
73 return data.dtype
74 else:
75 comm_rank = dist.get_rank()
77 dtypes = [
78 torch.float32,
79 torch.float64,
80 torch.float16,
81 torch.uint8,
82 torch.int8,
83 torch.int16,
84 torch.int32,
85 torch.int64,
86 torch.bool,
87 ]
89 ind_dtype = torch.empty((1,), dtype=torch.uint8)
90 if comm_rank == src:
91 dtype = data.dtype
92 ind_dtype[:] = [*map(str, dtypes)].index(str(dtype))
93 dist.broadcast(ind_dtype, 0)
94 return dtypes[ind_dtype.item()]
97def bcast_shape(data: Tensor, src: int) -> Tensor:
98 """Broadcast shape of data on src rank.
100 :param data: Tensor on src rank
101 :param src: Source rank
102 :returns: Tensor with shape on each rank
103 """
104 if tvo.get_run_policy() == "seq":
105 return torch.tensor(data.shape)
106 else:
107 comm_rank = dist.get_rank()
109 ndim = torch.empty((1,), dtype=torch.int64)
110 if comm_rank == src:
111 ndim[:] = data.dim()
112 dist.broadcast(ndim, src)
113 shape = torch.empty((ndim.item(),), dtype=torch.int64)
114 if comm_rank == src:
115 shape[:] = torch.tensor(data.shape)
116 dist.broadcast(shape, src)
117 return shape
120def scatter_to_processes(*tensors: Tensor, src: int = 0) -> Iterable[Tensor]:
121 """Split tensors into chunks and scatter within process group.
123 :param tensors: Tensor to be scattered. Chunks are cut along dimension 0.
124 :param src: Source rank to scatter from.
125 :returns: Tensor scattered to local rank.
127 Tensor data is assumed to be None on all but the root processes.
128 """
129 my_tensors = []
131 if tvo.get_run_policy() == "seq":
132 for data in tensors:
133 my_tensors.append(data)
135 elif tvo.get_run_policy() == "mpi":
136 comm_size, comm_rank = dist.get_world_size(), dist.get_rank()
137 for data in tensors:
138 this_dtype = bcast_dtype(data, src)
139 is_bool = this_dtype == torch.bool
140 if is_bool: # workaround to avoid `IndexError: map::at` when scattering to.bool tensor
141 this_dtype = torch.uint8
142 this_device = tvo.get_device()
144 shape = bcast_shape(data, src)
145 total_length = shape[0].item()
146 other_length = tuple(shape[1:])
148 # logic to ensure that input to `dist.scatter` is evenly divisible by comm_size
149 assert (
150 total_length / comm_size
151 ) >= 1, "number of data points must be greater or equal to number of MPI processes"
152 local_length_ceiled = math.ceil(total_length / comm_size)
153 total_length_ceiled = local_length_ceiled * comm_size
154 no_dummy = total_length_ceiled - total_length
155 local_length = local_length_ceiled - 1 if comm_rank < no_dummy else local_length_ceiled
157 # split into chunks and scatter
158 chunks = [] # type: ignore
159 if comm_rank == 0:
160 to_cut_into_chunks = torch.zeros(
161 ((total_length_ceiled,) + other_length), dtype=this_dtype, device=this_device
162 )
163 local_start = 0
164 for r in range(comm_size):
165 local_length_ = local_length_ceiled - 1 if r < no_dummy else local_length_ceiled
166 to_cut_into_chunks[
167 r * local_length_ceiled : r * local_length_ceiled + local_length_
168 ] = data[range(local_start, local_start + local_length_)]
169 local_start += local_length_
170 chunks = list(torch.chunk(to_cut_into_chunks, comm_size, dim=0))
172 my_data = torch.zeros(
173 (local_length_ceiled,) + other_length, dtype=this_dtype, device=this_device
174 )
176 dist.scatter(my_data, src=src, scatter_list=chunks)
178 if is_bool:
179 my_data = my_data.to(dtype=torch.bool)
181 my_data = my_data[:local_length]
183 N = torch.tensor([local_length])
184 all_reduce(N)
185 assert N.item() == total_length
187 my_tensors.append(my_data)
189 return my_tensors[0] if len(my_tensors) == 1 else my_tensors
192def gather_from_processes(*my_tensors: Tensor, dst: int = 0) -> Union[Tensor, Iterable[Tensor]]:
193 """Gather tensors from process group.
195 :param my_tensors: List of tensors to be gathered from local process on process dst.
196 For each element tensor.shape[1:] must be identical on
197 each process.
198 :param dst: Rank of destination process to gather tensors.
199 :returns: List of tensors gathered from process group.
201 Only process with rank dst will contain gathered data.
202 """
203 tensors = []
205 if tvo.get_run_policy() == "seq":
206 for data in my_tensors:
207 tensors.append(data)
209 elif tvo.get_run_policy() == "mpi":
210 comm_size, comm_rank = dist.get_world_size(), dist.get_rank()
211 for my_data in my_tensors:
212 local_length = my_data.shape[0]
213 other_length = tuple(my_data.shape[1:])
214 total_length = torch.tensor([local_length])
215 all_reduce(total_length)
216 total_length = total_length.item()
217 local_length_ceiled = math.ceil(total_length / comm_size)
218 no_dummy = local_length_ceiled * comm_size - total_length
220 chunks = (
221 [
222 torch.zeros(
223 (local_length_ceiled,) + other_length,
224 dtype=my_data.dtype,
225 device=my_data.device,
226 )
227 for r in range(comm_size)
228 ]
229 if comm_rank == 0
230 else []
231 )
233 dist.gather(
234 tensor=torch.cat(
235 (
236 my_data,
237 torch.zeros(
238 (1,) + other_length, dtype=my_data.dtype, device=my_data.device
239 ),
240 )
241 )
242 if comm_rank < no_dummy
243 else my_data,
244 gather_list=chunks,
245 dst=dst,
246 )
248 if comm_rank == 0:
249 for r in range(no_dummy):
250 chunks[r] = chunks[r][:-1]
251 data = torch.cat(chunks)
252 assert data.shape[0] == total_length
253 tensors.append(data)
255 return tensors[0] if len(tensors) == 1 else tensors
258def all_reduce(tensor: Tensor, op=dist.ReduceOp.SUM):
259 """Equivalent to torch's all_reduce if tvo.get_run_policy() is 'mpi', no-op otherwise."""
260 if tvo.get_run_policy() == "mpi":
261 dist.all_reduce(tensor, op)
264def broadcast(tensor: Tensor, src: int = 0):
265 """Equivalent to torch's broadcast if tvo.get_run_policy() is 'mpi', no-op otherwise."""
266 if tvo.get_run_policy() == "mpi":
267 dist.broadcast(tensor, src)
270def barrier():
271 """Equivalent to torch's dist.barrier if tvo.get_run_policy() is 'mpi', no-op otherwise."""
272 if tvo.get_run_policy() == "mpi":
273 dist.barrier()
276def mpi_average_grads(theta: Dict[str, torch.Tensor]) -> None:
277 """Average gradients across processes. See https://bit.ly/2FlJsxS.
279 :param theta: dictionary with torch.tensors storing TVO model parameters
280 """
281 if tvo.get_run_policy() != "mpi":
282 return # nothing to do
284 n_procs = dist.get_world_size()
285 parameters = [p for p in theta.values() if p.requires_grad]
286 with torch.no_grad():
287 for p in parameters:
288 all_reduce(p.grad)
289 p.grad /= n_procs
292def get_h5_dataset_to_processes(fname: str, possible_keys: Tuple[str, ...]) -> torch.Tensor:
293 """Return dataset with the first of `possible_keys` that is found in hdf5 file `fname`."""
294 rank = dist.get_rank() if dist.is_initialized() else 0
296 f = h5py.File(fname, "r")
297 for dataset in possible_keys:
298 if dataset in f.keys():
299 break
300 else: # pragma: no cover
301 raise ValueError(f'File "{fname}" does not contain any of keys {possible_keys}')
302 if rank == 0:
303 data = torch.tensor(f[dataset][...], device=tvo.get_device())
304 else:
305 data = None
306 return scatter_to_processes(data)