mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix incorrect function signature in template (#165567)
Summary: In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid argument out, but it's not reflected in our template. Test Plan: Included in commit. python test/inductor/test_aot_inductor.py AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
7dabfb07cb
commit
9fccbdd4f0
@ -2340,6 +2340,39 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_cond_symint_input_disable_one_pass(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
a = y.shape[0]
|
||||
b = z.shape[0]
|
||||
|
||||
def true_fn(x):
|
||||
return x + a
|
||||
|
||||
def false_fn(x):
|
||||
return x + b * z
|
||||
|
||||
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
||||
|
||||
input1 = (
|
||||
torch.ones(3, 3, device=self.device),
|
||||
torch.ones(5, device=self.device),
|
||||
torch.ones(3, 3, device=self.device),
|
||||
)
|
||||
input2 = (
|
||||
torch.ones(10, 3, device=self.device),
|
||||
torch.ones(6, device=self.device),
|
||||
torch.ones(10, 3, device=self.device),
|
||||
)
|
||||
inputs = (input1, input2)
|
||||
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
||||
with torch._inductor.config.patch({"triton.autotune_at_compile_time": False}):
|
||||
self.check_model_with_multiple_inputs(
|
||||
M(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_while_loop_simple(self):
|
||||
inputs = (
|
||||
torch.randn((10, 20), device=self.device),
|
||||
|
@ -2631,7 +2631,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if len(kernel.launchers) == 0:
|
||||
kernel.precompile()
|
||||
kernel.save_gpu_kernel(
|
||||
grid=(0, 0, 0), # use dummy grid
|
||||
stream="stream", # use dummy stream
|
||||
launcher=kernel.launchers[0],
|
||||
)
|
||||
|
Reference in New Issue
Block a user