tvo.utils.parallel

tvo.utils.parallel.all_reduce(tensor, op=<RedOpType.SUM: 0>)[source]

Equivalent to torch’s all_reduce if tvo.get_run_policy() is ‘mpi’, no-op otherwise.

tvo.utils.parallel.barrier()[source]

Equivalent to torch’s dist.barrier if tvo.get_run_policy() is ‘mpi’, no-op otherwise.

tvo.utils.parallel.bcast_dtype(data, src=0)[source]

Broadcast dtype of data on src rank.

Parameters:
  • data (Tensor) – Tensor on src rank

  • src (int) – Source rank

Return type:

dtype

Returns:

dtype on each rank

tvo.utils.parallel.bcast_shape(data, src)[source]

Broadcast shape of data on src rank.

Parameters:
  • data (Tensor) – Tensor on src rank

  • src (int) – Source rank

Return type:

Tensor

Returns:

Tensor with shape on each rank

tvo.utils.parallel.broadcast(tensor, src=0)[source]

Equivalent to torch’s broadcast if tvo.get_run_policy() is ‘mpi’, no-op otherwise.

tvo.utils.parallel.gather_from_processes(*my_tensors, dst=0)[source]

Gather tensors from process group.

Parameters:
  • my_tensors (Tensor) – List of tensors to be gathered from local process on process dst. For each element tensor.shape[1:] must be identical on each process.

  • dst (int) – Rank of destination process to gather tensors.

Return type:

Union[Tensor, Iterable[Tensor]]

Returns:

List of tensors gathered from process group.

Only process with rank dst will contain gathered data.

tvo.utils.parallel.get_h5_dataset_to_processes(fname, possible_keys)[source]

Return dataset with the first of possible_keys that is found in hdf5 file fname.

Return type:

Tensor

tvo.utils.parallel.init_processes(multi_node=False)[source]

Initialize MPI process group using torch.distributed module.

param multi_node: Deploy multiple computing nodes.

Eventually updates the value of tvo.device.

tvo.utils.parallel.mpi_average_grads(theta)[source]

Average gradients across processes. See https://bit.ly/2FlJsxS.

Parameters:

theta (Dict[str, Tensor]) – dictionary with torch.tensors storing TVO model parameters

Return type:

None

tvo.utils.parallel.pprint(obj='', end='\\n')[source]

Print on root process of torch.distributed process group.

param obj: Message to print param end: Suffix of message. Default is linebreak.

tvo.utils.parallel.scatter_to_processes(*tensors, src=0)[source]

Split tensors into chunks and scatter within process group.

Parameters:
  • tensors (Tensor) – Tensor to be scattered. Chunks are cut along dimension 0.

  • src (int) – Source rank to scatter from.

Return type:

Iterable[Tensor]

Returns:

Tensor scattered to local rank.

Tensor data is assumed to be None on all but the root processes.