[inductor] fix cpp_wrapper inputs mismatch (#116197)

Summary: fixes https://github.com/pytorch/pytorch/issues/115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116197
Approved by: https://github.com/jansel
This commit is contained in:
Bin Bao
2023-12-26 16:17:05 +00:00
committed by PyTorch MergeBot
parent 7571511af9
commit e5bcfe205e
4 changed files with 46 additions and 27 deletions

View File

@ -165,6 +165,7 @@ if RUN_CUDA:
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_index_tensor"),
BaseTest("test_layer_norm"),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
BaseTest("test_mm_views"),

View File

@ -557,6 +557,12 @@ def _run_and_assert_no_indirect_indexing(test_case, func, *args, **kwargs):
return result
def assertGeneratedKernelCountEqual(self: TestCase, expected: int):
if config.cpp_wrapper:
expected *= 2
self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected)
class SweepInputs2:
input_gen_types1 = [
"dense",
@ -804,7 +810,7 @@ class CommonTemplate:
torch.randn(26),
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_forced_buffer_realize(self):
# Test torch._test_inductor_realize forces a buffer to be realized
@ -839,10 +845,7 @@ class CommonTemplate:
),
)
self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count,
1 if self.device == "cuda" else 3,
)
assertGeneratedKernelCountEqual(self, 1 if self.device == "cuda" else 3)
def test_index_propagation(self):
def flip(x):
@ -2902,7 +2905,7 @@ class CommonTemplate:
(torch.randn(2, 4, 21, 21),),
check_lowp=False,
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
def test_multi_threading(self):
model = torch.nn.Linear(2, 3).eval()
@ -3159,7 +3162,7 @@ class CommonTemplate:
fn,
(torch.randn([16, 64, 55, 55]),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
# From https://github.com/pytorch/pytorch/issues/94775
def test_max_pool2d7(self):
@ -3185,7 +3188,7 @@ class CommonTemplate:
fn,
(torch.randn([2, 2, 3, 6]),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
def test_avg_pool2d1(self):
def fn(x):
@ -3265,7 +3268,7 @@ class CommonTemplate:
fn,
(-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
def test_avg_pool2d8(self):
# https://github.com/pytorch/pytorch/issues/100987
@ -3569,7 +3572,7 @@ class CommonTemplate:
with torch.no_grad():
self.common(m, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_transpose_add(self):
def fn(a, b):
@ -3579,7 +3582,7 @@ class CommonTemplate:
fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
@patch.object(config.triton, "persistent_reductions", True)
def test_softmax_one_kernel_persist(self):
@ -3592,7 +3595,7 @@ class CommonTemplate:
self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
@patch.object(config.triton, "persistent_reductions", False)
def test_softmax_one_kernel_loop(self):
@ -3604,7 +3607,7 @@ class CommonTemplate:
self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_complex_fallback(self):
def fn(x):
@ -3614,7 +3617,7 @@ class CommonTemplate:
fn,
(torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
class ToComplex(nn.Module):
def forward(self, x):
@ -3623,7 +3626,7 @@ class CommonTemplate:
self.common(ToComplex(), (torch.rand([1, 2, 4, 8]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_view_as_complex(self):
class Repro(torch.nn.Module):
@ -3675,7 +3678,7 @@ class CommonTemplate:
check_lowp=False,
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_gather_scatter(self):
def fn(node_feat, edge_index):
@ -3701,7 +3704,7 @@ class CommonTemplate:
check_lowp=False,
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
assertGeneratedKernelCountEqual(self, 2)
@config.patch(max_fusion_size=1)
def test_no_mega_fusion_during_lowering(self):
@ -3728,7 +3731,7 @@ class CommonTemplate:
self.common(fn, (torch.randn([32]),), check_lowp=False)
# if we have a copy there will be more than 1 kernel
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_leaky_relu(self):
def fn(x):
@ -6230,7 +6233,7 @@ class CommonTemplate:
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_max_pool2d_with_indices_backward5(self):
# Window size is too big. Should fallback
@ -6257,7 +6260,7 @@ class CommonTemplate:
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
# From https://github.com/pytorch/pytorch/issues/93384
def test_max_pool2d_with_indices_backward6(self):
@ -6285,7 +6288,7 @@ class CommonTemplate:
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
def test_issue102546(self):
def fn(x):
@ -6356,7 +6359,7 @@ class CommonTemplate:
torch.randn([1, 2016, 21, 21]),
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_avg_pool2d_backward4(self):
def fn(a, b):
@ -6380,7 +6383,7 @@ class CommonTemplate:
],
check_lowp=False,
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
@config.patch(search_autotune_cache=False)
def test_mm_views(self):
@ -7316,7 +7319,7 @@ class CommonTemplate:
)
# expanded dim should not cause copy in require_stride_order
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)
@requires_cuda()
@unittest.skipIf(
@ -7913,7 +7916,7 @@ class CommonTemplate:
self.common(fn, (input, boundaries, add_value), check_lowp=False)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)
def test_bucketize_computed_offsets(self):
def fn(inp, offsets):

View File

@ -518,7 +518,7 @@ def fx_codegen_and_compile(
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
# we currently use fake tensors and defake them later.
example_inputs=V.real_inputs if is_inference else example_inputs,
example_inputs=example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,

View File

@ -1,3 +1,4 @@
import itertools
import logging
import operator
import os
@ -1038,9 +1039,23 @@ class GraphLowering(torch.fx.Interpreter):
), "Unknown type when creating real inputs" + str(type(x))
return x
if tracing_context := torch._guards.TracingContext.try_get():
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]
with torch.utils._python_dispatch._disable_current_modes():
assert self.example_inputs is not None
real_inputs = [materialize(x) for x in self.example_inputs]
compiled(real_inputs)
del real_inputs