mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] do comm compute overlap at aten fx level (#163215)
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
c39357bab6
commit
0d7994ca97
@ -26,6 +26,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
|
|||||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
|
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
|
||||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
||||||
time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering
|
time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering
|
||||||
|
time python test/run_test.py --verbose -i distributed/test_aten_comm_compute_reordering
|
||||||
time python test/run_test.py --verbose -i distributed/test_store
|
time python test/run_test.py --verbose -i distributed/test_store
|
||||||
time python test/run_test.py --verbose -i distributed/test_symmetric_memory
|
time python test/run_test.py --verbose -i distributed/test_symmetric_memory
|
||||||
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
||||||
|
@ -435,7 +435,7 @@ test_inductor_distributed() {
|
|||||||
|
|
||||||
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
|
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
|
||||||
# with if required # gpus aren't available
|
# with if required # gpus aren't available
|
||||||
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_compute_comm_reordering --verbose
|
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_aten_comm_compute_reordering distributed/test_compute_comm_reordering --verbose
|
||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
}
|
}
|
||||||
|
|
||||||
|
353
test/distributed/test_aten_comm_compute_reordering.py
Normal file
353
test/distributed/test_aten_comm_compute_reordering.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
# flake8: noqa: B950
|
||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._dynamo
|
||||||
|
import torch._dynamo.logging
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
|
||||||
|
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||||
|
import torch.distributed._functional_collectives as _functional_collectives
|
||||||
|
from torch._C import FileCheck
|
||||||
|
from torch._dynamo.utils import counters, same
|
||||||
|
from torch._inductor.utils import run_and_get_triton_code
|
||||||
|
from torch.testing._internal.common_distributed import (
|
||||||
|
_dynamo_dist_per_rank_init,
|
||||||
|
at_least_x_gpu,
|
||||||
|
DynamoDistributedMultiProcTestCase,
|
||||||
|
requires_accelerator_dist_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
|
from torch.testing._internal.common_utils import skipIfRocm
|
||||||
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_aten_runtime(fx_node):
|
||||||
|
# for tests, assume a matmul can hide a single collective
|
||||||
|
if "c10" in str(fx_node.target):
|
||||||
|
return 1.0
|
||||||
|
elif fx_node.target == aten.mm.default:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
device_type = str(get_devtype())
|
||||||
|
|
||||||
|
|
||||||
|
def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||||
|
gm = graph.owning_module
|
||||||
|
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
|
||||||
|
|
||||||
|
schedule_overlap_bucketing(gm)
|
||||||
|
gm.graph.lint()
|
||||||
|
out_li.append(str(gm.graph))
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_get_aten_graph(fn, *inputs):
|
||||||
|
li = []
|
||||||
|
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
|
||||||
|
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
|
||||||
|
out = fn(*inputs)
|
||||||
|
|
||||||
|
return out, li[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_patches():
|
||||||
|
return {
|
||||||
|
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
|
||||||
|
"reorder_for_locality": False,
|
||||||
|
"reorder_for_compute_comm_overlap_passes": [],
|
||||||
|
"compile_threads": 1,
|
||||||
|
"force_disable_caches": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@requires_accelerator_dist_backend()
|
||||||
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
"""
|
||||||
|
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||||
|
|
||||||
|
Note: these tests are a fork of test/distributed/test_compute_comm_reordering.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
torch._dynamo.reset()
|
||||||
|
torch._dynamo.utils.counters.clear()
|
||||||
|
|
||||||
|
def get_world_trs(self):
|
||||||
|
return {
|
||||||
|
"tag": "",
|
||||||
|
"ranks": list(range(self.world_size)),
|
||||||
|
"group_size": self.world_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
|
||||||
|
# works around issue with skipif<2 and workers with unpredictable #s gpu
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_sink_waits(self):
|
||||||
|
def func(a):
|
||||||
|
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||||
|
b = torch.matmul(a, a)
|
||||||
|
return torch.matmul(ar, b)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
|
):
|
||||||
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
|
|
||||||
|
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
|
||||||
|
|
||||||
|
# Verify that the wait_tensor is sinked below the 1st matmul but
|
||||||
|
# above the 2nd matmul.
|
||||||
|
(
|
||||||
|
FileCheck()
|
||||||
|
.check("all_reduce.default")
|
||||||
|
.check("aten.mm.default")
|
||||||
|
.check("wait_tensor.default")
|
||||||
|
.check("aten.mm.default")
|
||||||
|
.run(aten_graph_str)
|
||||||
|
)
|
||||||
|
correct = func(inputs)
|
||||||
|
self.assertTrue(same(out, correct))
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_raise_comms(self):
|
||||||
|
def func(a):
|
||||||
|
b = torch.matmul(a, a)
|
||||||
|
c = torch.relu(b)
|
||||||
|
d = torch.matmul(c, c)
|
||||||
|
e = _functional_collectives.all_reduce((b + 1), "sum", "0")
|
||||||
|
return torch.matmul(d, e)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
|
):
|
||||||
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
|
compiled = torch.compile(func)
|
||||||
|
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
|
||||||
|
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
||||||
|
# but below the 1st matmul. Note that the all_reduce_ directly
|
||||||
|
# writes to the output buffer of the 1st matmul, which is an input
|
||||||
|
# to the first relu. Therefore, the all_reduce_ should be scheduled
|
||||||
|
# after the first relu.
|
||||||
|
(
|
||||||
|
FileCheck()
|
||||||
|
.check("aten.mm")
|
||||||
|
.check("all_reduce.default")
|
||||||
|
.check("aten.mm")
|
||||||
|
.check("wait_tensor.default")
|
||||||
|
.check("aten.mm")
|
||||||
|
.run(aten_graph_str)
|
||||||
|
)
|
||||||
|
out = compiled(inputs)
|
||||||
|
correct = func(inputs)
|
||||||
|
self.assertTrue(same(out, correct))
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_sink_waits_raise_comms(self):
|
||||||
|
def func(a, *, tag, ranks, group_size):
|
||||||
|
b = torch.matmul(a, a)
|
||||||
|
c = torch.relu(b)
|
||||||
|
d = torch.matmul(c, c)
|
||||||
|
e = _functional_collectives.all_reduce(b, "sum", "0")
|
||||||
|
f = torch.relu(d)
|
||||||
|
g = torch.matmul(f, f)
|
||||||
|
return torch.mm(e, g)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
|
):
|
||||||
|
inputs = torch.ones(
|
||||||
|
4, 4, dtype=torch.float, device=device_type
|
||||||
|
) # + self.rank
|
||||||
|
kwargs = self.get_world_trs()
|
||||||
|
func = functools.partial(func, **kwargs)
|
||||||
|
compiled = torch.compile(func)
|
||||||
|
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||||
|
# Things to verify:
|
||||||
|
# - The all_reduce_ and its prologue should be raised above the 2nd
|
||||||
|
# matmul but below the 1st matmul.
|
||||||
|
# - The wait_tensor should be sinked below the 3rd matmul but above
|
||||||
|
# the 4th matmul.
|
||||||
|
|
||||||
|
self.assertExpectedInline(
|
||||||
|
aten_graph_str,
|
||||||
|
"""\
|
||||||
|
graph():
|
||||||
|
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||||
|
%mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg0_1), kwargs = {})
|
||||||
|
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
|
||||||
|
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%mm, sum, 0), kwargs = {})
|
||||||
|
%mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu, %relu), kwargs = {})
|
||||||
|
%relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm_1,), kwargs = {})
|
||||||
|
%mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu_1, %relu_1), kwargs = {})
|
||||||
|
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
|
||||||
|
%mm_3 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%wait_tensor, %mm_2), kwargs = {})
|
||||||
|
return (mm_3,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: this triggered an all_reduce_ bug
|
||||||
|
correct = func(inputs, **self.get_world_trs())
|
||||||
|
self.assertTrue(same(out, correct))
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
|
||||||
|
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_reorder_compute_for_overlap_mul(self):
|
||||||
|
def func(a, *, tag, ranks, group_size):
|
||||||
|
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
||||||
|
g = torch.matmul(a, a)
|
||||||
|
c = torch.relu(a)
|
||||||
|
d = torch.matmul(c, c)
|
||||||
|
f = d * c * ar
|
||||||
|
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
|
||||||
|
e = torch.matmul(d + ar + fr, g)
|
||||||
|
return (e,)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
|
):
|
||||||
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
|
func_c = functools.partial(func, **self.get_world_trs())
|
||||||
|
compiled = torch.compile(func_c)
|
||||||
|
out_c, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||||
|
# Note: because we have given collectives and mms equal estimation,
|
||||||
|
# we overlap each collective with a single mm.
|
||||||
|
# Same schedule as in test_reorder_compute_for_overlap_custom_runtime_estimation
|
||||||
|
# although there is an exposed collective
|
||||||
|
(
|
||||||
|
FileCheck()
|
||||||
|
.check("all_reduce.default")
|
||||||
|
.check("aten.mm")
|
||||||
|
.check("aten.mm")
|
||||||
|
.check("wait_tensor.default")
|
||||||
|
.check("aten.mul")
|
||||||
|
.check("all_reduce.default")
|
||||||
|
.check("wait_tensor.default")
|
||||||
|
.check("aten.mm")
|
||||||
|
.run(aten_graph_str)
|
||||||
|
)
|
||||||
|
correct = func(inputs, **self.get_world_trs())
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 1)
|
||||||
|
self.assertTrue(same(out_c, correct))
|
||||||
|
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@skipIfRocm
|
||||||
|
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||||
|
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||||
|
@unittest.skipIf(True, "Logic not yet implemented")
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_grouped_scheduler_node(self):
|
||||||
|
def func(a, *, tag, ranks, group_size):
|
||||||
|
add = a + a
|
||||||
|
div = add / a
|
||||||
|
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
|
||||||
|
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
|
||||||
|
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
|
||||||
|
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
|
||||||
|
mul = a * a
|
||||||
|
mm = torch.matmul(mul, ar)
|
||||||
|
return (mm,)
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank,
|
||||||
|
self.world_size,
|
||||||
|
self.backend(device_type),
|
||||||
|
fake_pg=not at_least_x_gpu(2),
|
||||||
|
):
|
||||||
|
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||||
|
compiled = torch.compile(func)
|
||||||
|
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||||
|
# Expectations:
|
||||||
|
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
|
||||||
|
# still happens among nodes within a GroupedSchedulerNode.
|
||||||
|
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
|
||||||
|
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
|
||||||
|
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
|
||||||
|
"_c10d_functional.all_reduce_."
|
||||||
|
).check("triton_poi_fused_mul_1.").run(code)
|
||||||
|
out = compiled(inputs, **self.get_world_trs())
|
||||||
|
correct = func(inputs, **self.get_world_trs())
|
||||||
|
self.assertTrue(same(out, correct))
|
||||||
|
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@torch._inductor.config.patch(get_patches())
|
||||||
|
def test_inductor_default_comms_ordering(self):
|
||||||
|
pg_info = self.get_world_trs()
|
||||||
|
tag = pg_info["tag"]
|
||||||
|
ranks = pg_info["ranks"]
|
||||||
|
group_size = pg_info["group_size"]
|
||||||
|
|
||||||
|
g1 = torch.ones(10, 10, device=device_type)
|
||||||
|
g2 = torch.ones(11, 11, device=device_type)
|
||||||
|
g3 = torch.ones(12, 12, device=device_type)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def fn(g1, g2, g3):
|
||||||
|
handle1 = torch.ops.c10d_functional.all_reduce(
|
||||||
|
g1, "avg", tag, ranks, group_size
|
||||||
|
)
|
||||||
|
handle2 = torch.ops.c10d_functional.all_reduce(
|
||||||
|
g2, "avg", tag, ranks, group_size
|
||||||
|
)
|
||||||
|
handle3 = torch.ops.c10d_functional.all_reduce(
|
||||||
|
g3, "avg", tag, ranks, group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait on them in a different order
|
||||||
|
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
|
||||||
|
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
|
||||||
|
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
||||||
|
return grad3, grad2, grad1
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(
|
||||||
|
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
||||||
|
):
|
||||||
|
# all_reduces remain in order!
|
||||||
|
# note: this isnt actually invariant of pass currently..
|
||||||
|
# but we should keep collectives stable without reordering opportunities
|
||||||
|
|
||||||
|
_, code = run_and_get_aten_graph(fn, g1, g2, g3)
|
||||||
|
|
||||||
|
FileCheck().check("all_reduce").check_same("arg0_1").check(
|
||||||
|
"all_reduce"
|
||||||
|
).check_same("arg1_1").check("all_reduce").check_same("arg2_1").run(code)
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 3)
|
||||||
|
# these have no overlap opportunities
|
||||||
|
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.fx.operator_schemas import normalize_function
|
||||||
|
|
||||||
from . import ir
|
from . import ir
|
||||||
from .utils import get_dtype_size, snode_args_kwargs, sympy_product
|
from .utils import get_dtype_size, snode_args_kwargs, sympy_product
|
||||||
@ -43,11 +44,7 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
|||||||
return NVIDIA_GPU_TYPE.AMPERE
|
return NVIDIA_GPU_TYPE.AMPERE
|
||||||
|
|
||||||
|
|
||||||
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
|
||||||
if not isinstance(node, ir._CollectiveKernel):
|
|
||||||
raise ValueError(f"node is not a collective kernel: {node}")
|
|
||||||
|
|
||||||
kernel_name = node.python_kernel_name
|
|
||||||
assert kernel_name is not None
|
assert kernel_name is not None
|
||||||
if "all_reduce" in kernel_name:
|
if "all_reduce" in kernel_name:
|
||||||
return NCCL_COLL.ALL_REDUCE
|
return NCCL_COLL.ALL_REDUCE
|
||||||
@ -61,6 +58,15 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
|||||||
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
||||||
|
if not isinstance(node, ir._CollectiveKernel):
|
||||||
|
raise ValueError(f"node is not a collective kernel: {node}")
|
||||||
|
|
||||||
|
name = node.python_kernel_name
|
||||||
|
assert name is not None
|
||||||
|
return get_collective_type_from_kernel_name(name)
|
||||||
|
|
||||||
|
|
||||||
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
||||||
sz_bytes = 0
|
sz_bytes = 0
|
||||||
for inp in node.inputs: # type: ignore[attr-defined]
|
for inp in node.inputs: # type: ignore[attr-defined]
|
||||||
@ -210,7 +216,9 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]:
|
|||||||
return est_time_ms
|
return est_time_ms
|
||||||
|
|
||||||
|
|
||||||
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
def estimate_nccl_collective_runtime_impl(
|
||||||
|
tensor_storage_size_bytes: int, group_size: int, coll: NCCL_COLL
|
||||||
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Returns estimated NCCL collective runtime in milliseconds (ms).
|
Returns estimated NCCL collective runtime in milliseconds (ms).
|
||||||
|
|
||||||
@ -223,14 +231,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||||
- collective is one of: allreduce, reducescatter, allgather
|
- collective is one of: allreduce, reducescatter, allgather
|
||||||
"""
|
"""
|
||||||
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
|
||||||
# Convert bytes to GB
|
# Convert bytes to GB
|
||||||
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
||||||
|
|
||||||
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
||||||
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||||
num_gpus_per_node = 8
|
num_gpus_per_node = 8
|
||||||
group_size = get_collective_group_size(node)
|
|
||||||
nNodes = math.ceil(group_size / num_gpus_per_node)
|
nNodes = math.ceil(group_size / num_gpus_per_node)
|
||||||
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
||||||
|
|
||||||
@ -240,7 +246,6 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
# Assumes ring algorithm
|
# Assumes ring algorithm
|
||||||
nccl_algo = NCCL_ALGO.RING
|
nccl_algo = NCCL_ALGO.RING
|
||||||
nccl_proto = NCCL_PROTO.LL
|
nccl_proto = NCCL_PROTO.LL
|
||||||
coll = get_collective_type(node)
|
|
||||||
|
|
||||||
# =============== bandwidth computation ===============
|
# =============== bandwidth computation ===============
|
||||||
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
||||||
@ -318,3 +323,70 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
################################################################################################################
|
################################################################################################################
|
||||||
# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
|
||||||
################################################################################################################
|
################################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||||
|
"""
|
||||||
|
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||||
|
|
||||||
|
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
||||||
|
We aim to estimate the runtime as accurately as possible.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- only ring algorithm (NCCL_ALGO_RING) is used
|
||||||
|
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
|
||||||
|
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||||
|
- collective is one of: allreduce, reducescatter, allgather
|
||||||
|
"""
|
||||||
|
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
||||||
|
group_size = get_collective_group_size(node)
|
||||||
|
coll = get_collective_type(node)
|
||||||
|
return estimate_nccl_collective_runtime_impl(
|
||||||
|
tensor_storage_size_bytes, group_size, coll
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||||
|
size = 0
|
||||||
|
for node in fx_node.all_input_nodes:
|
||||||
|
if (t := node.meta.get("val")) is not None:
|
||||||
|
size += t.numel() * t.element_size()
|
||||||
|
|
||||||
|
# TODO - symbolic
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
|
||||||
|
"""
|
||||||
|
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||||
|
|
||||||
|
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
||||||
|
We aim to estimate the runtime as accurately as possible.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- only ring algorithm (NCCL_ALGO_RING) is used
|
||||||
|
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
|
||||||
|
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||||
|
- collective is one of: allreduce, reducescatter, allgather
|
||||||
|
"""
|
||||||
|
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||||
|
|
||||||
|
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
|
||||||
|
|
||||||
|
assert not isinstance(fx_node.target, str)
|
||||||
|
opt_args_kwargs = normalize_function(
|
||||||
|
fx_node.target,
|
||||||
|
args=fx_node.args,
|
||||||
|
kwargs=fx_node.kwargs,
|
||||||
|
normalize_to_only_use_kwargs=True,
|
||||||
|
)
|
||||||
|
assert opt_args_kwargs is not None
|
||||||
|
_, kwargs = opt_args_kwargs
|
||||||
|
|
||||||
|
group_size = _get_group_size_by_name(kwargs["group_name"])
|
||||||
|
assert isinstance(fx_node.target, torch._ops.OpOverload)
|
||||||
|
coll = get_collective_type_from_kernel_name(fx_node.target.name())
|
||||||
|
|
||||||
|
return estimate_nccl_collective_runtime_impl(
|
||||||
|
tensor_storage_size_bytes, group_size, coll
|
||||||
|
)
|
||||||
|
@ -2007,6 +2007,17 @@ class test_configs:
|
|||||||
# for unit testing
|
# for unit testing
|
||||||
use_libtorch = False
|
use_libtorch = False
|
||||||
|
|
||||||
|
# to be migrated when ready for use
|
||||||
|
aten_fx_overlap_scheduling = False
|
||||||
|
|
||||||
|
# to be migrated when ready for use
|
||||||
|
# runtime estimation function for ops
|
||||||
|
# for user-defined estimation function, pass in the function handle
|
||||||
|
# TODO - need estimated and profile based version
|
||||||
|
estimate_aten_runtime: Union[
|
||||||
|
Literal["default"], Callable[[torch.fx.Node], Optional[float]]
|
||||||
|
] = "default"
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils._config_typing import * # noqa: F401, F403
|
from torch.utils._config_typing import * # noqa: F401, F403
|
||||||
|
655
torch/_inductor/fx_passes/overlap_scheduling.py
Normal file
655
torch/_inductor/fx_passes/overlap_scheduling.py
Normal file
@ -0,0 +1,655 @@
|
|||||||
|
import functools
|
||||||
|
import heapq
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx as fx
|
||||||
|
from torch._dynamo.utils import counters, dynamo_timed
|
||||||
|
from torch.utils._mode_utils import no_dispatch
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from ..pattern_matcher import stable_topological_sort
|
||||||
|
|
||||||
|
|
||||||
|
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||||
|
return (
|
||||||
|
node.op == "call_function"
|
||||||
|
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||||
|
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
|
||||||
|
if runtime_estimation == "default":
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert callable(runtime_estimation)
|
||||||
|
return runtime_estimation(n)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_collective_time(n: fx.Node) -> float:
|
||||||
|
if (est := get_custom_estimation(n)) is not None:
|
||||||
|
return est
|
||||||
|
|
||||||
|
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
|
||||||
|
n
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||||
|
size = 0
|
||||||
|
for node in fx_node.all_input_nodes:
|
||||||
|
if (t := node.meta.get("val")) is not None:
|
||||||
|
# todo - symbolic
|
||||||
|
size += t.numel() * t.element_size()
|
||||||
|
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def is_compute_node(n: fx.Node) -> bool:
|
||||||
|
return (
|
||||||
|
getattr(n.target, "overloadpacket", None)
|
||||||
|
in torch.utils.flop_counter.flop_registry
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]:
|
||||||
|
if isinstance(x, int):
|
||||||
|
return x
|
||||||
|
assert isinstance(x, torch.SymInt)
|
||||||
|
if not x.node.has_hint():
|
||||||
|
return None
|
||||||
|
return x.node.hint
|
||||||
|
|
||||||
|
|
||||||
|
def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
|
||||||
|
with dynamo_timed("collective_compute_do_bench"):
|
||||||
|
return functools.partial(
|
||||||
|
torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
|
||||||
|
warmup=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_node(n: fx.Node) -> float:
|
||||||
|
if (est := get_custom_estimation(n)) is not None:
|
||||||
|
return est
|
||||||
|
|
||||||
|
from torch._dynamo.testing import rand_strided
|
||||||
|
|
||||||
|
# todo - skip unbacked, symbolic
|
||||||
|
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(n)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
unbacked_tensor = False
|
||||||
|
|
||||||
|
key = f"{str(n.target)}: "
|
||||||
|
|
||||||
|
def to_real(t: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
shape = [get_hint(dim) for dim in t.shape]
|
||||||
|
stride = [get_hint(s) for s in t.stride()]
|
||||||
|
|
||||||
|
if any(s is None for s in itertools.chain(shape, stride)):
|
||||||
|
nonlocal unbacked_tensor
|
||||||
|
unbacked_tensor = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
nonlocal key
|
||||||
|
key += f"T: {shape, stride, t.dtype} "
|
||||||
|
return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
with no_dispatch():
|
||||||
|
args, kwargs = torch.utils._pytree.tree_map_only(
|
||||||
|
torch.Tensor,
|
||||||
|
lambda t: to_real(t),
|
||||||
|
(args, kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
if val := get_cached_node_time(key):
|
||||||
|
return val
|
||||||
|
|
||||||
|
if unbacked_tensor:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
bench = get_collective_do_bench()
|
||||||
|
out = bench(lambda: n.target(*args, **kwargs)) # type: ignore[operator]
|
||||||
|
set_cached_node_time(key, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_benchmark_cache() -> torch._inductor.codecache.LocalCache:
|
||||||
|
return torch._inductor.codecache.LocalCache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_node_time(key: str) -> float:
|
||||||
|
return get_benchmark_cache().lookup(key) # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def set_cached_node_time(key: str, value: float) -> None:
|
||||||
|
return get_benchmark_cache().set_value(key, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CollectiveInfo:
|
||||||
|
"""Track info about a collective operation"""
|
||||||
|
|
||||||
|
start_node: fx.Node
|
||||||
|
wait_node: fx.Node
|
||||||
|
size_bytes: int
|
||||||
|
estimated_time_ms: float
|
||||||
|
exposed_time_ms: float # How much of this collective is still exposed
|
||||||
|
hiding_node: Optional[fx.Node] = None # Node that hides this collective
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_exposed(self) -> bool:
|
||||||
|
return self.exposed_time_ms != 0
|
||||||
|
|
||||||
|
|
||||||
|
class OverlapScheduler:
|
||||||
|
"""
|
||||||
|
Scheduler that reorders operations to maximize compute-collective overlap.
|
||||||
|
|
||||||
|
The reordering is done as a scheduling pass. We maintain a priority queue of
|
||||||
|
schedulable nodes. The nodes are ranked by:
|
||||||
|
|
||||||
|
1) the compute node depth they dominate. this allows reordering locally, such as with
|
||||||
|
parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward
|
||||||
|
with compute by deferring their waits.
|
||||||
|
|
||||||
|
2) whether the current node is a collective or wait that is currently exposed but has a compute
|
||||||
|
node which it could be overlapped with.
|
||||||
|
|
||||||
|
3) original order in the graph for stability.
|
||||||
|
|
||||||
|
When we schedule compute nodes, we first overlap exposed in-flight collectives, then look for unscheduled
|
||||||
|
collectives that can be scheduled concurrently.
|
||||||
|
|
||||||
|
TODO:
|
||||||
|
- experiment with other priority scores / allow other mechanisms of reorder / more strict adherence to original graph
|
||||||
|
- memory limit for deferred scheduling of reduce_scatter nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
max_in_flight_gb: float = 2.0,
|
||||||
|
compute_overlap_multipler: float = 1.0,
|
||||||
|
max_coll_distance: int = 1000,
|
||||||
|
):
|
||||||
|
self.gm = gm
|
||||||
|
self.graph = gm.graph
|
||||||
|
self.compute_overlap_multipler = compute_overlap_multipler
|
||||||
|
self.max_node_distance = max_coll_distance
|
||||||
|
self.max_in_flight_bytes: int = int(max_in_flight_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# Build structures
|
||||||
|
stable_topological_sort(self.graph)
|
||||||
|
self.nodes = list(self.graph.nodes)
|
||||||
|
self.node_idx = {n: i for i, n in enumerate(self.nodes)}
|
||||||
|
self.node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = (
|
||||||
|
self._collect_node_ancestors()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Identify collectives and compute nodes
|
||||||
|
self.collective_info: dict[fx.Node, CollectiveInfo] = {}
|
||||||
|
self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet()
|
||||||
|
|
||||||
|
self.wait_to_start: dict[fx.Node, fx.Node] = {}
|
||||||
|
self._identify_collectives()
|
||||||
|
|
||||||
|
self.compute_depth = self._calculate_compute_node_depth()
|
||||||
|
self.compute_nodes = [n for n in self.nodes if is_compute_node(n)]
|
||||||
|
|
||||||
|
# Scheduling state
|
||||||
|
self.potentially_hidden_collectives = (
|
||||||
|
self.compute_potential_hidden_collectives()
|
||||||
|
)
|
||||||
|
self.potentially_hidden_waits = self.compute_potential_hidden_waits()
|
||||||
|
self.in_degree = Counter(user for node in self.nodes for user in node.users)
|
||||||
|
self.ready: list[tuple[object, fx.Node]] = []
|
||||||
|
|
||||||
|
for node in self.nodes:
|
||||||
|
if self.in_degree[node] == 0:
|
||||||
|
heapq.heappush(self.ready, (self._compute_score(node), node))
|
||||||
|
|
||||||
|
self.in_flight: dict[fx.Node, CollectiveInfo] = {} # start -> info
|
||||||
|
self.in_flight_bytes = 0
|
||||||
|
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
|
||||||
|
|
||||||
|
def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||||
|
"""Collect all ancestors for each node."""
|
||||||
|
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||||
|
for node in self.nodes:
|
||||||
|
for input_node in node.all_input_nodes:
|
||||||
|
ancestors[node].add(input_node)
|
||||||
|
ancestors[node] |= ancestors[input_node]
|
||||||
|
|
||||||
|
return ancestors
|
||||||
|
|
||||||
|
def _identify_collectives(self) -> None:
|
||||||
|
"""Identify all collective operations."""
|
||||||
|
for node in self.nodes:
|
||||||
|
if is_wait_tensor(node):
|
||||||
|
start = node.args[0]
|
||||||
|
coll_time_ms = estimate_collective_time(start)
|
||||||
|
|
||||||
|
info = CollectiveInfo(
|
||||||
|
start_node=start,
|
||||||
|
wait_node=node,
|
||||||
|
size_bytes=estimate_fx_collective_size(start),
|
||||||
|
estimated_time_ms=coll_time_ms,
|
||||||
|
exposed_time_ms=coll_time_ms, # Initially fully exposed
|
||||||
|
)
|
||||||
|
self.collective_info[start] = info
|
||||||
|
self.wait_to_start[node] = start
|
||||||
|
self.unscheduled_collectives.add(start)
|
||||||
|
|
||||||
|
def _calculate_compute_node_depth(self) -> dict[fx.Node, int]:
|
||||||
|
"""Compute forward depth and minimum dominance depth (infinity if blocks no compute)."""
|
||||||
|
|
||||||
|
# First pass: forward compute depth
|
||||||
|
in_degree: dict[fx.Node, int] = {}
|
||||||
|
compute_depth: dict[fx.Node, int] = {}
|
||||||
|
queue: list[fx.Node] = []
|
||||||
|
|
||||||
|
for node in self.graph.nodes:
|
||||||
|
num_inputs = len(node.all_input_nodes)
|
||||||
|
if num_inputs == 0:
|
||||||
|
queue.append(node)
|
||||||
|
else:
|
||||||
|
in_degree[node] = num_inputs
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
node = queue.pop()
|
||||||
|
|
||||||
|
max_input_depth = max(
|
||||||
|
(compute_depth[inp] for inp in node.all_input_nodes), default=0
|
||||||
|
)
|
||||||
|
compute_depth[node] = max_input_depth + is_compute_node(node)
|
||||||
|
|
||||||
|
for use in node.users:
|
||||||
|
in_degree[use] -= 1
|
||||||
|
if in_degree[use] == 0:
|
||||||
|
queue.append(use)
|
||||||
|
|
||||||
|
# Second pass: minimum dominance (what's the earliest compute this blocks)
|
||||||
|
compute_depth_dominance: dict[fx.Node, int] = {}
|
||||||
|
|
||||||
|
for node in reversed(self.graph.nodes):
|
||||||
|
if is_compute_node(node):
|
||||||
|
# consider compute nodes to be at their own depth
|
||||||
|
dominance = compute_depth[node]
|
||||||
|
else:
|
||||||
|
# For non-compute nodes, find minimum compute they block
|
||||||
|
dominance = min(
|
||||||
|
(compute_depth_dominance[succ] for succ in node.users),
|
||||||
|
default=sys.maxsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
compute_depth_dominance[node] = dominance
|
||||||
|
|
||||||
|
return compute_depth_dominance
|
||||||
|
|
||||||
|
def run(self) -> torch.fx.GraphModule:
|
||||||
|
"""Run the scheduling algorithm."""
|
||||||
|
|
||||||
|
while self.ready:
|
||||||
|
if self._should_force_wait_for_memory():
|
||||||
|
self._force_oldest_wait()
|
||||||
|
continue
|
||||||
|
|
||||||
|
_, node = heapq.heappop(self.ready)
|
||||||
|
|
||||||
|
# we don't always remove nodes from the heap when we schedule them
|
||||||
|
if node in self.scheduled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_compute_node(node):
|
||||||
|
self._handle_compute(node)
|
||||||
|
elif node in self.collective_info:
|
||||||
|
self._handle_collective_start(node)
|
||||||
|
elif is_wait_tensor(node):
|
||||||
|
self._handle_wait(node)
|
||||||
|
else:
|
||||||
|
self._handle_other(node)
|
||||||
|
|
||||||
|
self._reorder_graph()
|
||||||
|
return self.gm
|
||||||
|
|
||||||
|
def _handle_other(self, node: fx.Node) -> None:
|
||||||
|
self._schedule(node)
|
||||||
|
|
||||||
|
def _schedule(self, node: fx.Node) -> None:
|
||||||
|
"""Schedule a node."""
|
||||||
|
assert node not in self.scheduled
|
||||||
|
assert all(n in self.scheduled for n in node.all_input_nodes)
|
||||||
|
self.scheduled.add(node)
|
||||||
|
|
||||||
|
for user in node.users:
|
||||||
|
self.in_degree[user] -= 1
|
||||||
|
if self.in_degree[user] == 0:
|
||||||
|
heapq.heappush(self.ready, (self._compute_score(user), user))
|
||||||
|
|
||||||
|
def _compute_score(self, node: fx.Node) -> object:
|
||||||
|
"""Compute priority score for a node"""
|
||||||
|
|
||||||
|
if is_wait_tensor(node):
|
||||||
|
info = self.collective_info[self.wait_to_start[node]]
|
||||||
|
# TODO: we could consider even deferring waits that are not potentially hidden
|
||||||
|
# so as to overlap comm with itself. although exposed comms should bucketed with each other.
|
||||||
|
overlappable = info.is_exposed and node in self.potentially_hidden_waits
|
||||||
|
else:
|
||||||
|
overlappable = self.in_overlappable_collective_unary_chain(node)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.compute_depth[node], # what depth compute it blocks
|
||||||
|
overlappable, # Defer hideable collective ops
|
||||||
|
self.node_idx[node], # Original order for stability
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cheap_fn(node: fx.Node) -> bool:
|
||||||
|
return getattr(node.target, "is_view", False) or torch.Tag.pointwise in getattr(
|
||||||
|
node.target, "tags", ()
|
||||||
|
)
|
||||||
|
|
||||||
|
def in_overlappable_collective_unary_chain(self, curr: fx.Node) -> bool:
|
||||||
|
while True:
|
||||||
|
if len(curr.users) != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user = next(iter(curr.users))
|
||||||
|
if len(user.all_input_nodes) != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if user in self.unscheduled_collectives:
|
||||||
|
return user in self.potentially_hidden_collectives
|
||||||
|
|
||||||
|
if not self.is_cheap_fn(user):
|
||||||
|
return False
|
||||||
|
|
||||||
|
curr = user
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _should_force_wait_for_memory(self) -> bool:
|
||||||
|
"""Check if we need to force a wait due to memory pressure"""
|
||||||
|
return self.in_flight_bytes >= self.max_in_flight_bytes
|
||||||
|
|
||||||
|
def _force_oldest_wait(self) -> None:
|
||||||
|
"""Schedule the oldest in flight wait"""
|
||||||
|
self._handle_wait(self._get_oldest_wait())
|
||||||
|
|
||||||
|
def _handle_collective_start(self, node: fx.Node) -> None:
|
||||||
|
"""Handle scheduling a collective start."""
|
||||||
|
info = self.collective_info[node]
|
||||||
|
self.in_flight[node] = info
|
||||||
|
self.in_flight_bytes += info.size_bytes
|
||||||
|
self.unscheduled_collectives.discard(node)
|
||||||
|
self._schedule(node)
|
||||||
|
|
||||||
|
def _handle_wait(self, node: fx.Node) -> None:
|
||||||
|
"""Handle scheduling a wait."""
|
||||||
|
assert node in self.wait_to_start
|
||||||
|
coll_start = self.wait_to_start[node]
|
||||||
|
|
||||||
|
assert coll_start in self.in_flight
|
||||||
|
self.in_flight_bytes -= self.in_flight[coll_start].size_bytes
|
||||||
|
del self.in_flight[coll_start]
|
||||||
|
self._schedule(node)
|
||||||
|
|
||||||
|
def _handle_compute(self, node: fx.Node) -> None:
|
||||||
|
"""Handle scheduling compute and finding overlaps."""
|
||||||
|
|
||||||
|
compute_time = benchmark_node(node)
|
||||||
|
available_compute = compute_time * self.compute_overlap_multipler
|
||||||
|
|
||||||
|
# First reduce exposed time of in-flight collectives
|
||||||
|
for info in self.in_flight.values():
|
||||||
|
if info.exposed_time_ms == 0:
|
||||||
|
continue
|
||||||
|
overlap_amount = min(info.exposed_time_ms, available_compute)
|
||||||
|
info.exposed_time_ms -= overlap_amount
|
||||||
|
available_compute -= overlap_amount
|
||||||
|
if info.exposed_time_ms == 0:
|
||||||
|
info.hiding_node = node
|
||||||
|
elif available_compute == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Then, look for unscheduled collectives we can overlap
|
||||||
|
if available_compute:
|
||||||
|
self._schedule_collectives_for_overlap(node, available_compute)
|
||||||
|
|
||||||
|
self._schedule(node)
|
||||||
|
|
||||||
|
def _schedule_collectives_for_overlap(
|
||||||
|
self, compute_node: fx.Node, available_compute_time: float
|
||||||
|
) -> None:
|
||||||
|
"""Opportunistically schedule collectives that can be hidden by compute."""
|
||||||
|
compute_ancestors = self.node_ancestors[compute_node]
|
||||||
|
|
||||||
|
# copy unscheduled_collectives to local because we modify it during iteration
|
||||||
|
possible_collectives = []
|
||||||
|
for collective in self.unscheduled_collectives:
|
||||||
|
distance = abs(self.node_idx[compute_node] - self.node_idx[collective])
|
||||||
|
if distance > self.max_node_distance:
|
||||||
|
break
|
||||||
|
|
||||||
|
possible_collectives.append(collective)
|
||||||
|
|
||||||
|
for collective in possible_collectives:
|
||||||
|
if available_compute_time == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
info = self.collective_info[collective]
|
||||||
|
|
||||||
|
# Skip if compute depends on collective or vice versa
|
||||||
|
if (
|
||||||
|
collective in compute_ancestors
|
||||||
|
or compute_node in self.node_ancestors[collective]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
while (
|
||||||
|
self.in_flight
|
||||||
|
and (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes
|
||||||
|
and self._wait_is_hidden(self._get_oldest_wait(), compute_node)
|
||||||
|
):
|
||||||
|
self._force_oldest_wait()
|
||||||
|
|
||||||
|
if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we can reach this collective without scheduling compute, other collectives, or waits
|
||||||
|
path = self._find_schedulable_path(collective, compute_node)
|
||||||
|
if path is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Schedule path to this collective
|
||||||
|
self._schedule_path_to_collective(path, compute_node)
|
||||||
|
# Update the exposed time for this newly scheduled collective
|
||||||
|
overlap_amount = min(info.estimated_time_ms, available_compute_time)
|
||||||
|
info.exposed_time_ms -= overlap_amount
|
||||||
|
if info.exposed_time_ms == 0:
|
||||||
|
info.hiding_node = compute_node
|
||||||
|
available_compute_time -= overlap_amount
|
||||||
|
self._handle_collective_start(collective)
|
||||||
|
|
||||||
|
def _find_schedulable_path(
|
||||||
|
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
|
||||||
|
) -> Optional[OrderedSet[fx.Node]]:
|
||||||
|
"""Find path to target by collecting unscheduled dependencies."""
|
||||||
|
|
||||||
|
# TODO - following path faster than doing set difference here
|
||||||
|
unscheduled_ancestors = self.node_ancestors[target] - self.scheduled
|
||||||
|
|
||||||
|
# only schedule non distributed, non compute nodes
|
||||||
|
for node in unscheduled_ancestors:
|
||||||
|
if is_compute_node(node):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if node in self.unscheduled_collectives:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# if we schedule a wait tensor whose start collective is hidden by the
|
||||||
|
# current compute node we are scheduling, then we are effectively exposing it.
|
||||||
|
# similarly, dont schedule a wait of a collective that could be otherwise hidden,
|
||||||
|
# thus forcing it to be exposed.
|
||||||
|
# however, if it is already hidden or it cannot be possible hidden,
|
||||||
|
# it's fine to schedule it
|
||||||
|
if is_wait_tensor(node):
|
||||||
|
info = self.collective_info[self.wait_to_start[node]]
|
||||||
|
if info.hiding_node and info.hiding_node != curr_compute_node:
|
||||||
|
continue
|
||||||
|
elif node not in self.potentially_hidden_waits:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
return unscheduled_ancestors
|
||||||
|
|
||||||
|
def _get_oldest_wait(self) -> fx.Node:
|
||||||
|
oldest_start = next(iter(self.in_flight))
|
||||||
|
return self.collective_info[oldest_start].wait_node
|
||||||
|
|
||||||
|
def _wait_is_hidden(
|
||||||
|
self, wait_node: fx.Node, compute_node: Optional[fx.Node] = None
|
||||||
|
) -> bool:
|
||||||
|
assert is_wait_tensor(wait_node)
|
||||||
|
info = self.collective_info[self.wait_to_start[wait_node]]
|
||||||
|
return not info.is_exposed and info.hiding_node != compute_node
|
||||||
|
|
||||||
|
def _schedule_path_to_collective(
|
||||||
|
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
|
||||||
|
) -> None:
|
||||||
|
"""Schedule all nodes needed to reach a collective."""
|
||||||
|
for node in sorted(path, key=lambda n: self.node_idx[n]):
|
||||||
|
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
|
||||||
|
|
||||||
|
if is_wait_tensor(node):
|
||||||
|
info = self.collective_info[self.wait_to_start[node]]
|
||||||
|
assert not info.hiding_node == curr_compute_node
|
||||||
|
self._handle_wait(node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._schedule(node)
|
||||||
|
|
||||||
|
def reorder_graph(self) -> None:
|
||||||
|
output_node = self.graph.output_node()
|
||||||
|
for node in self.scheduled:
|
||||||
|
if node.op == "placeholder":
|
||||||
|
continue
|
||||||
|
output_node.prepend(node)
|
||||||
|
self.graph.lint()
|
||||||
|
|
||||||
|
def _reorder_graph(self) -> None:
|
||||||
|
"""Reorder graph based on schedule."""
|
||||||
|
exposed = [
|
||||||
|
c
|
||||||
|
for c in self.collective_info.values()
|
||||||
|
if c.exposed_time_ms == c.estimated_time_ms
|
||||||
|
]
|
||||||
|
|
||||||
|
potentially_hidden_collectives = self.compute_potential_hidden_collectives(
|
||||||
|
limit_coll_per_compute=True
|
||||||
|
)
|
||||||
|
bad_exposed = [
|
||||||
|
c for c in exposed if c.start_node in potentially_hidden_collectives
|
||||||
|
]
|
||||||
|
|
||||||
|
counters["inductor"]["overlap_scheduling_exposed"] += len(exposed)
|
||||||
|
counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed)
|
||||||
|
counters["inductor"]["overlap_scheduling_potentially_hidden"] += len(
|
||||||
|
potentially_hidden_collectives
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
|
||||||
|
len(exposed),
|
||||||
|
len(bad_exposed),
|
||||||
|
len(potentially_hidden_collectives),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reorder_graph()
|
||||||
|
|
||||||
|
def compute_potential_hidden_nodes(
|
||||||
|
self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False
|
||||||
|
) -> dict[fx.Node, fx.Node]:
|
||||||
|
"""
|
||||||
|
Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node
|
||||||
|
"""
|
||||||
|
|
||||||
|
used_compute_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||||
|
|
||||||
|
def could_be_hidden(start: fx.Node) -> Optional[fx.Node]:
|
||||||
|
for compute_node in self.compute_nodes:
|
||||||
|
if limit_coll_per_compute and compute_node in used_compute_nodes:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
start not in self.node_ancestors[compute_node]
|
||||||
|
and compute_node not in self.node_ancestors[start]
|
||||||
|
):
|
||||||
|
if limit_coll_per_compute:
|
||||||
|
used_compute_nodes.add(compute_node)
|
||||||
|
return compute_node
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: We could potentially limit compute nodes per overlap time,
|
||||||
|
# today, this is optimistic, and just serves to avoid deferring
|
||||||
|
# collectives/waits that have no possible overlap as well as for analysis of how
|
||||||
|
# successfully we hid compute
|
||||||
|
potentially_hidden = {}
|
||||||
|
for node in nodes_to_check:
|
||||||
|
if mm := could_be_hidden(node):
|
||||||
|
potentially_hidden[node] = mm
|
||||||
|
|
||||||
|
return potentially_hidden
|
||||||
|
|
||||||
|
def compute_potential_hidden_collectives(
|
||||||
|
self, limit_coll_per_compute: bool = False
|
||||||
|
) -> dict[fx.Node, fx.Node]:
|
||||||
|
"""Compute which collective operations could be hidden by compute."""
|
||||||
|
return self.compute_potential_hidden_nodes(
|
||||||
|
self.collective_info.keys(), limit_coll_per_compute
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_potential_hidden_waits(
|
||||||
|
self, limit_coll_per_compute: bool = False
|
||||||
|
) -> dict[fx.Node, fx.Node]:
|
||||||
|
"""Compute which wait operations could be hidden by compte."""
|
||||||
|
wait_nodes = [info.wait_node for info in self.collective_info.values()]
|
||||||
|
return self.compute_potential_hidden_nodes(wait_nodes, limit_coll_per_compute)
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_overlap_bucketing(
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
max_in_flight_gb: float = 2.0,
|
||||||
|
compute_overlap_multipler: float = 1.0,
|
||||||
|
max_coll_distance: int = 1000,
|
||||||
|
) -> torch.fx.GraphModule:
|
||||||
|
"""Schedule nodes to maximize compute-collective overlap.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm: Input graph module to optimize.
|
||||||
|
max_in_flight_gb: Maximum GB of concurrent collective data.
|
||||||
|
compute_overlap_multipler: Scale factor for compute time used to hide collectives.
|
||||||
|
max_coll_distance: Maximum node distance for overlap consideration.
|
||||||
|
"""
|
||||||
|
return OverlapScheduler(
|
||||||
|
gm,
|
||||||
|
compute_overlap_multipler=compute_overlap_multipler,
|
||||||
|
max_in_flight_gb=max_in_flight_gb,
|
||||||
|
max_coll_distance=max_coll_distance,
|
||||||
|
).run()
|
@ -202,6 +202,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||||||
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
|
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
|
||||||
|
|
||||||
collectives_bucketing: bool = False
|
collectives_bucketing: bool = False
|
||||||
|
|
||||||
if config.bucket_reduce_scatters_fx != "none":
|
if config.bucket_reduce_scatters_fx != "none":
|
||||||
from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
|
from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
|
||||||
from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter
|
from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter
|
||||||
|
@ -9112,6 +9112,7 @@ class _CollectiveKernel(FallbackKernel):
|
|||||||
assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
|
assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
|
||||||
for tensor_arg in tensor_args:
|
for tensor_arg in tensor_args:
|
||||||
tensor_arg.realize()
|
tensor_arg.realize()
|
||||||
|
V.graph.mark_buffer_mutated(tensor_arg.get_name())
|
||||||
|
|
||||||
device = tensor_args[0].get_device()
|
device = tensor_args[0].get_device()
|
||||||
packed = cls(
|
packed = cls(
|
||||||
|
Reference in New Issue
Block a user