[dynamo] support group=None when rewriting collectives (#121043)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121043
Approved by: https://github.com/awgu
This commit is contained in:
Yifu Wang
2024-03-04 01:35:11 -08:00
committed by PyTorch MergeBot
parent 3fee05f242
commit d7a5e59647
14 changed files with 95 additions and 19 deletions

View File

@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
moco,pass,11
moco,pass,5

1 name accuracy graph_breaks
234
235
236
237
238
239
240

View File

@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
moco,pass,17
moco,pass,11

1 name accuracy graph_breaks
162
163
164
165
166
167
168

View File

@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
moco,pass,11
moco,pass,5

1 name accuracy graph_breaks
234
235
236
237
238
239
240

View File

@ -158,7 +158,7 @@ mobilenet_v3_large,pass,7
moco,pass,17
moco,pass,11

1 name accuracy graph_breaks
158
159
160
161
162
163
164

View File

@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
moco,pass,11
moco,pass,5

1 name accuracy graph_breaks
234
235
236
237
238
239
240

View File

@ -158,7 +158,7 @@ mobilenet_v3_large,pass,7
moco,pass,17
moco,pass,11

1 name accuracy graph_breaks
158
159
160
161
162
163
164

View File

@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
moco,pass,11
moco,pass,5

1 name accuracy graph_breaks
234
235
236
237
238
239
240

View File

@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
moco,pass,17
moco,pass,11

1 name accuracy graph_breaks
162
163
164
165
166
167
168

View File

@ -234,7 +234,7 @@ mobilenet_v3_large,pass,0
moco,pass,11
moco,pass,5

1 name accuracy graph_breaks
234
235
236
237
238
239
240

View File

@ -162,7 +162,7 @@ mobilenet_v3_large,pass,7
moco,pass,17
moco,pass,11

1 name accuracy graph_breaks
162
163
164
165
166
167
168

View File

@ -22,7 +22,11 @@ from torch.testing._internal.common_distributed import (
run_with_both_funcol_impls_with_arg,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import instantiate_parametrized_tests, requires_cuda
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch.utils._triton import has_triton
from torch._inductor.utils import run_and_get_triton_code
@ -825,22 +829,43 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert same(outputs, correct_outputs)
@run_with_both_funcol_impls
def test_dynamo_rewrite_dist_allreduce(self):
@parametrize(
"pg_mode",
[
"kwargs",
"kwargs_none",
"unspecified",
]
)
def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
def func(tensor, pg):
def func(tensor, *args, **kwargs):
torch.distributed.all_reduce(
tensor,
group=pg
*args,
**kwargs,
)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
args = []
kwargs = {}
# TODO(yifu): test positional and positional_none
# once explicit reduce op is supported
if pg_mode == "kwargs":
kwargs["group"] = GroupMember.WORLD
elif pg_mode == "kwargs_none":
kwargs["group"] = None
else:
assert pg_mode == "unspecified"
inputs_compiled = torch.ones(2, device=self.device)
inputs_eager = torch.ones(2, device=self.device)
compiled(inputs_compiled, GroupMember.WORLD)
func(inputs_eager, GroupMember.WORLD)
compiled(inputs_compiled, *args, **kwargs)
func(inputs_eager, *args, **kwargs)
assert counter.frame_count == 1
# should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)

View File

@ -9,6 +9,8 @@ from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented
from ..external_utils import call_module_hooks_from_backward_state
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource
from ..utils import istype
from .base import VariableTracker
from .constant import ConstantVariable
@ -259,6 +261,33 @@ class ProcessGroupVariable(DistributedVariable):
return istype(value, (ProcessGroup, FakeProcessGroup))
@staticmethod
def get_global_pg_variable():
"""
Make a ProcessGroupVariable from torch.distributed.group.WORLD and
intall guards.
"""
import torch.distributed as dist
source = AttrSource(
AttrSource(
base=AttrSource(
base=GlobalSource(global_name="torch"),
member="distributed",
get_static=False,
),
member="group",
get_static=False,
),
member="WORLD",
get_static=False,
)
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
return ProcessGroupVariable(
dist.group.WORLD,
source=source,
)
class BackwardHookVariable(VariableTracker):
"""

View File

@ -17,6 +17,7 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
from .distributed import ProcessGroupVariable
if TYPE_CHECKING:
from torch._guards import Source
@ -691,10 +692,21 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable):
# call_function must check any unsupported arguments and graph-break.
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
# since that's the contract for putting a mapping in `traceable_collective_remaps`
# Merge args into kwargs so positional and keyword args
# can be processed the same way.
signature = inspect.signature(self.fn)
kwargs = dict(signature.bind(*args, **kwargs).arguments)
args = ()
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
unimplemented(
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
)
if kwargs.get("group") is None or kwargs["group"].value is None:
kwargs["group"] = ProcessGroupVariable.get_global_pg_variable()
return self.replacement_var.call_function(tx, args, kwargs)

View File

@ -1033,10 +1033,20 @@ def all_gather_inplace(
assert (
not async_op
), "Can't remap async version of inplace op to functional collective"
assert all(
t.size(0) == tensor.size(0) for t in tensor_list
), "Remapping variable size all_gather is not yet supported"
output = all_gather_tensor(tensor, 0, group, tag)
for dst, src in zip(
tensor_list, output.split([t.size(0) for t in tensor_list], dim=0)
):
# Use aten.slice instead of aten.split because the latter causes
# tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
output_splits = []
offset = 0
for t in tensor_list:
output_splits.append(output[offset : offset + t.size(0)])
offset += t.size(0)
for dst, src in zip(tensor_list, output_splits):
dst.copy_(src)
return tensor_list