mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] don't try to reorder loops for template (#165601)
fix https://github.com/pytorch/pytorch/issues/165579 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165601 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
7669ac9402
commit
a303d6dda9
@ -589,6 +589,31 @@ class LoopOrderingTest(TestCase):
|
||||
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
|
||||
).run(code[0])
|
||||
|
||||
@inductor_config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
"test_configs.max_mm_configs": 4,
|
||||
}
|
||||
)
|
||||
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
|
||||
def test_interaction_with_multi_template(self):
|
||||
"""
|
||||
Skip MultiTemplateBuffer during loop reordering
|
||||
"""
|
||||
|
||||
@torch.compile
|
||||
def f(x, y):
|
||||
return (x @ y), x + 1
|
||||
|
||||
N = 2
|
||||
x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
|
||||
out, code = run_and_get_code(f, x, y)
|
||||
# didn't fuse due to small savings
|
||||
FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0])
|
||||
|
||||
def test_fuse_with_scalar_shared_memory(self):
|
||||
"""
|
||||
Make sure if we can fuse two nodes sharing a scalar before,
|
||||
|
@ -3994,6 +3994,12 @@ class Scheduler:
|
||||
):
|
||||
return -1
|
||||
|
||||
# in some rare case, a template can be passed in.
|
||||
# Check test_interaction_with_multi_template in test_loop_ordering.py
|
||||
# and https://github.com/pytorch/pytorch/issues/165579
|
||||
if node1.is_template() or node2.is_template():
|
||||
return -1
|
||||
|
||||
node1_buffer_names = node1.read_writes.buffer_names()
|
||||
node2_buffer_names = node2.read_writes.buffer_names()
|
||||
# Fast path: no common buffers.
|
||||
|
Reference in New Issue
Block a user