[Inductor][Tritonparse] Get Inductor kernel params (#161953)

Summary: Save the config args that Inductor burns into `inductor_metadata` so we can optionally pass them to any Jit Hooks that are set. This allows us to pass them to Tritonparse.

Reviewed By: davidberard98, FindHao

Differential Revision: D80994791

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161953
Approved by: https://github.com/FindHao
This commit is contained in:
Nikhil Patel
2025-09-03 14:11:27 +00:00
committed by PyTorch MergeBot
parent b16d3f4c8c
commit aed33a8fcb
2 changed files with 14 additions and 1 deletions

View File

@ -772,7 +772,10 @@ class CachingAutotuner(KernelInterface):
and getattr(knobs.runtime, "jit_post_compile_hook", None)
):
try:
knobs.runtime.jit_post_compile_hook(
hook = knobs.runtime.jit_post_compile_hook
# base args everyone should get
call_kwargs = dict(
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
repr=getattr(self.fn, "src", None),
fn=self.fn,
@ -780,6 +783,14 @@ class CachingAutotuner(KernelInterface):
is_manual_warmup=False,
already_compiled=True,
)
# only add inductor_args if the hook takes it
sig = inspect.signature(hook)
params = sig.parameters
if "inductor_args" in params:
call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
hook(**call_kwargs)
except Exception:
log.exception("jit_post_compile_hook failed")

View File

@ -613,6 +613,8 @@ class TritonTemplateKernel(TritonKernel):
flops = self.estimate_flops()
inductor_meta["kernel_flop"] = flops
inductor_meta["config_args"] = self.meta
template_args = f"""
num_stages={self.num_stages},
num_warps={self.num_warps},