[Traceable FSDP2][Inductor] Apply compute/comm reordering passes to achieve overlap (#131614)

This PR enables the Inductor compute/comm reordering passes to Traceable FSDP2 to achieve overlap. Note that the overlap is not maximally optimized yet and the follow-up work will be done in subsequent PRs.

Test commands:
- `pytest -rA  test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131614
Approved by: https://github.com/yifuwang
ghstack dependencies: #131510
This commit is contained in:
Will Feng
2024-07-26 21:12:39 -07:00
committed by PyTorch MergeBot
parent 9e06572704
commit aee6bcdba4
4 changed files with 250 additions and 21 deletions

View File

@ -1547,6 +1547,26 @@ def is_wait(node):
return type(node) == ir._WaitKernel
def contains_collective(snode):
from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
assert isinstance(snode, BaseSchedulerNode)
if isinstance(snode, GroupedSchedulerNode):
return any(contains_collective(x) for x in snode.snodes)
else:
return is_collective(snode.node)
def contains_wait(snode):
from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
assert isinstance(snode, BaseSchedulerNode)
if isinstance(snode, GroupedSchedulerNode):
return any(contains_wait(x) for x in snode.snodes)
else:
return is_wait(snode.node)
def is_fallback_op(node, op):
from . import ir