From a303d6dda9532f6e6a8e0776ba866727df28b721 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 15 Oct 2025 17:52:57 -0700 Subject: [PATCH] [inductor] don't try to reorder loops for template (#165601) fix https://github.com/pytorch/pytorch/issues/165579 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165601 Approved by: https://github.com/yushangdi --- test/inductor/test_loop_ordering.py | 25 +++++++++++++++++++++++++ torch/_inductor/scheduler.py | 6 ++++++ 2 files changed, 31 insertions(+) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 34f70b6ec539..efe0fbfc2837 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -589,6 +589,31 @@ class LoopOrderingTest(TestCase): ".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True ).run(code[0]) + @inductor_config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "test_configs.max_mm_configs": 4, + } + ) + @skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune") + def test_interaction_with_multi_template(self): + """ + Skip MultiTemplateBuffer during loop reordering + """ + + @torch.compile + def f(x, y): + return (x @ y), x + 1 + + N = 2 + x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16) + y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16) + + out, code = run_and_get_code(f, x, y) + # didn't fuse due to small savings + FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0]) + def test_fuse_with_scalar_shared_memory(self): """ Make sure if we can fuse two nodes sharing a scalar before, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f85b5c7e39d9..d76036d3859b 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3994,6 +3994,12 @@ class Scheduler: ): return -1 + # in some rare case, a template can be passed in. + # Check test_interaction_with_multi_template in test_loop_ordering.py + # and https://github.com/pytorch/pytorch/issues/165579 + if node1.is_template() or node2.is_template(): + return -1 + node1_buffer_names = node1.read_writes.buffer_names() node2_buffer_names = node2.read_writes.buffer_names() # Fast path: no common buffers.