Refactor Provenance Tracking (#163378)

Summary:
- Move the `provenance_level` flag check to inside the `set_kernel_post_grad_provenance_tracing` call to simply the code

- Move the `set_kernel_post_grad_provenance_tracing` call and `write_provenance_debug_handle` call to `codegen_comment`.

- If some `call_kernel` call sites don't have a proceeding `codegen_comment` call, add one. Now all `call_kernel` call sites are accompanied with a  `codegen_comment` call.

- Add a `codegen_comment` method to BaseScheduling and remove the noop `codegen_comment` method in Scheduling

- Remove `debug_handle` from `call_kernel`.

Test Plan:
CI

```
buck run @//mode/opt-split-dwarf fbcode//caffe2/test/inductor:provenance_tracing
```

Differential Revision: D82839271

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163378
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-09-25 22:55:59 +00:00
committed by PyTorch MergeBot
parent 908bcfd403
commit 520fca82c8
12 changed files with 101 additions and 103 deletions

View File

@ -150,7 +150,7 @@ class TestProvenanceTracingArtifact(TestCase):
"cppCodeToPost",
{
"triton_poi_fused_mul_0:1": ["mul"],
"triton_poi_fused_addmm_gelu_1:2": [
"triton_poi_fused_addmm_gelu_1:3": [
"mul_3",
"mul_1",
"add_tensor",
@ -164,12 +164,12 @@ class TestProvenanceTracingArtifact(TestCase):
"postToCppCode",
{
"mul": ["triton_poi_fused_mul_0:1"],
"mul_3": ["triton_poi_fused_addmm_gelu_1:2"],
"mul_1": ["triton_poi_fused_addmm_gelu_1:2"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1:2"],
"add": ["triton_poi_fused_addmm_gelu_1:2"],
"erf": ["triton_poi_fused_addmm_gelu_1:2"],
"mul_2": ["triton_poi_fused_addmm_gelu_1:2"],
"mul_3": ["triton_poi_fused_addmm_gelu_1:3"],
"mul_1": ["triton_poi_fused_addmm_gelu_1:3"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1:3"],
"add": ["triton_poi_fused_addmm_gelu_1:3"],
"erf": ["triton_poi_fused_addmm_gelu_1:3"],
"mul_2": ["triton_poi_fused_addmm_gelu_1:3"],
},
),
(
@ -195,18 +195,18 @@ class TestProvenanceTracingArtifact(TestCase):
),
]
if backend == "aot_inductor":
expected_mapping[0][1]["aoti_torch_cuda_mm_out:3"] = [
expected_mapping[0][1]["aoti_torch_cuda_mm_out:2"] = [
"mm_default"
]
expected_mapping[1][1]["mm_default"] = [
"aoti_torch_cuda_mm_out:3"
"aoti_torch_cuda_mm_out:2"
]
else:
expected_mapping[0][1]["extern_kernels.mm:3"] = [
expected_mapping[0][1]["extern_kernels.mm:2"] = [
"mm_default"
]
expected_mapping[1][1]["mm_default"] = [
"extern_kernels.mm:3"
"extern_kernels.mm:2"
]
self._check_provenance_tracking_node_mappings(
filepath, expected_mapping
@ -217,8 +217,8 @@ class TestProvenanceTracingArtifact(TestCase):
if backend == "aot_inductor":
expected_data = {
"cpp_fused_mul_0:1": ["mul"],
"aoti_torch_cpu_addmm_out:3": ["addmm"],
"cpp_fused_gelu_1:2": [
"aoti_torch_cpu_addmm_out:2": ["addmm"],
"cpp_fused_gelu_1:3": [
"mul_3",
"mul_1",
"add",
@ -230,14 +230,14 @@ class TestProvenanceTracingArtifact(TestCase):
# backend == "inductor"
expected_data = {
"cpp_fused_mul_0:1": ["mul"],
"cpp_fused_gelu_1:2": [
"cpp_fused_gelu_1:3": [
"mul_3",
"mul_1",
"add",
"erf",
"mul_2",
],
"extern_kernels.addmm:3": ["addmm"],
"extern_kernels.addmm:2": ["addmm"],
}
self._check_provenance_tracing_kernel_to_post_grad(
filepath, expected_data
@ -550,22 +550,22 @@ class TestProvenanceTracingStackTraces(TestCase):
example_inputs = (x, a, b, c)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:1": [
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:2": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
"x = self.relu(x)",
],
"triton_poi_fused_mul_1:2": [
"triton_poi_fused_mul_1:3": [
"d = a * 3.14",
],
"triton_poi_fused_addmm_gelu_2:3": [
"triton_poi_fused_addmm_gelu_2:5": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
],
"extern_kernels.mm:4": [
"extern_kernels.mm:1": [
"x = self.fc1(x)",
],
"extern_kernels.mm:5": [
"extern_kernels.mm:4": [
"y = torch.addmm(c, d, b)",
],
}
@ -648,7 +648,7 @@ class TestProvenanceTracingStackTraces(TestCase):
kernel_info = json.load(f)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_0:1": {
"triton_poi_fused_addmm_relu_sigmoid_0:2": {
"stack_traces": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
@ -657,14 +657,14 @@ class TestProvenanceTracingStackTraces(TestCase):
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
},
"triton_poi_fused_mul_1:2": {
"triton_poi_fused_mul_1:3": {
"stack_traces": [
"d = a * 3.14",
],
"post_grad_nodes": ["mul"],
"pre_grad_nodes": ["mul"],
},
"triton_poi_fused_addmm_gelu_2:3": {
"triton_poi_fused_addmm_gelu_2:5": {
"stack_traces": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
@ -679,14 +679,14 @@ class TestProvenanceTracingStackTraces(TestCase):
],
"pre_grad_nodes": ["gelu", "addmm"],
},
"aoti_torch_cuda_mm_out:4": {
"aoti_torch_cuda_mm_out:1": {
"stack_traces": [
"x = self.fc1(x)",
],
"post_grad_nodes": ["mm_default_1"],
"pre_grad_nodes": ["linear"],
},
"aoti_torch_cuda_mm_out:5": {
"aoti_torch_cuda_mm_out:4": {
"stack_traces": [
"y = torch.addmm(c, d, b)",
],

View File

@ -5376,6 +5376,7 @@ class CppScheduling(BaseScheduling):
)
user.node.mark_run()
self.codegen_comment(node_schedule, kernel_name)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.free_buffers_in_scheduler()
@ -5441,18 +5442,20 @@ class CppScheduling(BaseScheduling):
kernel_name = self.define_kernel(
src_code, self.kernel_group.scheduled_nodes
)
# below add provenance tracing info for cpu CppKernel types
debug_handle: Optional[int] = None
if config.trace.provenance_tracking_level != 0:
debug_handle = set_kernel_post_grad_provenance_tracing(
self.kernel_group.scheduled_nodes, kernel_name
)
self.kernel_group.call_kernel(
V.graph.wrapper_code, kernel_name, debug_handle=debug_handle
)
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
self.reset_kernel_group()
self._set_flush_status(False)
def codegen_comment(self, node_schedule, kernel_name=None):
# below add provenance tracing info for cpu CppKernel types
wrapper = V.graph.wrapper_code
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
class KernelGroup:
def __init__(self):
@ -5524,14 +5527,13 @@ class KernelGroup:
code.splice(self.loops_code)
return code.getvalue()
def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None):
def call_kernel(self, wrapper, kernel_name):
_, call_args, arg_types = self.args.cpp_argdefs()
wrapper.generate_kernel_call(
kernel_name,
call_args,
triton=False,
arg_types=arg_types,
debug_handle=debug_handle,
)

View File

@ -22,7 +22,6 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config, cpp_builder, ir
from ..debug import set_kernel_post_grad_provenance_tracing
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
@ -1297,14 +1296,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
device = d.type if (d := extern_kernel.get_device()) else self.device
debug_handle = None
if config.trace.provenance_tracking_level != 0:
debug_handle = set_kernel_post_grad_provenance_tracing(
extern_kernel, extern_kernel.get_kernel_name(), is_extern=True
)
self.generate_c_shim_extern_kernel_call(
extern_kernel.get_kernel_name(), args, device, debug_handle=debug_handle
extern_kernel.get_kernel_name(), args, device
)
if extern_kernel.python_kernel_name in (
@ -1362,19 +1355,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
args = args + output_args
device = d.type if (d := fallback_kernel.get_device()) else self.device
debug_handle = None
if config.trace.provenance_tracking_level != 0:
shim_fn = self.get_c_shim_func_name(fallback_kernel.cpp_kernel_name, device) # type: ignore[arg-type]
debug_handle = set_kernel_post_grad_provenance_tracing(
fallback_kernel,
shim_fn,
is_extern=True,
)
self.generate_c_shim_extern_kernel_call(
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
args,
device,
debug_handle=debug_handle,
)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)

View File

@ -175,6 +175,7 @@ class CUDACPPScheduling(BaseScheduling):
call_args, kernel_name, arg_signatures, kernel
)
with debug_printer_manager:
self.codegen_comment(node_schedule, kernel_name)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers

View File

@ -135,6 +135,7 @@ class CuteDSLScheduling(BaseScheduling):
with V.set_kernel_handler(kernel):
node_schedule = [template_node]
kernel_name = self.define_kernel(src_code_str, node_schedule)
self.codegen_comment(node_schedule, kernel_name)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.free_buffers_in_scheduler()

View File

@ -94,6 +94,7 @@ class ROCmCPPScheduling(BaseScheduling):
with V.set_kernel_handler(kernel):
node_schedule = [template_node]
kernel_name = self.define_kernel(src_code, node_schedule)
self.codegen_comment(node_schedule, kernel_name)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.free_buffers_in_scheduler()

View File

@ -41,7 +41,6 @@ from ..dependencies import MemoryDep, StarDep, WeakDep
if TYPE_CHECKING:
from ..ir import IRNode
from ..debug import set_kernel_post_grad_provenance_tracing
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.runtime_utils import green_text, yellow_text
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
@ -1471,17 +1470,10 @@ class SIMDScheduling(BaseScheduling):
for kernel in kernels:
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
MultiKernel.merge_workspaces_inplace(kernels)
debug_handles: list[tuple[str, Optional[int]]] = []
for kernel in kernels:
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking_level != 0:
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
debug_handles.append((kernel_name, debug_handle))
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
kernel.kernel_name = kernel_name
kernel.code_hash = code_hash(src_code)
@ -1497,11 +1489,11 @@ class SIMDScheduling(BaseScheduling):
for node in kernel_features.scheduler_nodes():
node.mark_run()
self.codegen_comment(node_schedule)
for kernel_name, debug_handle in debug_handles:
V.graph.wrapper_code.write_provenance_debug_handle(
kernel_name, debug_handle
)
# filter out NodeScheduleMarker
base_scheduler_nodes = [
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
]
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
final_kernel.call_kernel(final_kernel.kernel_name)
if config.nan_asserts:
@ -1696,11 +1688,6 @@ class SIMDScheduling(BaseScheduling):
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(
node_schedule, kernel.kernel_name
)
return kernel
def _get_multikernel_shapes(
@ -1819,8 +1806,7 @@ class SIMDScheduling(BaseScheduling):
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
multi_kernel = SizeHintMultiKernel(kernels)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)
self.codegen_comment(node_schedule, multi_kernel.kernel_name)
multi_kernel.call_kernel(multi_kernel.kernel_name)
V.graph.removed_buffers |= multi_kernel.removed_buffers
V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove
@ -1851,7 +1837,7 @@ class SIMDScheduling(BaseScheduling):
)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)
self.codegen_comment(node_schedule, kernel.kernel_name)
kernel.call_kernel(kernel.kernel_name, template_node.node)
V.graph.removed_buffers |= kernel.removed_buffers
@ -1937,12 +1923,7 @@ class SIMDScheduling(BaseScheduling):
for src_code, kernel, _ in kernel_code_list:
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
# dump provenance node info for ComboKernelNode/ForeachKernel type
if config.trace.provenance_tracking_level != 0:
set_kernel_post_grad_provenance_tracing(
combo_kernel_node.snodes, kernel_name
)
self.codegen_comment([combo_kernel_node])
self.codegen_comment(combo_kernel_node.snodes, kernel_name)
log.debug("ComboKernels: generated kernel %s.", kernel_name)
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
@ -2663,9 +2644,6 @@ class SIMDScheduling(BaseScheduling):
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
return src_code
def codegen_comment(self, node_schedule):
pass
def define_kernel(self, src_code, node_schedule, kernel):
raise NotImplementedError

View File

@ -34,6 +34,7 @@ from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
from ..debug import set_kernel_post_grad_provenance_tracing
from ..ops_handler import DefaultHandler
from ..runtime import triton_heuristics
from ..runtime.benchmarking import benchmarker
@ -4867,7 +4868,7 @@ class TritonScheduling(SIMDScheduling):
)
return cls.backend_features
def codegen_comment(self, node_schedule):
def codegen_comment(self, node_schedule, kernel_name=None):
wrapper = V.graph.wrapper_code
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
if origins:
@ -4893,6 +4894,13 @@ class TritonScheduling(SIMDScheduling):
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
)
if kernel_name:
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
def define_kernel(self, src_code, node_schedule, kernel):
wrapper = V.graph.wrapper_code
if src_code in wrapper.src_to_kernel:

View File

@ -40,7 +40,6 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import async_compile, config, ir
from ..codecache import output_code_log
from ..debug import set_kernel_post_grad_provenance_tracing
from ..ir import IRNode, ReinterpretView
from ..runtime import triton_heuristics
from ..runtime.hints import DeviceProperties
@ -509,19 +508,12 @@ class ExternKernelOutLine(WrapperLine):
else:
kernel_name = node.get_kernel_name()
device = d.type if (d := node.get_device()) else V.graph.device_type
provenance_debug_handle: Optional[int] = None
# set provenance tracing kernel mapping for ExternKernel types
if config.trace.provenance_tracking_level != 0:
provenance_debug_handle = set_kernel_post_grad_provenance_tracing(
node, kernel_name, is_extern=True
)
self.wrapper._generate_extern_kernel_out_helper(
kernel_name,
node.codegen_reference(),
node.output_view.codegen_reference() if node.output_view else None,
args,
device,
provenance_debug_handle,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
@ -1516,13 +1508,11 @@ class PythonWrapperCodegen(CodeGen):
out_view: Optional[str],
args: list[str],
device: str,
debug_handle: Optional[int] = None,
) -> None:
# add debug printer code for triton kernel calls at (jit) inductor level
debug_printer_manager = V.graph.wrapper_code.debug_printer
debug_printer_manager.set_printer_args(args, kernel, None, None, "extern")
args.append(f"out={out_view if out_view else out}")
self.write_provenance_debug_handle(kernel, debug_handle)
with debug_printer_manager:
self.writeline(f"{kernel}({', '.join(args)})")
@ -2696,7 +2686,6 @@ class PythonWrapperCodegen(CodeGen):
raw_args=None,
triton_meta=None,
original_fxnode_name=None,
debug_handle: Optional[int] = None,
):
"""
Generates kernel call code.
@ -2716,7 +2705,6 @@ class PythonWrapperCodegen(CodeGen):
)
device = device or V.graph.get_current_device_or_throw()
self.write_provenance_debug_handle(kernel_name, debug_handle)
self.writeline(
KernelCallLine(
self,

View File

@ -1103,6 +1103,9 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function.
"""
if config.trace.provenance_tracking_level == 0:
return None
try:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction

View File

@ -5605,11 +5605,23 @@ class ExternKernel(InputsKernel):
self.apply_constraint()
self.freeze_layout()
def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None:
def codegen_comment(
self, wrapper: PythonWrapperCodegen, kernel_name: Optional[str] = None
) -> None:
origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
if origin_str:
wrapper.make_comment(origin_str)
if not kernel_name:
kernel_name = self.try_get_kernel_name()
if kernel_name:
from .debug import set_kernel_post_grad_provenance_tracing
debug_handle = set_kernel_post_grad_provenance_tracing(
self, kernel_name, is_extern=True
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
raise NotImplementedError
@ -5652,25 +5664,29 @@ class ExternKernel(InputsKernel):
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
)
def get_kernel_name(self) -> str:
def try_get_kernel_name(self) -> Optional[str]:
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
device = d.type if (d := self.get_device()) else V.graph.device_type
if V.graph.fx_wrapper:
assert self.python_kernel_name is not None
return self.python_kernel_name
elif V.graph.cpp_wrapper:
assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
V.graph.wrapper_code
)
assert self.cpp_kernel_name is not None
if self.cpp_kernel_name is None:
return None
return V.graph.wrapper_code.get_c_shim_func_name(
self.cpp_kernel_name, device
)
else:
assert self.python_kernel_name is not None
return self.python_kernel_name
def get_kernel_name(self) -> str:
name = self.try_get_kernel_name()
assert name is not None
return name
@staticmethod
def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
pw = Pointwise.create(
@ -6803,7 +6819,7 @@ class UserDefinedTritonKernel(ExternKernel):
else:
raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
self.codegen_comment(wrapper)
self.codegen_comment(wrapper, new_name)
wrapper.generate_kernel_call(
new_name,
args,

View File

@ -5606,7 +5606,7 @@ class Scheduler:
V.graph.zero_dim_cpu_tensor_list.add(read.name)
class BaseScheduling:
class BaseScheduling: # noqa: docstring_linter
def __init__(self, scheduler: Optional[Scheduler]):
super().__init__()
self.scheduler = scheduler
@ -5749,3 +5749,19 @@ class BaseScheduling:
and memory copy time in milliseconds on randomly generated inputs.
"""
raise NotImplementedError
def codegen_comment(
self,
node_schedule: Sequence[BaseSchedulerNode],
kernel_name: Optional[str] = None,
) -> None:
if kernel_name:
from torch._inductor.debug import set_kernel_post_grad_provenance_tracing
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
V.graph.wrapper_code.write_provenance_debug_handle(
kernel_name, debug_handle
)