mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Inductor][Tritonparse] Get Inductor kernel params (#161953)
Summary: Save the config args that Inductor burns into `inductor_metadata` so we can optionally pass them to any Jit Hooks that are set. This allows us to pass them to Tritonparse. Reviewed By: davidberard98, FindHao Differential Revision: D80994791 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161953 Approved by: https://github.com/FindHao
This commit is contained in:
committed by
PyTorch MergeBot
parent
b16d3f4c8c
commit
aed33a8fcb
@ -772,7 +772,10 @@ class CachingAutotuner(KernelInterface):
|
|||||||
and getattr(knobs.runtime, "jit_post_compile_hook", None)
|
and getattr(knobs.runtime, "jit_post_compile_hook", None)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
knobs.runtime.jit_post_compile_hook(
|
hook = knobs.runtime.jit_post_compile_hook
|
||||||
|
|
||||||
|
# base args everyone should get
|
||||||
|
call_kwargs = dict(
|
||||||
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
|
key=getattr(self.fn, "cache_key", self.kernel_hash or str(self.fn)),
|
||||||
repr=getattr(self.fn, "src", None),
|
repr=getattr(self.fn, "src", None),
|
||||||
fn=self.fn,
|
fn=self.fn,
|
||||||
@ -780,6 +783,14 @@ class CachingAutotuner(KernelInterface):
|
|||||||
is_manual_warmup=False,
|
is_manual_warmup=False,
|
||||||
already_compiled=True,
|
already_compiled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# only add inductor_args if the hook takes it
|
||||||
|
sig = inspect.signature(hook)
|
||||||
|
params = sig.parameters
|
||||||
|
if "inductor_args" in params:
|
||||||
|
call_kwargs["inductor_args"] = self.inductor_meta["config_args"]
|
||||||
|
|
||||||
|
hook(**call_kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("jit_post_compile_hook failed")
|
log.exception("jit_post_compile_hook failed")
|
||||||
|
|
||||||
|
@ -613,6 +613,8 @@ class TritonTemplateKernel(TritonKernel):
|
|||||||
flops = self.estimate_flops()
|
flops = self.estimate_flops()
|
||||||
inductor_meta["kernel_flop"] = flops
|
inductor_meta["kernel_flop"] = flops
|
||||||
|
|
||||||
|
inductor_meta["config_args"] = self.meta
|
||||||
|
|
||||||
template_args = f"""
|
template_args = f"""
|
||||||
num_stages={self.num_stages},
|
num_stages={self.num_stages},
|
||||||
num_warps={self.num_warps},
|
num_warps={self.num_warps},
|
||||||
|
Reference in New Issue
Block a user