# Owner(s): ["module: inductor"] import unittest from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.logging import torch._dynamo.test_case # for some reason importing functional collectives after dynamo breaks collectives handling! import torch.distributed._functional_collectives as _functional_collectives from torch._C import FileCheck from torch._dynamo.utils import same from torch._inductor import ir, scheduler from torch._inductor.comm_analysis import ( baseLat, hwLat, llMaxBws, NCCL_ALGO, NCCL_HW, NCCL_PROTO, NVIDIA_GPU_TYPE, ) from torch._inductor.utils import run_and_get_triton_code from torch.testing._internal.common_distributed import ( _dynamo_dist_per_rank_init, at_least_x_gpu, DynamoDistributedMultiProcTestCase, requires_accelerator_dist_backend, ) from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.inductor_utils import HAS_GPU device_type = str(get_devtype()) def get_snode_runtime_for_reorder_compute_test(snode): # NOTE: custom cost model to show that the compute reordering algorithm is working # Collective kernels if isinstance(snode.node, ir._CollectiveKernel): return 100 elif isinstance(snode.node, ir._WaitKernel): return 0 # High-arithmetic-intensity compute kernels elif isinstance(snode.node, ir.ExternKernel): return 5 # All other kernels return 1 def create_grouped_node_for_allreduce_and_its_deps(snodes): name_to_snode = {snode.node.name: snode for snode in snodes} all_reduce_snodes = [ snode for snode in snodes if isinstance(snode.node, ir._CollectiveKernel) and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default ] assert len(all_reduce_snodes) == 1 all_reduce_snode = all_reduce_snodes[0] all_reduce_dep_snodes = [ name_to_snode[node.name] for node in all_reduce_snode.node.inputs ] assert len(all_reduce_dep_snodes) == 1 all_reduce_dep_snode = all_reduce_dep_snodes[0] grouped_snode = scheduler.GroupedSchedulerNode.create( [all_reduce_dep_snode, all_reduce_snode] ) new_snode_order = [] new_snode_order.append(grouped_snode) for snode in snodes: if snode in grouped_snode.snodes: continue new_snode_order.append(snode) return new_snode_order @requires_accelerator_dist_backend() @unittest.skipIf( torch._inductor.config.triton.native_matmul, "native matmul is fused with surrounding ops", ) class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): """ Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under """ def get_world_trs(self): return { "tag": "", "ranks": list(range(self.world_size)), "group_size": self.world_size, } @property def world_size(self) -> int: # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2 # works around issue with skipif<2 and workers with unpredictable #s gpu return 2 @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_locality", False) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) @patch.object( torch._inductor.config, "reorder_for_compute_comm_overlap_passes", [ "sink_waits", ], ) def test_sink_waits(self): def func(a): ar = _functional_collectives.all_reduce(a, "sum", "0") b = torch.matmul(a, a) return torch.matmul(ar, b) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) # Verify that the wait_tensor is sinked below the 1st matmul but # above the 2nd matmul. ( FileCheck() .check("torch.ops._c10d_functional.all_reduce_.default") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) out = compiled(inputs) correct = func(inputs) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_locality", False) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) @patch.object( torch._inductor.config, "reorder_for_compute_comm_overlap_passes", [ "raise_comms", ], ) def test_raise_comms(self): def func(a): b = torch.matmul(a, a) c = torch.relu(b) d = torch.matmul(c, c) e = _functional_collectives.all_reduce(b, "sum", "0") return torch.matmul(d, e) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) # Verify that the all_reduce_ has been raised above the 2nd matmul # but below the 1st matmul. Note that the all_reduce_ directly # writes to the output buffer of the 1st matmul, which is an input # to the first relu. Therefore, the all_reduce_ should be scheduled # after the first relu. ( FileCheck() .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("torch.ops._c10d_functional.all_reduce_.default") .check_same("buf0") # mm not use buf prior to wait_tensor .check("extern_kernels.mm") .check_not("buf0") .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) out = compiled(inputs) correct = func(inputs) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) @patch.object( torch._inductor.config, "reorder_for_compute_comm_overlap_passes", [ "sink_waits", "raise_comms", ], ) def test_sink_waits_raise_comms(self): def func(a, *, tag, ranks, group_size): b = torch.matmul(a, a) c = torch.relu(b) d = torch.matmul(c, c) e = _functional_collectives.all_reduce(b, "sum", "0") f = torch.relu(d) g = torch.matmul(f, f) return torch.mm(e, g) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # Things to verify: # - The clone prologue of the all_reduce_ should not be fused with # any relus. # - The all_reduce_ and its prologue should be raised above the 2nd # matmul but below the 1st matmul. # - The wait_tensor should be sinked below the 3rd matmul but above # the 4th matmul. ( FileCheck() .check("extern_kernels.mm") .check("triton_poi_fused_all_reduce_0") .check("torch.ops._c10d_functional.all_reduce_.default") .check("triton_poi_fused_relu") .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) ) out = compiled(inputs, **self.get_world_trs()) correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) @patch.object( torch._inductor.config, "reorder_for_compute_comm_overlap_passes", [ "reorder_compute_for_overlap", ], ) @patch.object( torch._inductor.config, "runtime_estimations_mms_benchmark", False, ) def test_reorder_compute_for_overlap(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) g = torch.matmul(a, a) c = torch.relu(a) d = torch.matmul(c, c) f = d * c * ar fr = _functional_collectives.all_reduce(f, "sum", ranks, tag) e = torch.matmul(d + ar + fr, g) return (e,) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: after scheduling the first all_reduce: # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce. # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce. # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce. # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce. ( FileCheck() .check("torch.ops._c10d_functional.all_reduce_.default") .check("triton_poi_fused_relu") .check("extern_kernels.mm") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_all_reduce_mul") .check("torch.ops._c10d_functional.all_reduce_.default") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_add") .check("extern_kernels.mm") .run(code) ) out = compiled(inputs, **self.get_world_trs()) correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) @patch.object( torch._inductor.config, "reorder_for_compute_comm_overlap_passes", [ "reorder_compute_for_overlap", ], ) @patch.object( torch._inductor.config, "estimate_op_runtime", get_snode_runtime_for_reorder_compute_test, ) def test_reorder_compute_for_overlap_custom_runtime_estimation(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) g = torch.matmul(a, a) c = torch.relu(a) d = torch.matmul(c, c) f = d * c * ar fr = _functional_collectives.all_reduce(f, "sum", ranks, tag) e = torch.matmul(d + ar + fr, g) return (e,) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: after scheduling the first all_reduce: # 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce. # 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce. # 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce. # and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce. ( FileCheck() .check("torch.ops._c10d_functional.all_reduce_.default") .check("triton_poi_fused_relu") .check("extern_kernels.mm") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_all_reduce_mul") .check("torch.ops._c10d_functional.all_reduce_.default") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_add") .check("extern_kernels.mm") .run(code) ) out = compiled(inputs, **self.get_world_trs()) correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf( torch._inductor.config.triton.native_matmul, "native matmul is fused with surrounding ops", ) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object( torch._inductor.config, "_pre_fusion_custom_pass", create_grouped_node_for_allreduce_and_its_deps, ) def test_grouped_scheduler_node(self): def func(a, *, tag, ranks, group_size): add = a + a div = add / a ar = _functional_collectives.all_reduce(div, "sum", ranks, tag) # Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op, # but here in this unit test, we intentionally put `add`, `div` and `ar` computation # into a GroupedSchedulerNode, which prevents them from being fused with any other ops. mul = a * a mm = torch.matmul(mul, ar) return (mm,) with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=not at_least_x_gpu(2), ): inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # Expectations: # 1. `add = a + a` and `div = add / a` are still fused, which means fusion # still happens among nodes within a GroupedSchedulerNode. # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within # GroupedSchedulerNode and thus are prevented from being fused with any outside ops. FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check( "_c10d_functional.all_reduce_." ).check("triton_poi_fused_mul_1.").run(code) out = compiled(inputs, **self.get_world_trs()) correct = func(inputs, **self.get_world_trs()) self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(force_disable_caches=True) def test_inductor_default_comms_ordering(self): pg_info = self.get_world_trs() tag = pg_info["tag"] ranks = pg_info["ranks"] group_size = pg_info["group_size"] g1 = torch.ones(10, 10, device=device_type) g2 = torch.ones(11, 11, device=device_type) g3 = torch.ones(12, 12, device=device_type) def assert_pass(graph): # all_reduces need to remain in order! self.assertExpectedInline( graph, """\ graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg0_1, avg, 0), kwargs = {}) %all_reduce_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg1_1, avg, 0), kwargs = {}) %all_reduce_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%arg2_1, avg, 0), kwargs = {}) %wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_2,), kwargs = {}) %wait_tensor_1 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce_1,), kwargs = {}) %wait_tensor_2 : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {}) return (wait_tensor, wait_tensor_1, wait_tensor_2)""", # noqa: B950 ) torch._inductor.config.post_grad_custom_post_pass = assert_pass @torch.compile def fn(g1, g2, g3): handle1 = torch.ops.c10d_functional.all_reduce( g1, "avg", tag, ranks, group_size ) handle2 = torch.ops.c10d_functional.all_reduce( g2, "avg", tag, ranks, group_size ) handle3 = torch.ops.c10d_functional.all_reduce( g3, "avg", tag, ranks, group_size ) # wait on them in a different order grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3) grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2) grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1) return grad3, grad2, grad1 with _dynamo_dist_per_rank_init( self.rank, self.world_size, self.backend(device_type), fake_pg=True ): fn(g1, g2, g3) def test_nccl_heuristics(self): assert len(baseLat) == len(NCCL_ALGO) assert all(len(x) == len(NCCL_PROTO) for x in baseLat) assert len(hwLat) == len(NCCL_HW) assert all(len(x) == len(NCCL_ALGO) for x in hwLat) assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x) assert len(llMaxBws) == len(NVIDIA_GPU_TYPE) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()