mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][triton] more JITCallable._hash_lock support (#162244)
Follow-up to #161768. Context: ProcessPool pickles the outputs before sending them back to the main process. Triton kernels have some un-pickleable fields, so `prepare_for_pickle()` is used to strip out those fields. Previously, in the standard case (without triton_bundler.py), `prepare_for_pickle()` would strip out the un-pickleable fields and they would never be added back after unpickling, because the un-pickleable fields were not actually needed after compilation finished. In #161768 updated `prepare_for_pickle` to also strip out the `fn._hash_lock` field, a newly added field in JITCallable instances which is a `threading.RLock()`, which is not pickleable. It turns out that we do need to restore the `fn._hash_lock` field, even in the non-triton_bundler case - the MultiKernel case uses the hash lock. To do this, we add `restore_after_unpickle()` which will restore fields (or if the old fields are not provided, initialize just the hash_lock) Compile time benchmarks look good, maybe a very minor regression (see the comment below on the PR) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162244 Approved by: https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e0656f063
commit
fb0afa853e
@ -465,6 +465,8 @@ class AsyncCompile:
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
CompiledTritonKernels.remove_future(source_code)
|
||||
|
||||
kernel.restore_after_unpickle(old_values=None)
|
||||
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
reload_kernel=reload_kernel_in_parent,
|
||||
|
@ -640,6 +640,23 @@ class CachingAutotuner(KernelInterface):
|
||||
self.fn._hash_lock = None
|
||||
return old_values
|
||||
|
||||
def restore_after_unpickle(
|
||||
self, old_values: Optional[tuple[Any, Any, Any, Any, Any, Any]]
|
||||
) -> None:
|
||||
if old_values:
|
||||
(
|
||||
self.fn.fn,
|
||||
self.fn.__globals__,
|
||||
self.fn.used_global_vals,
|
||||
self.fn.repr,
|
||||
self.launchers,
|
||||
self.fn._hash_lock,
|
||||
) = old_values
|
||||
else:
|
||||
# even if we don't need/have specific values, we do need the
|
||||
# _hash_lock to be a valid RLock
|
||||
self.fn._hash_lock = threading.RLock()
|
||||
|
||||
def prepare_for_caching(self) -> None:
|
||||
"""
|
||||
Statically Launched CUDA Kernels have a raw cubin on them
|
||||
|
@ -185,15 +185,7 @@ class TritonBundler:
|
||||
)
|
||||
|
||||
# Put the values back since we need it to use now
|
||||
(
|
||||
kernel.fn.fn,
|
||||
kernel.fn.__globals__,
|
||||
kernel.fn.used_global_vals,
|
||||
kernel.fn.repr,
|
||||
kernel.launchers,
|
||||
hash_lock,
|
||||
) = old_values
|
||||
kernel.fn._hash_lock = hash_lock
|
||||
kernel.restore_after_unpickle(old_values)
|
||||
|
||||
@classmethod
|
||||
def collect_static_autotuners(
|
||||
|
Reference in New Issue
Block a user