Fix torch.distributed._functional_collectives.AsyncCollectiveTensor for aten.to. (#134661)

Fixes #133421

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134661
Approved by: https://github.com/bdhirsh
This commit is contained in:
PHLens
2025-01-04 02:33:36 +00:00
committed by PyTorch MergeBot
parent 7e3cd0e488
commit 98949df7a4

View File

@ -435,6 +435,12 @@ def reduce_scatter_tensor_coalesced(
# Today, this maps 1:1 with "aten ops that are views".
def _is_view_op(tgt):
assert isinstance(tgt, torch._ops.OpOverload)
# Don't apply the view optimization to any `CompositeImplicitAutograd` ops.
# See issue: https://github.com/pytorch/pytorch/issues/133421
if torch._C._dispatch_has_kernel_for_dispatch_key(
tgt.name(), torch.DispatchKey.CompositeImplicitAutograd
):
return False
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]