[DTensor] Fix _get_or_create_default_group() (#96961)

Summary:
This PR fixes `_get_or_create_default_group()` of `DeviceMesh`. When `mesh` of the first created `DeviceMesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]` and `is_initialized() == False`, it wrongly asserts. This PR fixes this issue by removing these assertions.

 ---

More specifically, `_get_or_create_default_group()` has 4 checks:

1. `DeviceMesh must include every process in WORLD`
2. `DeviceMesh cannot have duplicate values`
3. `DeviceMesh ranks must start from 0`
4. `DeviceMesh should have all ranks of WORLD`

1, 3, and 4 are not satisfied when `self.mesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]`.

2 is a valid check, but it is also checked in `__init__()`, so we don't need to check it again in this function.

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D44098849

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96961
Approved by: https://github.com/wanchaol
This commit is contained in:
Shintaro Iwasaki
2023-03-17 15:52:19 +00:00
committed by PyTorch MergeBot
parent ffddb2219a
commit 95575f0a5f

View File

@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import os
import warnings
from typing import List, Optional, Sequence, TypeVar, Union
@ -222,33 +221,6 @@ class DeviceMesh:
def _get_or_create_default_group(self):
if not is_initialized():
# TODO: we will support mesh on a subset of WORLD in future
world_size = int(os.getenv("WORLD_SIZE", 1))
if self.mesh.numel() < world_size:
raise RuntimeError(
"DeviceMesh must include every process in WORLD, "
f"but WORLD_SIZE({world_size}) != mesh size({self.mesh.numel()})"
)
unique_mesh_values = self.mesh.unique(sorted=True)
if unique_mesh_values.numel() != self.mesh.numel():
raise RuntimeError(
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
)
# ranks in mesh must start from 0
if unique_mesh_values[0] != 0:
raise RuntimeError(
"DeviceMesh ranks must start from 0, "
f"but found min rank = {unique_mesh_values[0]}"
)
# mesh must be contiguous (i.e. from 0 to N-1)
if 2 * unique_mesh_values.sum().item() != world_size * (world_size - 1):
raise RuntimeError(
f"DeviceMesh should have all ranks of WORLD, but found {self.mesh.tolist()}"
)
_backend = "gloo" if self.device_type == "cpu" else "nccl"
init_process_group(backend=_backend)
return _get_default_group()