mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[inductor] fix cpp_wrapper inputs mismatch (#116197)
Summary: fixes https://github.com/pytorch/pytorch/issues/115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116197 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7571511af9
commit
e5bcfe205e
@ -165,6 +165,7 @@ if RUN_CUDA:
|
||||
BaseTest("test_index_put_deterministic_fallback"),
|
||||
BaseTest("test_adding_tensor_offsets"),
|
||||
BaseTest("test_index_tensor"),
|
||||
BaseTest("test_layer_norm"),
|
||||
BaseTest("test_linear1"),
|
||||
BaseTest("test_linear2"),
|
||||
BaseTest("test_mm_views"),
|
||||
|
||||
@ -557,6 +557,12 @@ def _run_and_assert_no_indirect_indexing(test_case, func, *args, **kwargs):
|
||||
return result
|
||||
|
||||
|
||||
def assertGeneratedKernelCountEqual(self: TestCase, expected: int):
|
||||
if config.cpp_wrapper:
|
||||
expected *= 2
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected)
|
||||
|
||||
|
||||
class SweepInputs2:
|
||||
input_gen_types1 = [
|
||||
"dense",
|
||||
@ -804,7 +810,7 @@ class CommonTemplate:
|
||||
torch.randn(26),
|
||||
),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_forced_buffer_realize(self):
|
||||
# Test torch._test_inductor_realize forces a buffer to be realized
|
||||
@ -839,10 +845,7 @@ class CommonTemplate:
|
||||
),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5)
|
||||
self.assertEqual(
|
||||
torch._inductor.metrics.generated_kernel_count,
|
||||
1 if self.device == "cuda" else 3,
|
||||
)
|
||||
assertGeneratedKernelCountEqual(self, 1 if self.device == "cuda" else 3)
|
||||
|
||||
def test_index_propagation(self):
|
||||
def flip(x):
|
||||
@ -2902,7 +2905,7 @@ class CommonTemplate:
|
||||
(torch.randn(2, 4, 21, 21),),
|
||||
check_lowp=False,
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
def test_multi_threading(self):
|
||||
model = torch.nn.Linear(2, 3).eval()
|
||||
@ -3159,7 +3162,7 @@ class CommonTemplate:
|
||||
fn,
|
||||
(torch.randn([16, 64, 55, 55]),),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
# From https://github.com/pytorch/pytorch/issues/94775
|
||||
def test_max_pool2d7(self):
|
||||
@ -3185,7 +3188,7 @@ class CommonTemplate:
|
||||
fn,
|
||||
(torch.randn([2, 2, 3, 6]),),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
def test_avg_pool2d1(self):
|
||||
def fn(x):
|
||||
@ -3265,7 +3268,7 @@ class CommonTemplate:
|
||||
fn,
|
||||
(-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
def test_avg_pool2d8(self):
|
||||
# https://github.com/pytorch/pytorch/issues/100987
|
||||
@ -3569,7 +3572,7 @@ class CommonTemplate:
|
||||
with torch.no_grad():
|
||||
self.common(m, (torch.randn([16, 32]),), check_lowp=False)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_transpose_add(self):
|
||||
def fn(a, b):
|
||||
@ -3579,7 +3582,7 @@ class CommonTemplate:
|
||||
fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False
|
||||
)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
@patch.object(config.triton, "persistent_reductions", True)
|
||||
def test_softmax_one_kernel_persist(self):
|
||||
@ -3592,7 +3595,7 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
@patch.object(config.triton, "persistent_reductions", False)
|
||||
def test_softmax_one_kernel_loop(self):
|
||||
@ -3604,7 +3607,7 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_complex_fallback(self):
|
||||
def fn(x):
|
||||
@ -3614,7 +3617,7 @@ class CommonTemplate:
|
||||
fn,
|
||||
(torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),),
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
class ToComplex(nn.Module):
|
||||
def forward(self, x):
|
||||
@ -3623,7 +3626,7 @@ class CommonTemplate:
|
||||
self.common(ToComplex(), (torch.rand([1, 2, 4, 8]),), check_lowp=False)
|
||||
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_view_as_complex(self):
|
||||
class Repro(torch.nn.Module):
|
||||
@ -3675,7 +3678,7 @@ class CommonTemplate:
|
||||
check_lowp=False,
|
||||
)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_gather_scatter(self):
|
||||
def fn(node_feat, edge_index):
|
||||
@ -3701,7 +3704,7 @@ class CommonTemplate:
|
||||
check_lowp=False,
|
||||
)
|
||||
if self.device != "cpu":
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
|
||||
assertGeneratedKernelCountEqual(self, 2)
|
||||
|
||||
@config.patch(max_fusion_size=1)
|
||||
def test_no_mega_fusion_during_lowering(self):
|
||||
@ -3728,7 +3731,7 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn([32]),), check_lowp=False)
|
||||
# if we have a copy there will be more than 1 kernel
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
def fn(x):
|
||||
@ -6230,7 +6233,7 @@ class CommonTemplate:
|
||||
indices,
|
||||
],
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_max_pool2d_with_indices_backward5(self):
|
||||
# Window size is too big. Should fallback
|
||||
@ -6257,7 +6260,7 @@ class CommonTemplate:
|
||||
indices,
|
||||
],
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
# From https://github.com/pytorch/pytorch/issues/93384
|
||||
def test_max_pool2d_with_indices_backward6(self):
|
||||
@ -6285,7 +6288,7 @@ class CommonTemplate:
|
||||
indices,
|
||||
],
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
def test_issue102546(self):
|
||||
def fn(x):
|
||||
@ -6356,7 +6359,7 @@ class CommonTemplate:
|
||||
torch.randn([1, 2016, 21, 21]),
|
||||
],
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_avg_pool2d_backward4(self):
|
||||
def fn(a, b):
|
||||
@ -6380,7 +6383,7 @@ class CommonTemplate:
|
||||
],
|
||||
check_lowp=False,
|
||||
)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
@config.patch(search_autotune_cache=False)
|
||||
def test_mm_views(self):
|
||||
@ -7316,7 +7319,7 @@ class CommonTemplate:
|
||||
)
|
||||
|
||||
# expanded dim should not cause copy in require_stride_order
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
@requires_cuda()
|
||||
@unittest.skipIf(
|
||||
@ -7913,7 +7916,7 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (input, boundaries, add_value), check_lowp=False)
|
||||
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_bucketize_computed_offsets(self):
|
||||
def fn(inp, offsets):
|
||||
|
||||
@ -518,7 +518,7 @@ def fx_codegen_and_compile(
|
||||
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
|
||||
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
|
||||
# we currently use fake tensors and defake them later.
|
||||
example_inputs=V.real_inputs if is_inference else example_inputs,
|
||||
example_inputs=example_inputs,
|
||||
shape_env=shape_env,
|
||||
num_static_inputs=num_fixed,
|
||||
graph_id=graph_id,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
@ -1038,9 +1039,23 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
), "Unknown type when creating real inputs" + str(type(x))
|
||||
return x
|
||||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
if tracing_context.output_strides:
|
||||
tracing_context.output_strides.clear()
|
||||
|
||||
params_flat = [
|
||||
param
|
||||
for param in tracing_context.params_flat # type: ignore[union-attr]
|
||||
if param is not None
|
||||
]
|
||||
real_inputs = [
|
||||
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
|
||||
]
|
||||
else:
|
||||
real_inputs = [materialize(x) for x in V.real_inputs]
|
||||
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
assert self.example_inputs is not None
|
||||
real_inputs = [materialize(x) for x in self.example_inputs]
|
||||
compiled(real_inputs)
|
||||
del real_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user