mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][Tritonparse] Get Inductor kernel params (#161953)
Summary: Save the config args that Inductor burns into `inductor_metadata` so we can optionally pass them to any Jit Hooks that are set. This allows us to pass them to Tritonparse. Reviewed By: davidberard98, FindHao Differential Revision: D80994791 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161953 Approved by: https://github.com/FindHao
This commit is contained in:
committed by
PyTorch MergeBot
parent
b16d3f4c8c
commit
aed33a8fcb
@ -772,7 +772,10 @@ class CachingAutotuner(KernelInterface):
|
||||
and getattr(knobs.runtime, "jit_post_compile_hook", None)
|
||||
):
|
||||
try:
|
||||
knobs.runtime.jit_post_compile_hook(
|
||||
hook = knobs.runtime.jit_post_compile_hook
|
||||
|
||||
# base args everyone should get
|
||||
call_kwargs = dict(
|
||||
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
|
||||
repr=getattr(self.fn, "src", None),
|
||||
fn=self.fn,
|
||||
@ -780,6 +783,14 @@ class CachingAutotuner(KernelInterface):
|
||||
is_manual_warmup=False,
|
||||
already_compiled=True,
|
||||
)
|
||||
|
||||
# only add inductor_args if the hook takes it
|
||||
sig = inspect.signature(hook)
|
||||
params = sig.parameters
|
||||
if "inductor_args" in params:
|
||||
call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
|
||||
|
||||
hook(**call_kwargs)
|
||||
except Exception:
|
||||
log.exception("jit_post_compile_hook failed")
|
||||
|
||||
|
@ -613,6 +613,8 @@ class TritonTemplateKernel(TritonKernel):
|
||||
flops = self.estimate_flops()
|
||||
inductor_meta["kernel_flop"] = flops
|
||||
|
||||
inductor_meta["config_args"] = self.meta
|
||||
|
||||
template_args = f"""
|
||||
num_stages={self.num_stages},
|
||||
num_warps={self.num_warps},
|
||||
|
Reference in New Issue
Block a user