mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Summary: Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation. Original diff reverted: D51629761 Original PR reverted: https://github.com/pytorch/pytorch/pull/114991 It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file. Test Plan: CI. Differential Revision: D51825114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099 Approved by: https://github.com/wanchaol, https://github.com/fegin
216 lines
7.1 KiB
Python
216 lines
7.1 KiB
Python
import inspect
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
from .. import variables
|
|
from ..exc import unimplemented
|
|
from ..utils import istype
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
|
|
|
|
class DistributedVariable(VariableTracker):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if not DistributedVariable.is_available():
|
|
unimplemented("torch.distributed package is not available!")
|
|
|
|
@staticmethod
|
|
def is_available():
|
|
# check if the distributed package is available or not
|
|
return torch.distributed.is_available()
|
|
|
|
|
|
def is_from_local(value):
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
return inspect.isfunction(value) and value is DTensor.from_local
|
|
|
|
|
|
def is_constant_pg_functions(value):
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
|
|
from torch.distributed.distributed_c10d import (
|
|
_get_group_tag,
|
|
get_process_group_ranks,
|
|
)
|
|
|
|
constant_processgroup_functions = [
|
|
get_process_group_ranks,
|
|
_get_group_tag,
|
|
]
|
|
|
|
return inspect.isfunction(value) and value in constant_processgroup_functions
|
|
|
|
|
|
class PlacementClassVariable(DistributedVariable):
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
@staticmethod
|
|
def is_placement_type(value):
|
|
# we can't rely on importing/accessing torch distributed, it is not always built.
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
|
|
return type(value) is type and issubclass(value, Placement)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if (
|
|
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
|
|
and self.source
|
|
):
|
|
# NOTE: we don't need to track mutations to the placement class as they
|
|
# suppose to be immutable.
|
|
new_obj = object.__new__(self.value)
|
|
var = PlacementVariable(new_obj)
|
|
if inspect.getattr_static(self.value, "__init__", None):
|
|
var.call_method(tx, "__init__", args, kwargs)
|
|
return var
|
|
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
|
|
class PlacementVariable(DistributedVariable):
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
@staticmethod
|
|
def is_placement(value):
|
|
# we can't rely on importing/accessing torch distributed, it is not always built.
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
|
|
return isinstance(value, Placement)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ConstantVariable
|
|
|
|
allowed_methods = ["__init__", "__setattr__"]
|
|
# placement types dynamo tracking allows only __init__
|
|
# and __setattr__ methods, the latter is for case like `Shard(dim)`
|
|
if name in allowed_methods:
|
|
try:
|
|
value_type = type(self.value)
|
|
assert (
|
|
inspect.getattr_static(value_type, "__getattr__", None) is None
|
|
), "no custom getattr allowed!"
|
|
method = inspect.getattr_static(value_type, name)
|
|
except AttributeError:
|
|
method = None
|
|
if method is object.__init__:
|
|
return ConstantVariable.create(None)
|
|
|
|
args = [x.as_python_constant() for x in args]
|
|
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
|
method(self.value, *args, **kwargs)
|
|
return self
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
|
|
class DeviceMeshVariable(DistributedVariable):
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
@staticmethod
|
|
def is_device_mesh(value):
|
|
# we can't rely on importing/accessing torch distributed, it is not always built.
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
|
|
return istype(value, DeviceMesh)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
def var_getattr(self, tx, name: str) -> VariableTracker:
|
|
if name == "ndim":
|
|
return ConstantVariable.create(self.value.ndim)
|
|
return super().var_getattr(tx, name)
|
|
|
|
|
|
class ProcessGroupVariable(DistributedVariable):
|
|
"""
|
|
We don't want a ProcessGroup object to end up in our output graph.
|
|
|
|
But it's common for dynamo to intercept a PG that is then used to get info like
|
|
rank() or world_size(), as well as passed to utility functions in distributed_c10d
|
|
which desugar it into plain types like a ranklist and tag.
|
|
|
|
For convenience and proper guarding, we construct a variable type.
|
|
|
|
TODO: make it possible to use ProcessGroupVariable as input to simple functions
|
|
like _expand_group without dynamo complaining about making a proxy for it.
|
|
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
|
|
torch library functions are dealing with tensor-like types and would have proxies
|
|
for their args.
|
|
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
|
|
or just graph-break whenever one of our special cases is not hit?
|
|
"""
|
|
|
|
def __init__(self, value, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "rank":
|
|
return variables.ConstantVariable.create(self.value.rank())
|
|
if name == "size":
|
|
return variables.ConstantVariable.create(self.value.size())
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name in ["rank", "size"]:
|
|
return variables.LambdaVariable(
|
|
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
|
)
|
|
# TODO should this just raise unimplemented?
|
|
return super().var_getattr(tx, name)
|
|
|
|
@staticmethod
|
|
def is_process_group(value):
|
|
# we can't rely on importing/accessing torch distributed, it is not always built.
|
|
if not DistributedVariable.is_available():
|
|
return False
|
|
from torch._C._distributed_c10d import ProcessGroup
|
|
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
|
|
|
|
return istype(value, (ProcessGroup, FakeProcessGroup))
|