From ca7315c17162ea21b1ca5ba23f4bf6168766c7b9 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 11 Aug 2025 16:25:12 +0000 Subject: [PATCH] [Graph Partition] Pass all OSS unit tests (#154667) Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315). Run the same diff on two days and both show speedup on average. [first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d) image [second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf) image Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667 Approved by: https://github.com/eellison --- test/inductor/test_compiled_autograd.py | 22 +- test/inductor/test_control_flow.py | 3 + test/inductor/test_cuda_repro.py | 6 +- test/inductor/test_cudagraph_trees.py | 330 +++++++++++++++++++-- test/inductor/test_inductor_annotations.py | 7 +- test/inductor/test_torchinductor.py | 296 ------------------ torch/_inductor/codegen/wrapper.py | 10 +- torch/_inductor/config.py | 6 +- torch/_inductor/cudagraph_utils.py | 5 +- torch/_inductor/scheduler.py | 11 +- torch/_inductor/utils.py | 7 + 11 files changed, 378 insertions(+), 325 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 241528b159cc..dff94b4aa092 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3085,7 +3085,16 @@ main() self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + else: + expected_cudagraph_skips = 1 + + self.assertEqual( + counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips + ) @scoped_load_inline @requires_cuda_and_triton @@ -3150,9 +3159,18 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) + if inductor_config.graph_partition: + # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops + # and cudagraphify the remaining computation. So there is no cudagraph skip. + expected_cudagraph_skips = 0 + elif inductor_config.cpp_wrapper: + expected_cudagraph_skips = 2 + else: + expected_cudagraph_skips = 1 + self.assertEqual( counters["inductor"]["cudagraph_skips"], - 2 if inductor_config.cpp_wrapper else 1, + expected_cudagraph_skips, ) def test_logs(self): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 107a65d6fa1d..511b9cea5e14 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -472,6 +472,9 @@ class CondTests(TestCase): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) + # TODO: graph partition does not support creating tensor + # with dynamic shape in conditional subgraph yet + @torch._inductor.config.patch(graph_partition=False) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 00511c572239..53506698297f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -189,9 +189,9 @@ class CudaReproTests(TestCase): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("return").run(code[0]) + FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( + "return" + ).run(code[0]) self.assertEqual(out, f(*inputs)) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1408a0208cf0..763384671eb5 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -279,10 +279,14 @@ if HAS_CUDA_AND_TRITON: with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) - FileCheck().check( - "skipping cudagraphs due to cpu device (arg1_1). Found from" - ).check("y + 2").run(captured_output[0]) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + if torch._inductor.config.graph_partition: + # graph partition splits on cpu ops + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + else: + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) with capture_stderr() as captured_output: foo( @@ -292,7 +296,10 @@ if HAS_CUDA_AND_TRITON: FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 1 if torch._inductor.config.graph_partition else 2, + ) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -807,10 +814,16 @@ if HAS_CUDA_AND_TRITON: # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -1127,8 +1140,13 @@ if HAS_CUDA_AND_TRITON: node = self.curr_node() first_node = next(node._path_from_root) - self.assertFalse(first_node.unaliased_in_all_paths[0]) - self.assertTrue(first_node.cached_tensor_outputs[0] is None) + if torch._inductor.config.graph_partition: + # graph partition may changed the order of outputs + self.assertFalse(first_node.unaliased_in_all_paths[1]) + self.assertTrue(first_node.cached_tensor_outputs[1] is None) + else: + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1631,10 +1649,16 @@ if HAS_CUDA_AND_TRITON: # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) - self.assertEqual( - self.curr_node().expected_dead_indices_after_graph, - [(0, 1), (0, 2)], - ) + if torch._inductor.config.graph_partition: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 0), (0, 2)], + ) + else: + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2137,8 +2161,8 @@ if HAS_CUDA_AND_TRITON: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" - r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -3551,6 +3575,278 @@ if HAS_CUDA_AND_TRITON: self.assertEqual(self.get_manager().new_graph_id().id, 2) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_simple(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "recursively_apply_fns = runner.recursively_apply_fns" + ).run(code[0]) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device="cuda") + a1 = torch.randn(2, 3, device="cuda") + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device="cuda") + a = torch.ones([2, 3], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device="cuda") + a = torch.ones([4, 5], device="cuda") + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device="cuda") + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device="cuda") + size_tensor = torch.tensor(2, device="cuda") + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device="cuda"), + torch.randn(4, 4, device="cuda"), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device="cuda") + eager_w = torch.randn(2, 2, device="cuda", requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device="cuda") + compiled_w = torch.randn(2, 2, device="cuda", requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device="cuda"), + repeats := torch.tensor([5, 10, 15], device="cuda"), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device="cuda") + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to("cuda") + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device="cuda"), + torch.randn(3, 3, device="cuda"), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = compiled_f(torch.tensor(1, device="cuda")) + self.assertEqual(compiled_out, f(torch.tensor(1, device="cuda"))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to("cuda") + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to("cuda") + + x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device="cuda") + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + compiled_foo = torch.compile(foo) + x = torch.rand([20, 20], device="cuda") + + eager_out = foo(x) + compiled_out = compiled_foo(x) + self.assertEqual(eager_out, compiled_out) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index bee7e0ad917d..3824b25cdeae 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -31,10 +31,11 @@ class InductorAnnotationTestCase(TestCase): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) - self.assertEqual( - code.count("training_annotation = nvtx._device_range_start('inference')"), 1 + self.assertTrue( + code.count("training_annotation = nvtx._device_range_start('inference')") + >= 1 ) - self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) + self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) if __name__ == "__main__": diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index cdcedd5a1771..385a75d98f94 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15044,302 +15044,6 @@ if RUN_GPU: "'XBLOCK': 'constexpr'" ).run(code[0]) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - self.assertEqual(eager_out, compiled_out) - - _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").check( - "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" - ).check("recursively_apply_fns = runner.recursively_apply_fns").run( - code[0] - ) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_foreach_op(self): - def fn(a0, a1): - c = torch._foreach_abs([a0, a1]) - return torch.mul(c[0], a0) - - compiled_fn = torch.compile(fn) - - a0 = torch.randn(2, 3, device=self.device) - a1 = torch.randn(2, 3, device=self.device) - eager_out = fn(a0, a1) - compiled_out = compiled_fn(a0, a1) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_multiple_functions(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - def g(x): - return x + 1 - - x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = g(f(x, y)) - - f_compiled = torch.compile(f) - g_compiled = torch.compile(g) - compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_condition_op(self): - def f(p, b): - def true_fn(x): - return torch.cos(x) - - def false_fn(x): - return torch.sin(x) - - return torch.cond(p, true_fn, false_fn, [b]) - - compiled_f = torch.compile(f) - - # static shape - p = torch.tensor([True], device=self.device) - a = torch.ones([2, 3], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - # dynamic shape with backed symint - p = torch.tensor([True], device=self.device) - a = torch.ones([4, 5], device=self.device) - eager_out = f(p, a) - compiled_out = compiled_f(p, a) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_unbacked_symint_multi_output_layout(self): - def f(p, size_tensor): - size_val = size_tensor.item() - b = torch.ones([size_val, 3], device=GPU_TYPE) - - def true_fn(x): - return torch.cos(x), torch.cos(x) + 1 - - def false_fn(x): - return torch.sin(x), torch.sin(x) + 1 - - cond_out = torch.cond(p, true_fn, false_fn, [b]) - return cond_out[0] + cond_out[1] - - compiled_f = torch.compile(f) - p = torch.tensor([True], device=GPU_TYPE) - size_tensor = torch.tensor(2, device=GPU_TYPE) - eager_out = f(p, size_tensor) - compiled_out = compiled_f(p, size_tensor) - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - x, y = ( - torch.ones(4, 4, device=self.device), - torch.randn(4, 4, device=self.device), - ) - compiled_out = f_compiled(x, y) - self.assertEqual(compiled_out, f(x, y)) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_cat_backward(self): - def f(x, w): - y = torch.cat((x, x), dim=0) - z = y @ w - return z @ z.T - - compiled_f = torch.compile(f) - - for shape in (2, 3): - torch.manual_seed(42) - eager_x = torch.randn(shape, 2, device=self.device) - eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) - torch.manual_seed(42) - compiled_x = torch.randn(shape, 2, device=self.device) - compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) - - f(eager_x, eager_w).sum().backward() - compiled_f(compiled_x, compiled_w).sum().backward() - self.assertEqual(eager_w.grad, compiled_w.grad) - - @dynamo_config.patch("capture_dynamic_output_shape_ops", True) - @config.patch(implicit_fallbacks=True) - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_nested_indirect_indexing(self): - def nested(x, repeats): - rank = torch.arange(repeats.numel(), device=x.device) - index = rank.repeat_interleave(repeats, dim=0) - return torch.index_select(x, index=index, dim=0) - - example_inputs = ( - torch.randn((32, 64), device=self.device), - repeats := torch.tensor([5, 10, 15], device=self.device), - ) - torch._dynamo.mark_dynamic(repeats, 0) # create backed symint - - nested_opt = torch.compile(nested, backend="inductor") - - expect = nested(*example_inputs) - actual = nested_opt(*example_inputs) - self.assertEqual(expect, actual) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_symint_from_mutation_index(self): - x = torch.zeros(7, device=GPU_TYPE) - - def fn(n, a): - a[n] = -1 - return a - - opt_fn = torch.compile(fn, fullgraph=True) - - for n in range(2, x.shape[0]): - opt_fn(n, x) - self.assertEqual(x[n], -1) - - # Negative index triggers new compilation. - opt_fn(-x.shape[0], x) - - self.assertEqual(x[0], -1) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_unbacked_symint(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y) - eager_out = f(x, y) - self.assertEqual(compiled_out, eager_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_dynamic_scalar_inputs(self): - def f(x, y, integer): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x @ y - z += integer - return x1 + y1 + z + y_cpu.to(GPU_TYPE) - - f_compiled = torch.compile(f) - x, y = ( - torch.ones(3, 3, device=self.device), - torch.randn(3, 3, device=self.device), - ) - - torch._dynamo.decorators.mark_unbacked(x, 0) - torch._dynamo.decorators.mark_unbacked(y, 1) - - compiled_out = f_compiled(x, y, 5) - self.assertEqual(compiled_out, f(x, y, 5)) - - compiled_out = f_compiled(x, y, 6) - self.assertEqual(compiled_out, f(x, y, 6)) - - @torch._inductor.config.patch("graph_partition", True) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_graph_partition_item(self): - def f(x): - y = x + 1 - scalar = y.item() - return x + y + scalar - - compiled_f = torch.compile(f) - compiled_out = f(torch.tensor(1, device=GPU_TYPE)) - self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_buffer_reuse(self): - def f(x, y): - x1 = x + 1 - y1 = y + 1 - y_cpu = y1.cpu() + 1 - z = x1 + y1 + x @ y - u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 - u_cpu = u.cpu() + 2 - return z + u_cpu.to(GPU_TYPE) - - x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] - x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] - eager_out = f(x, y) - - f_compiled = torch.compile(f) - compiled_out = f_compiled(x_cloned, y_cloned) - - self.assertEqual(eager_out, compiled_out) - - @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_fused_scheduler_node(self): - def foo(x): - x = x * 20 - x_alias = x[0] - y = x * 10 - y_alias = y[0] - torch._dynamo.graph_break() - ind = torch.tensor(4, device=GPU_TYPE) - x_alias2 = x[ind:] - y_alias2 = y[ind:] - return x, x_alias, x_alias2, y_alias, y_alias2 - - foo = torch.compile(foo) - x = torch.rand([20, 20], device=GPU_TYPE) - _, code = run_and_get_code(foo, x) - - if not config.cpp_wrapper: - FileCheck().check("def partition_0(args):").run(code[0]) - @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") def test_respect_scaled_grouped_mm_layout_tag(self): # scaled_grouped_mm needs `mat2` to be column-major diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 49f8549170b6..a5ff9bd7b754 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -50,6 +50,7 @@ from ..utils import ( get_benchmark_name, IndentedBuffer, is_codegen_graph_partition_subgraph, + is_using_cudagraph_partition, LineContext, sympy_product, sympy_str, @@ -1197,7 +1198,14 @@ class PythonWrapperCodegen(CodeGen): self.write_args(graph_input_names) self.codegen_inputs() - self.codegen_input_size_and_nan_asserts() + + # avoid duplicating asserts for both partition functions and + # the call function when using cudagraph partition + if not ( + is_using_cudagraph_partition() + and (not is_codegen_graph_partition_subgraph(self)) + ): + self.codegen_input_size_and_nan_asserts() def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8d3b4cd7ed49..770da725a9aa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -437,7 +437,11 @@ max_autotune_report_choices_stats = ( ) # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph -graph_partition = False +graph_partition: bool = ( + os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") + == "1" +) + # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 2686d1d2ddde..7826c797d36b 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,6 +10,8 @@ from torch._dynamo.utils import counters, get_metrics_context from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet +from .utils import is_using_cudagraph_partition + if TYPE_CHECKING: from collections.abc import Sequence @@ -170,7 +172,8 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) - if torch._inductor.config.graph_partition: + # dynamo cudagraph does not support graph partition + if is_using_cudagraph_partition(): # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e0a0309d1c81..d8a96c573b32 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2179,7 +2179,10 @@ class Scheduler: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) self.process_grouped_nodes() - if torch._inductor.config.graph_partition: + if ( + torch._inductor.config.graph_partition + and torch._inductor.config.triton.cudagraphs + ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) @@ -4312,6 +4315,12 @@ class Scheduler: ) -> bool: """Return True if we should partition the inductor graph on this node""" + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if not torch._inductor.config.triton.cudagraphs: + return True + # avoid duplicating logs when should_partition is called multiple times # on the same node def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f21905e16e9d..0418edb2a115 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3329,6 +3329,13 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) +def is_using_cudagraph_partition() -> bool: + return ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + + def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V