Fix incorrect function signature in template (#165567)

Summary:
In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid
argument out, but it's not reflected in our template.

Test Plan:
Included in commit.
python test/inductor/test_aot_inductor.py
AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567
Approved by: https://github.com/desertfire
This commit is contained in:
Mu-Chu Lee
2025-10-15 10:41:51 -07:00
committed by PyTorch MergeBot
parent 7dabfb07cb
commit 9fccbdd4f0
2 changed files with 33 additions and 1 deletions

View File

@ -2340,6 +2340,39 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
def test_cond_symint_input_disable_one_pass(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
input1 = (
torch.ones(3, 3, device=self.device),
torch.ones(5, device=self.device),
torch.ones(3, 3, device=self.device),
)
input2 = (
torch.ones(10, 3, device=self.device),
torch.ones(6, device=self.device),
torch.ones(10, 3, device=self.device),
)
inputs = (input1, input2)
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
with torch._inductor.config.patch({"triton.autotune_at_compile_time": False}):
self.check_model_with_multiple_inputs(
M(),
inputs,
dynamic_shapes=dynamic_shapes,
)
def test_while_loop_simple(self):
inputs = (
torch.randn((10, 20), device=self.device),

View File

@ -2631,7 +2631,6 @@ class PythonWrapperCodegen(CodeGen):
if len(kernel.launchers) == 0:
kernel.precompile()
kernel.save_gpu_kernel(
grid=(0, 0, 0), # use dummy grid
stream="stream", # use dummy stream
launcher=kernel.launchers[0],
)