[pt2_provenance_tracking] add support for cpp kernel (#149185)

Summary:
As title.

Add inductor cpp kernel to post grad graph node mapping
& UT.

Context:
Raised as a feature request for AOTI CPU case.

https://fb.workplace.com/groups/1028545332188949/permalink/1169020841474730/

Differential Revision: D71181284

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149185
Approved by: https://github.com/jingsh
This commit is contained in:
Rachel Guo
2025-03-18 04:43:07 +00:00
committed by PyTorch MergeBot
parent 7869196482
commit b8f91bcb14
2 changed files with 63 additions and 1 deletions

View File

@ -5,12 +5,14 @@ import logging
import re
import shutil
import tempfile
import unittest
from pathlib import Path
import torch
from torch._inductor import config
from torch._inductor.debug import create_node_mapping
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda
@ -31,7 +33,6 @@ class Model(torch.nn.Module):
return z
@requires_cuda
@config.patch("trace.enabled", True)
class TestProvenanceTracingArtifact(TestCase):
"""
@ -39,6 +40,7 @@ class TestProvenanceTracingArtifact(TestCase):
corresponding "inductor triton kernel node" is expected.
"""
@requires_cuda
def _check_provenance_tracing_artifact(self, filepath):
self.assertTrue(filepath.is_dir())
filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json"
@ -113,6 +115,7 @@ class TestProvenanceTracingArtifact(TestCase):
]
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
@requires_cuda
def test_triton_kernel_to_post_grad_tracing(self):
a = torch.randn(10, 20, device="cuda")
b = torch.randn(20, 30, device="cuda")
@ -146,6 +149,60 @@ class TestProvenanceTracingArtifact(TestCase):
finally:
shutil.rmtree(filepath)
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
def test_triton_kernel_to_post_grad_tracing_cpu(self):
a = torch.randn(10, 20, device="cpu")
b = torch.randn(20, 30, device="cpu")
c = torch.randn(10, 30, device="cpu")
example_inputs = (a, b, c)
model = Model()
ep = torch.export._trace._export(model, example_inputs)
gm = ep.module()
filepath = None
for backend in ["aot_inductor", "inductor"]:
try:
with config.patch(
{
"trace.debug_dir": tempfile.mkdtemp(),
"force_disable_caches": True,
}
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"),
level=logging.WARNING,
) as cm:
if backend == "aot_inductor":
so_path = torch._inductor.aot_compile(gm, example_inputs)
optimized = AOTIRunnerUtil.load("cpu", so_path)
optimized(*example_inputs)
else:
compiled = torch.compile(gm, backend=backend)
compiled(*example_inputs)
self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1))
filename = (
Path(filepath)
/ "inductor_triton_kernel_to_post_grad_nodes.json"
)
with open(filename) as f:
actual_data = json.load(f)
# check the inductor kernel to post grad nodes mapping is expected for cpu
expected_data = {
"cpp_fused_mul_0": ["mul"],
"cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"],
}
self.assertEqual(
sorted(actual_data.items()), sorted(expected_data.items())
)
finally:
if filepath:
shutil.rmtree(filepath)
class TestProvenanceTracingNodeMapping(TestCase):
def test_create_node_mapping(self):

View File

@ -43,6 +43,7 @@ from ..utils import (
is_welford_reduction,
parallel_num_threads,
Placeholder,
set_kernel_post_grad_provenance_tracing,
sympy_index_symbol,
sympy_index_symbol_with_prefix,
sympy_product,
@ -5150,6 +5151,10 @@ class CppScheduling(BaseScheduling):
else ""
)
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
# below add provenance tracing info for cpu CppKernel types
if config.trace.enabled:
set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)