[inductor][triton] more JITCallable._hash_lock support (#162244)

Follow-up to #161768.

Context: ProcessPool pickles the outputs before sending them back to the main process. Triton kernels have some un-pickleable fields, so `prepare_for_pickle()` is used to strip out those fields. Previously, in the standard case (without triton_bundler.py), `prepare_for_pickle()` would strip out the un-pickleable fields and they would never be added back after unpickling, because the un-pickleable fields were not actually needed after compilation finished.

In #161768 updated `prepare_for_pickle` to also strip out the `fn._hash_lock` field, a newly added field in JITCallable instances which is a `threading.RLock()`, which is not pickleable.

It turns out that we do need to restore the `fn._hash_lock` field, even in the non-triton_bundler case - the MultiKernel case uses the hash lock.

To do this, we add `restore_after_unpickle()` which will restore fields (or if the old fields are not provided, initialize just the hash_lock)

Compile time benchmarks look good, maybe a very minor regression (see the comment below on the PR)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162244
Approved by: https://github.com/atalman
This commit is contained in:
David Berard
2025-09-05 11:09:35 -07:00
committed by PyTorch MergeBot
parent 1e0656f063
commit fb0afa853e
3 changed files with 20 additions and 9 deletions

View File

@ -465,6 +465,8 @@ class AsyncCompile:
kernel.set_compile_info(compile_id, is_backward)
CompiledTritonKernels.remove_future(source_code)
kernel.restore_after_unpickle(old_values=None)
kernel.precompile(
warm_cache_only=False,
reload_kernel=reload_kernel_in_parent,

View File

@ -640,6 +640,23 @@ class CachingAutotuner(KernelInterface):
self.fn._hash_lock = None
return old_values
def restore_after_unpickle(
self, old_values: Optional[tuple[Any, Any, Any, Any, Any, Any]]
) -> None:
if old_values:
(
self.fn.fn,
self.fn.__globals__,
self.fn.used_global_vals,
self.fn.repr,
self.launchers,
self.fn._hash_lock,
) = old_values
else:
# even if we don't need/have specific values, we do need the
# _hash_lock to be a valid RLock
self.fn._hash_lock = threading.RLock()
def prepare_for_caching(self) -> None:
"""
Statically Launched CUDA Kernels have a raw cubin on them

View File

@ -185,15 +185,7 @@ class TritonBundler:
)
# Put the values back since we need it to use now
(
kernel.fn.fn,
kernel.fn.__globals__,
kernel.fn.used_global_vals,
kernel.fn.repr,
kernel.launchers,
hash_lock,
) = old_values
kernel.fn._hash_lock = hash_lock
kernel.restore_after_unpickle(old_values)
@classmethod
def collect_static_autotuners(