mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Fix _get_or_create_default_group() (#96961)
Summary: This PR fixes `_get_or_create_default_group()` of `DeviceMesh`. When `mesh` of the first created `DeviceMesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]` and `is_initialized() == False`, it wrongly asserts. This PR fixes this issue by removing these assertions. --- More specifically, `_get_or_create_default_group()` has 4 checks: 1. `DeviceMesh must include every process in WORLD` 2. `DeviceMesh cannot have duplicate values` 3. `DeviceMesh ranks must start from 0` 4. `DeviceMesh should have all ranks of WORLD` 1, 3, and 4 are not satisfied when `self.mesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]`. 2 is a valid check, but it is also checked in `__init__()`, so we don't need to check it again in this function. Test Plan: CI Reviewed By: wanchaol Differential Revision: D44098849 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96961 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffddb2219a
commit
95575f0a5f
@ -1,5 +1,4 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
@ -222,33 +221,6 @@ class DeviceMesh:
|
||||
|
||||
def _get_or_create_default_group(self):
|
||||
if not is_initialized():
|
||||
# TODO: we will support mesh on a subset of WORLD in future
|
||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
if self.mesh.numel() < world_size:
|
||||
raise RuntimeError(
|
||||
"DeviceMesh must include every process in WORLD, "
|
||||
f"but WORLD_SIZE({world_size}) != mesh size({self.mesh.numel()})"
|
||||
)
|
||||
|
||||
unique_mesh_values = self.mesh.unique(sorted=True)
|
||||
if unique_mesh_values.numel() != self.mesh.numel():
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
|
||||
)
|
||||
|
||||
# ranks in mesh must start from 0
|
||||
if unique_mesh_values[0] != 0:
|
||||
raise RuntimeError(
|
||||
"DeviceMesh ranks must start from 0, "
|
||||
f"but found min rank = {unique_mesh_values[0]}"
|
||||
)
|
||||
|
||||
# mesh must be contiguous (i.e. from 0 to N-1)
|
||||
if 2 * unique_mesh_values.sum().item() != world_size * (world_size - 1):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh should have all ranks of WORLD, but found {self.mesh.tolist()}"
|
||||
)
|
||||
|
||||
_backend = "gloo" if self.device_type == "cpu" else "nccl"
|
||||
init_process_group(backend=_backend)
|
||||
return _get_default_group()
|
||||
|
Reference in New Issue
Block a user