Coverage for tvo/utils/parallel.py: 30%

145 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# -*- coding: utf-8 -*- 

2# Copyright (C) 2019 Machine Learning Group of the University of Oldenburg. 

3# Licensed under the Academic Free License version 3.0 

4 

5import platform 

6import math 

7import h5py 

8 

9import torch 

10import torch.distributed as dist 

11from torch import Tensor 

12from typing import Iterable, Union, Dict, Tuple 

13 

14import tvo 

15 

16 

17def pprint(obj: object = "", end: str = "\n"): 

18 """Print on root process of torch.distributed process group. 

19 

20 param obj: Message to print 

21 param end: Suffix of message. Default is linebreak. 

22 """ 

23 if tvo.get_run_policy() == "mpi" and dist.get_rank() != 0: 

24 return 

25 

26 print(obj, end=end) 

27 

28 

29def init_processes(multi_node: bool = False): 

30 """Initialize MPI process group using torch.distributed module. 

31 

32 param multi_node: Deploy multiple computing nodes. 

33 

34 Eventually updates the value of tvo.device. 

35 """ 

36 if torch.distributed.is_initialized(): 

37 return 

38 

39 dist.init_process_group("mpi") 

40 

41 global_rank = dist.get_rank() 

42 comm_size = dist.get_world_size() 

43 

44 if tvo.get_device().type == "cuda": 

45 device_count = int(torch.cuda.device_count()) 

46 if multi_node: 

47 node_count = comm_size // device_count 

48 else: 

49 node_count = 1 

50 # 0..device_count (first_node), ..., 0..device_count (last_node) 

51 local_rank = (list(range(device_count)) * node_count)[global_rank] 

52 device_str = "cuda:%i" % local_rank 

53 else: 

54 device_str = "cpu" 

55 

56 tvo._set_device(torch.device(device_str)) 

57 

58 pprint("Initializting %i processes." % comm_size) 

59 print( 

60 "New process on %s. Global rank %d. Device %s. Total no processes %d." 

61 % (platform.node(), global_rank, device_str, comm_size) 

62 ) 

63 

64 

65def bcast_dtype(data: Tensor, src: int = 0) -> torch.dtype: 

66 """Broadcast dtype of data on src rank. 

67 

68 :param data: Tensor on src rank 

69 :param src: Source rank 

70 :returns: dtype on each rank 

71 """ 

72 if tvo.get_run_policy() == "seq": 

73 return data.dtype 

74 else: 

75 comm_rank = dist.get_rank() 

76 

77 dtypes = [ 

78 torch.float32, 

79 torch.float64, 

80 torch.float16, 

81 torch.uint8, 

82 torch.int8, 

83 torch.int16, 

84 torch.int32, 

85 torch.int64, 

86 torch.bool, 

87 ] 

88 

89 ind_dtype = torch.empty((1,), dtype=torch.uint8) 

90 if comm_rank == src: 

91 dtype = data.dtype 

92 ind_dtype[:] = [*map(str, dtypes)].index(str(dtype)) 

93 dist.broadcast(ind_dtype, 0) 

94 return dtypes[ind_dtype.item()] 

95 

96 

97def bcast_shape(data: Tensor, src: int) -> Tensor: 

98 """Broadcast shape of data on src rank. 

99 

100 :param data: Tensor on src rank 

101 :param src: Source rank 

102 :returns: Tensor with shape on each rank 

103 """ 

104 if tvo.get_run_policy() == "seq": 

105 return torch.tensor(data.shape) 

106 else: 

107 comm_rank = dist.get_rank() 

108 

109 ndim = torch.empty((1,), dtype=torch.int64) 

110 if comm_rank == src: 

111 ndim[:] = data.dim() 

112 dist.broadcast(ndim, src) 

113 shape = torch.empty((ndim.item(),), dtype=torch.int64) 

114 if comm_rank == src: 

115 shape[:] = torch.tensor(data.shape) 

116 dist.broadcast(shape, src) 

117 return shape 

118 

119 

120def scatter_to_processes(*tensors: Tensor, src: int = 0) -> Iterable[Tensor]: 

121 """Split tensors into chunks and scatter within process group. 

122 

123 :param tensors: Tensor to be scattered. Chunks are cut along dimension 0. 

124 :param src: Source rank to scatter from. 

125 :returns: Tensor scattered to local rank. 

126 

127 Tensor data is assumed to be None on all but the root processes. 

128 """ 

129 my_tensors = [] 

130 

131 if tvo.get_run_policy() == "seq": 

132 for data in tensors: 

133 my_tensors.append(data) 

134 

135 elif tvo.get_run_policy() == "mpi": 

136 comm_size, comm_rank = dist.get_world_size(), dist.get_rank() 

137 for data in tensors: 

138 this_dtype = bcast_dtype(data, src) 

139 is_bool = this_dtype == torch.bool 

140 if is_bool: # workaround to avoid `IndexError: map::at` when scattering to.bool tensor 

141 this_dtype = torch.uint8 

142 this_device = tvo.get_device() 

143 

144 shape = bcast_shape(data, src) 

145 total_length = shape[0].item() 

146 other_length = tuple(shape[1:]) 

147 

148 # logic to ensure that input to `dist.scatter` is evenly divisible by comm_size 

149 assert ( 

150 total_length / comm_size 

151 ) >= 1, "number of data points must be greater or equal to number of MPI processes" 

152 local_length_ceiled = math.ceil(total_length / comm_size) 

153 total_length_ceiled = local_length_ceiled * comm_size 

154 no_dummy = total_length_ceiled - total_length 

155 local_length = local_length_ceiled - 1 if comm_rank < no_dummy else local_length_ceiled 

156 

157 # split into chunks and scatter 

158 chunks = [] # type: ignore 

