mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] support group=None when rewriting collectives (#121043)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121043 Approved by: https://github.com/awgu
This commit is contained in:
committed by
PyTorch MergeBot
parent
3fee05f242
commit
d7a5e59647
@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
|
||||
|
||||
|
||||
|
||||
moco,pass,11
|
||||
moco,pass,5
|
||||
|
||||
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
|
||||
|
||||
|
||||
|
||||
moco,pass,17
|
||||
moco,pass,11
|
||||
|
||||
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
|
||||
|
||||
|
||||
|
||||
moco,pass,11
|
||||
moco,pass,5
|
||||
|
||||
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ mobilenet_v3_large,pass,7
|
||||
|
||||
|
||||
|
||||
moco,pass,17
|
||||
moco,pass,11
|
||||
|
||||
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
|
||||
|
||||
|
||||
|
||||
moco,pass,11
|
||||
moco,pass,5
|
||||
|
||||
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ mobilenet_v3_large,pass,7
|
||||
|
||||
|
||||
|
||||
moco,pass,17
|
||||
moco,pass,11
|
||||
|
||||
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
|
||||
|
||||
|
||||
|
||||
moco,pass,11
|
||||
moco,pass,5
|
||||
|
||||
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
|
||||
|
||||
|
||||
|
||||
moco,pass,17
|
||||
moco,pass,11
|
||||
|
||||
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
|
||||
|
||||
|
||||
|
||||
moco,pass,11
|
||||
moco,pass,5
|
||||
|
||||
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
|
||||
|
||||
|
||||
|
||||
moco,pass,17
|
||||
moco,pass,11
|
||||
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,11 @@ from torch.testing._internal.common_distributed import (
|
||||
run_with_both_funcol_impls_with_arg,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests, requires_cuda
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
)
|
||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||
from torch.utils._triton import has_triton
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
@ -825,22 +829,43 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert same(outputs, correct_outputs)
|
||||
|
||||
@run_with_both_funcol_impls
|
||||
def test_dynamo_rewrite_dist_allreduce(self):
|
||||
@parametrize(
|
||||
"pg_mode",
|
||||
[
|
||||
"kwargs",
|
||||
"kwargs_none",
|
||||
"unspecified",
|
||||
]
|
||||
)
|
||||
def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
|
||||
|
||||
def func(tensor, pg):
|
||||
def func(tensor, *args, **kwargs):
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
group=pg
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
counter = CompileCounter()
|
||||
compiled = torch.compile(func, backend=counter, fullgraph=True)
|
||||
|
||||
args = []
|
||||
kwargs = {}
|
||||
|
||||
# TODO(yifu): test positional and positional_none
|
||||
# once explicit reduce op is supported
|
||||
if pg_mode == "kwargs":
|
||||
kwargs["group"] = GroupMember.WORLD
|
||||
elif pg_mode == "kwargs_none":
|
||||
kwargs["group"] = None
|
||||
else:
|
||||
assert pg_mode == "unspecified"
|
||||
|
||||
inputs_compiled = torch.ones(2, device=self.device)
|
||||
inputs_eager = torch.ones(2, device=self.device)
|
||||
|
||||
compiled(inputs_compiled, GroupMember.WORLD)
|
||||
func(inputs_eager, GroupMember.WORLD)
|
||||
compiled(inputs_compiled, *args, **kwargs)
|
||||
func(inputs_eager, *args, **kwargs)
|
||||
|
||||
assert counter.frame_count == 1
|
||||
# should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)
|
||||
|
@ -9,6 +9,8 @@ from .. import compiled_autograd, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from ..exc import unimplemented
|
||||
from ..external_utils import call_module_hooks_from_backward_state
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalSource
|
||||
from ..utils import istype
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
@ -259,6 +261,33 @@ class ProcessGroupVariable(DistributedVariable):
|
||||
|
||||
return istype(value, (ProcessGroup, FakeProcessGroup))
|
||||
|
||||
@staticmethod
|
||||
def get_global_pg_variable():
|
||||
"""
|
||||
Make a ProcessGroupVariable from torch.distributed.group.WORLD and
|
||||
intall guards.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
source = AttrSource(
|
||||
AttrSource(
|
||||
base=AttrSource(
|
||||
base=GlobalSource(global_name="torch"),
|
||||
member="distributed",
|
||||
get_static=False,
|
||||
),
|
||||
member="group",
|
||||
get_static=False,
|
||||
),
|
||||
member="WORLD",
|
||||
get_static=False,
|
||||
)
|
||||
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
||||
return ProcessGroupVariable(
|
||||
dist.group.WORLD,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
class BackwardHookVariable(VariableTracker):
|
||||
"""
|
||||
|
@ -17,6 +17,7 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
|
||||
from .base import MutableLocal, typestr, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .distributed import ProcessGroupVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._guards import Source
|
||||
@ -691,10 +692,21 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable):
|
||||
# call_function must check any unsupported arguments and graph-break.
|
||||
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
|
||||
# since that's the contract for putting a mapping in `traceable_collective_remaps`
|
||||
|
||||
# Merge args into kwargs so positional and keyword args
|
||||
# can be processed the same way.
|
||||
signature = inspect.signature(self.fn)
|
||||
kwargs = dict(signature.bind(*args, **kwargs).arguments)
|
||||
args = ()
|
||||
|
||||
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
|
||||
unimplemented(
|
||||
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
|
||||
)
|
||||
|
||||
if kwargs.get("group") is None or kwargs["group"].value is None:
|
||||
kwargs["group"] = ProcessGroupVariable.get_global_pg_variable()
|
||||
|
||||
return self.replacement_var.call_function(tx, args, kwargs)
|
||||
|
||||
|
||||
|
@ -1033,10 +1033,20 @@ def all_gather_inplace(
|
||||
assert (
|
||||
not async_op
|
||||
), "Can't remap async version of inplace op to functional collective"
|
||||
assert all(
|
||||
t.size(0) == tensor.size(0) for t in tensor_list
|
||||
), "Remapping variable size all_gather is not yet supported"
|
||||
|
||||
output = all_gather_tensor(tensor, 0, group, tag)
|
||||
for dst, src in zip(
|
||||
tensor_list, output.split([t.size(0) for t in tensor_list], dim=0)
|
||||
):
|
||||
|
||||
# Use aten.slice instead of aten.split because the latter causes
|
||||
# tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
|
||||
output_splits = []
|
||||
offset = 0
|
||||
for t in tensor_list:
|
||||
output_splits.append(output[offset : offset + t.size(0)])
|
||||
offset += t.size(0)
|
||||
for dst, src in zip(tensor_list, output_splits):
|
||||
dst.copy_(src)
|
||||
return tensor_list
|
||||
|
||||
|
Reference in New Issue
Block a user