-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Support more backends in distributed operations #2468
Copy link
Copy link
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
The current version of MONAI utilizes torch.distributed for all_gather, all_reduce etc., e.g. in line 409 below. However, this will raise an Error when I launch the parallel programs in ignite.distributed.Parallel context with horovod or xla backend: RuntimeError: Default process group has not been initialized, please make sure to call init_process_group. Is it better to replace all the torch.distributed operations with ignite.distributed ones? They have naive support for all the backends.
Lines 359 to 423 in 1a707f0
| def get_dist_device(): | |
| """ | |
| Get the expected target device in the distributed data parallel. | |
| For NCCL backend, return GPU device of current process. | |
| For GLOO backend, return CPU. | |
| For any other backends, return None as the default, tensor.to(None) will not change the device. | |
| """ | |
| if dist.is_initialized(): | |
| backend = dist.get_backend() | |
| if backend == "nccl" and torch.cuda.is_available(): | |
| return torch.device(f"cuda:{torch.cuda.current_device()}") | |
| elif backend == "gloo": | |
| return torch.device("cpu") | |
| return None | |
| def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True): | |
| """ | |
| Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. | |
| The input data of every rank should have the same number of dimensions, only the first dim can be different. | |
| Args: | |
| data: source tensor to pad and execute all_gather in distributed data parallel. | |
| concat: whether to concat the gathered list to be a Tensor, if False, return a list | |
| of Tensors, similar behavior as torch.distributed.all_gather(). default to True. | |
| Note: | |
| The input data on different ranks must have exactly same `dtype`. | |
| """ | |
| if not isinstance(data, torch.Tensor): | |
| raise ValueError("input data must be PyTorch Tensor.") | |
| world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 | |
| if world_size <= 1: | |
| return data | |
| device = get_dist_device() | |
| orig_device = data.device | |
| data = data.to(device) | |
| # data of all the ranks must have same number of dimensions | |
| ndims = data.ndimension() | |
| if ndims == 0: | |
| # tensor must have batch dimension | |
| data = data.unsqueeze(0) | |
| # make sure the data is evenly-divisible on multi-GPUs | |
| length: int = data.shape[0] | |
| length_tensor = torch.as_tensor([length], device=device) | |
| all_lens = [torch.zeros_like(length_tensor) for _ in range(world_size)] | |
| dist.all_gather(all_lens, length_tensor) | |
| all_lens_: List[int] = [int(i.item()) for i in all_lens] | |
| max_len: int = max(all_lens_) | |
| if length < max_len: | |
| size = [max_len - length] + list(data.shape[1:]) | |
| data = torch.cat([data, data.new_full(size, 0)], dim=0) | |
| # all gather across all processes | |
| output = [torch.zeros_like(data) for _ in range(world_size)] | |
| dist.all_gather(output, data) | |
| # remove the padding items, if all the input data doesn't have batch dim, suqeeze the first dim | |
| output = [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)] | |
| return torch.cat(output, dim=0) if concat else output | |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request