159 if comm_rank == 0: 

160 to_cut_into_chunks = torch.zeros( 

161 ((total_length_ceiled,) + other_length), dtype=this_dtype, device=this_device 

162 ) 

163 local_start = 0 

164 for r in range(comm_size): 

165 local_length_ = local_length_ceiled - 1 if r < no_dummy else local_length_ceiled 

166 to_cut_into_chunks[ 

167 r * local_length_ceiled : r * local_length_ceiled + local_length_ 

168 ] = data[range(local_start, local_start + local_length_)] 

169 local_start += local_length_ 

170 chunks = list(torch.chunk(to_cut_into_chunks, comm_size, dim=0)) 

171 

172 my_data = torch.zeros( 

173 (local_length_ceiled,) + other_length, dtype=this_dtype, device=this_device 

174 ) 

175 

176 dist.scatter(my_data, src=src, scatter_list=chunks) 

177 

178 if is_bool: 

179 my_data = my_data.to(dtype=torch.bool) 

180 

181 my_data = my_data[:local_length] 

182 

183 N = torch.tensor([local_length]) 

184 all_reduce(N) 

185 assert N.item() == total_length 

186 

187 my_tensors.append(my_data) 

188 

189 return my_tensors[0] if len(my_tensors) == 1 else my_tensors 

190 

191 

192def gather_from_processes(*my_tensors: Tensor, dst: int = 0) -> Union[Tensor, Iterable[Tensor]]: 

193 """Gather tensors from process group. 

194 

195 :param my_tensors: List of tensors to be gathered from local process on process dst. 

196 For each element tensor.shape[1:] must be identical on 

197 each process. 

198 :param dst: Rank of destination process to gather tensors. 

199 :returns: List of tensors gathered from process group. 

200 

201 Only process with rank dst will contain gathered data. 

202 """ 

203 tensors = [] 

204 

205 if tvo.get_run_policy() == "seq": 

206 for data in my_tensors: 

207 tensors.append(data) 

208 

209 elif tvo.get_run_policy() == "mpi": 

210 comm_size, comm_rank = dist.get_world_size(), dist.get_rank() 

211 for my_data in my_tensors: 

212 local_length = my_data.shape[0] 

213 other_length = tuple(my_data.shape[1:]) 

214 total_length = torch.tensor([local_length]) 

215 all_reduce(total_length) 

216 total_length = total_length.item() 

217 local_length_ceiled = math.ceil(total_length / comm_size) 

218 no_dummy = local_length_ceiled * comm_size - total_length 

219 

220 chunks = ( 

221 [ 

222 torch.zeros( 

223 (local_length_ceiled,) + other_length, 

224 dtype=my_data.dtype, 

225 device=my_data.device, 

226 ) 

227 for r in range(comm_size) 

228 ] 

229 if comm_rank == 0 

230 else [] 

231 ) 

232 

233 dist.gather( 

234 tensor=torch.cat( 

235 ( 

236 my_data, 

237 torch.zeros( 

238 (1,) + other_length, dtype=my_data.dtype, device=my_data.device 

239 ), 

240 ) 

241 ) 

242 if comm_rank < no_dummy 

243 else my_data, 

244 gather_list=chunks, 

245 dst=dst, 

246 ) 

247 

248 if comm_rank == 0: 

249 for r in range(no_dummy): 

250 chunks[r] = chunks[r][:-1] 

251 data = torch.cat(chunks) 

252 assert data.shape[0] == total_length 

253 tensors.append(data) 

254 

255 return tensors[0] if len(tensors) == 1 else tensors 

256 

257 

258def all_reduce(tensor: Tensor, op=dist.ReduceOp.SUM): 

259 """Equivalent to torch's all_reduce if tvo.get_run_policy() is 'mpi', no-op otherwise.""" 

260 if tvo.get_run_policy() == "mpi": 

261 dist.all_reduce(tensor, op) 

262 

263 

264def broadcast(tensor: Tensor, src: int = 0): 

265 """Equivalent to torch's broadcast if tvo.get_run_policy() is 'mpi', no-op otherwise.""" 

266 if tvo.get_run_policy() == "mpi": 

267 dist.broadcast(tensor, src) 

268 

269 

270def barrier(): 

271 """Equivalent to torch's dist.barrier if tvo.get_run_policy() is 'mpi', no-op otherwise.""" 

272 if tvo.get_run_policy() == "mpi": 

273 dist.barrier() 

274 

275 

276def mpi_average_grads(theta: Dict[str, torch.Tensor]) -> None: 

277 """Average gradients across processes. See https://bit.ly/2FlJsxS. 

278 

279 :param theta: dictionary with torch.tensors storing TVO model parameters 

280 """ 

281 if tvo.get_run_policy() != "mpi": 

282 return # nothing to do 

283 

284 n_procs = dist.get_world_size() 

285 parameters = [p for p in theta.values() if p.requires_grad] 

286 with torch.no_grad(): 

287 for p in parameters: 

288 all_reduce(p.grad) 

289 p.grad /= n_procs 

290 

291 

292def get_h5_dataset_to_processes(fname: str, possible_keys: Tuple[str, ...]) -> torch.Tensor: 

293 """Return dataset with the first of `possible_keys` that is found in hdf5 file `fname`.""" 

294 rank = dist.get_rank() if dist.is_initialized() else 0 

295 

296 f = h5py.File(fname, "r") 

297 for dataset in possible_keys: 

298 if dataset in f.keys(): 

299 break 

300 else: # pragma: no cover 

301 raise ValueError(f'File "{fname}" does not contain any of keys {possible_keys}') 

302 if rank == 0: 

303 data = torch.tensor(f[dataset][...], device=tvo.get_device()) 

304 else: 

305 data = None 

306 return scatter_to_processes(data)