Fix edge case in inductor triton clean script (#130837)

The regex in the script is too restrictive, as it excludes examples with parentheses in args, like the following:
```
triton_poi_fused_add_0.run(arg0_1.item(), arg1_1.item(), buf0, 1, grid=grid(1), stream=streamNone)
                                       ^
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130837
Approved by: https://github.com/Chillee
This commit is contained in:
Ahmad Sarvmeily
2024-08-19 23:46:08 +00:00
committed by PyTorch MergeBot
parent 65b3e42074
commit 9a998d98f1
2 changed files with 16 additions and 1 deletions

View File

@ -81,6 +81,7 @@ class TestKernelBenchmark(TestCase):
compiled_module.__file__, f"{compiled_module.__file__}.cleaned"
)
self.assertTrue("@triton_heuristics" not in cleaned_triton)
self.assertTrue(".run(" not in cleaned_triton)
try:
out = subprocess.check_output(
f"{sys.executable} {compiled_module.__file__}.cleaned".split(),
@ -470,6 +471,20 @@ class TestKernelBenchmark(TestCase):
compiled_module = self.get_compiled_module()
self.verify_remove_inductor_deps(compiled_module)
@config.patch("triton.unique_kernel_names", True)
@config.patch(benchmark_kernel=False)
@config.patch(compile_threads=1)
def test_remove_inductor_deps_scalar(self):
@torch.compile
def f(a, b):
return a + b
a = torch.tensor(1.0, device=GPU_TYPE)
b = torch.tensor(2.0, device=GPU_TYPE)
f(a, b)
compiled_module = self.get_compiled_module()
self.verify_remove_inductor_deps(compiled_module)
if __name__ == "__main__":
if HAS_GPU:

View File

@ -55,7 +55,7 @@ def merge_params(original_params: List[str], new_params: List[str]) -> List[str]
def add_launch_params(original: str, kernel_to_params: Dict[str, str]) -> str:
# Regex to match the function call in the original string
pattern = r"(\w+)\.run\(([^)]*), grid=(.*\)), [^)]*\)"
pattern = r"(\w+)\.run\((.*), grid=(.*\)), [^)]*\)"
def replace(match) -> str:
# Extract parts from the regex match