mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Traceable FSDP2][Inductor] Apply compute/comm reordering passes to achieve overlap (#131614)
This PR enables the Inductor compute/comm reordering passes to Traceable FSDP2 to achieve overlap. Note that the overlap is not maximally optimized yet and the follow-up work will be done in subsequent PRs. Test commands: - `pytest -rA test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131614 Approved by: https://github.com/yifuwang ghstack dependencies: #131510
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e06572704
commit
aee6bcdba4
@ -1547,6 +1547,26 @@ def is_wait(node):
|
||||
return type(node) == ir._WaitKernel
|
||||
|
||||
|
||||
def contains_collective(snode):
|
||||
from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
|
||||
|
||||
assert isinstance(snode, BaseSchedulerNode)
|
||||
if isinstance(snode, GroupedSchedulerNode):
|
||||
return any(contains_collective(x) for x in snode.snodes)
|
||||
else:
|
||||
return is_collective(snode.node)
|
||||
|
||||
|
||||
def contains_wait(snode):
|
||||
from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
|
||||
|
||||
assert isinstance(snode, BaseSchedulerNode)
|
||||
if isinstance(snode, GroupedSchedulerNode):
|
||||
return any(contains_wait(x) for x in snode.snodes)
|
||||
else:
|
||||
return is_wait(snode.node)
|
||||
|
||||
|
||||
def is_fallback_op(node, op):
|
||||
from . import ir
|
||||
|
||||
|
Reference in New Issue
Block a user