mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix edge case in inductor triton clean script (#130837)
The regex in the script is too restrictive, as it excludes examples with parentheses in args, like the following: ``` triton_poi_fused_add_0.run(arg0_1.item(), arg1_1.item(), buf0, 1, grid=grid(1), stream=streamNone) ^ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/130837 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
65b3e42074
commit
9a998d98f1
@ -81,6 +81,7 @@ class TestKernelBenchmark(TestCase):
|
||||
compiled_module.__file__, f"{compiled_module.__file__}.cleaned"
|
||||
)
|
||||
self.assertTrue("@triton_heuristics" not in cleaned_triton)
|
||||
self.assertTrue(".run(" not in cleaned_triton)
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
f"{sys.executable} {compiled_module.__file__}.cleaned".split(),
|
||||
@ -470,6 +471,20 @@ class TestKernelBenchmark(TestCase):
|
||||
compiled_module = self.get_compiled_module()
|
||||
self.verify_remove_inductor_deps(compiled_module)
|
||||
|
||||
@config.patch("triton.unique_kernel_names", True)
|
||||
@config.patch(benchmark_kernel=False)
|
||||
@config.patch(compile_threads=1)
|
||||
def test_remove_inductor_deps_scalar(self):
|
||||
@torch.compile
|
||||
def f(a, b):
|
||||
return a + b
|
||||
|
||||
a = torch.tensor(1.0, device=GPU_TYPE)
|
||||
b = torch.tensor(2.0, device=GPU_TYPE)
|
||||
f(a, b)
|
||||
compiled_module = self.get_compiled_module()
|
||||
self.verify_remove_inductor_deps(compiled_module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_GPU:
|
||||
|
@ -55,7 +55,7 @@ def merge_params(original_params: List[str], new_params: List[str]) -> List[str]
|
||||
|
||||
def add_launch_params(original: str, kernel_to_params: Dict[str, str]) -> str:
|
||||
# Regex to match the function call in the original string
|
||||
pattern = r"(\w+)\.run\(([^)]*), grid=(.*\)), [^)]*\)"
|
||||
pattern = r"(\w+)\.run\((.*), grid=(.*\)), [^)]*\)"
|
||||
|
||||
def replace(match) -> str:
|
||||
# Extract parts from the regex match
|
||||
|
Reference in New Issue
Block a user