tvo.utils.parallel
- tvo.utils.parallel.all_reduce(tensor, op=<RedOpType.SUM: 0>)[source]
Equivalent to torch’s all_reduce if tvo.get_run_policy() is ‘mpi’, no-op otherwise.
- tvo.utils.parallel.barrier()[source]
Equivalent to torch’s dist.barrier if tvo.get_run_policy() is ‘mpi’, no-op otherwise.
- tvo.utils.parallel.bcast_dtype(data, src=0)[source]
Broadcast dtype of data on src rank.
- Parameters:
data (
Tensor
) – Tensor on src ranksrc (
int
) – Source rank
- Return type:
dtype
- Returns:
dtype on each rank
- tvo.utils.parallel.bcast_shape(data, src)[source]
Broadcast shape of data on src rank.
- Parameters:
data (
Tensor
) – Tensor on src ranksrc (
int
) – Source rank
- Return type:
Tensor
- Returns:
Tensor with shape on each rank
- tvo.utils.parallel.broadcast(tensor, src=0)[source]
Equivalent to torch’s broadcast if tvo.get_run_policy() is ‘mpi’, no-op otherwise.
- tvo.utils.parallel.gather_from_processes(*my_tensors, dst=0)[source]
Gather tensors from process group.
- Parameters:
my_tensors (
Tensor
) – List of tensors to be gathered from local process on process dst. For each element tensor.shape[1:] must be identical on each process.dst (
int
) – Rank of destination process to gather tensors.
- Return type:
Union
[Tensor
,Iterable
[Tensor
]]- Returns:
List of tensors gathered from process group.
Only process with rank dst will contain gathered data.
- tvo.utils.parallel.get_h5_dataset_to_processes(fname, possible_keys)[source]
Return dataset with the first of possible_keys that is found in hdf5 file fname.
- Return type:
Tensor
- tvo.utils.parallel.init_processes(multi_node=False)[source]
Initialize MPI process group using torch.distributed module.
param multi_node: Deploy multiple computing nodes.
Eventually updates the value of tvo.device.
- tvo.utils.parallel.mpi_average_grads(theta)[source]
Average gradients across processes. See https://bit.ly/2FlJsxS.
- Parameters:
theta (
Dict
[str
,Tensor
]) – dictionary with torch.tensors storing TVO model parameters- Return type:
None
- tvo.utils.parallel.pprint(obj='', end='\\n')[source]
Print on root process of torch.distributed process group.
param obj: Message to print param end: Suffix of message. Default is linebreak.
- tvo.utils.parallel.scatter_to_processes(*tensors, src=0)[source]
Split tensors into chunks and scatter within process group.
- Parameters:
tensors (
Tensor
) – Tensor to be scattered. Chunks are cut along dimension 0.src (
int
) – Source rank to scatter from.
- Return type:
Iterable
[Tensor
]- Returns:
Tensor scattered to local rank.
Tensor data is assumed to be None on all but the root processes.