[user triton cache] Dedup user-defined Triton kernels by config in codecache (#143353)

Previously, the same kernel source with different autotuning configs would generate the same cache key which can lead to wrong cache it and silent incorrectness. Here we add the configs to the cache key in `FxGraphHashDetails`.

Test Plan:

```
python3 test/inductor/test_codecache.py -k test_triton_higher_order_op_different_configs
...
----------------------------------------------------------------------
Ran 2 tests in 3.590s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143353
Approved by: https://github.com/oulgen
This commit is contained in:
Adnan Akhundov
2024-12-16 18:06:58 -08:00
committed by PyTorch MergeBot
parent 6056efc5ff
commit 2531543c5f
2 changed files with 93 additions and 1 deletions

View File

@ -825,7 +825,15 @@ class FxGraphHashDetails:
from triton.runtime.autotuner import Autotuner
kernel = kernel_side_table.get_kernel(node.kwargs["kernel_idx"])
configs = None
if isinstance(kernel, Autotuner):
if kernel.configs:
configs = str(
sorted(
sorted(str(kv) for kv in c.all_kwargs().items())
for c in kernel.configs
)
)
kernel = kernel.fn
kernel_source = (
@ -837,7 +845,7 @@ class FxGraphHashDetails:
node.kwargs["constant_args_idx"]
)
self.user_defined_triton_source.append(
(kernel_source, constant_args)
(kernel_source, constant_args, configs)
)
# Alignment checks