Add kernel information JSON generation for AOTI packages (#160540)

Summary:
Build on D80031559. Generate kernel_information.json in AOTI compiled artifacts by combining stack traces and node mappings from provenance tracking.

This implementation delivers exactly what Zoomer team requested:

**1. Core Function**: `create_kernel_information_json()` in debug.py combines 3 data sources:
- `_inductor_kernel_stack_trace` → `stack_traces` field
- `_inductor_triton_kernel_to_post_grad_node_info` → `post_grad_nodes` field
- `_inductor_post_to_pre_grad_nodes["postToPre"]` → `pre_grad_nodes` field

**2. AOTI Integration**: codecache.py writes `kernel_information.json` to pt2 packages when both AOTI packaging and provenance tracking are enabled.

**3. Test Coverage**: TestKernelInformationAOTI class validates:
- JSON file creation in AOTI packages using zipfile
- Exact format compliance
- Proper disabling without provenance tracking

**Output Format** (exact specification):
```json
{
  "triton_kernel_name_1": {
    "stack_traces": [str, str, ...],
    "post_grad_nodes": [str, str, ...],
    "pre_grad_nodes": [str, str, ...]
  }
}
```

Test Plan:
```
buck test fbcode//caffe2/test/inductor:provenance_tracing -- TestKernelInformationAOTI
```

Manual validation:
```python
import torch
model = torch.nn.Linear(10, 1)
with torch._inductor.config.patch("aot_inductor.package", True):
    with torch._inductor.config.patch("trace.basic_provenance_tracking", True):
        # AOTI compilation should generate kernel_information.json
        compiled = torch.export.export(model, (torch.randn(1, 10),))
```
---

Rollback Plan:

Differential Revision: D80139160

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160540
Approved by: https://github.com/yushangdi
This commit is contained in:
Sandeep Narendranath Karjala
2025-08-20 02:33:45 +00:00
committed by PyTorch MergeBot
parent 54cc63b467
commit 2b62ef7420
4 changed files with 227 additions and 18 deletions

View File

@ -4,16 +4,19 @@ import contextlib
import io
import json
import logging
import os
import re
import shutil
import tempfile
import unittest
import zipfile
from pathlib import Path
import torch
from torch._dynamo.utils import detect_fake_mode
from torch._inductor import config
from torch._inductor.debug import (
create_kernel_information_json,
create_mapping_pre_post_grad_nodes,
create_node_mapping_kernel_to_post_grad,
)
@ -66,6 +69,23 @@ class Model3(torch.nn.Module):
return torch.nn.functional.linear(a, self.weight, self.bias)
class Model4(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
@config.patch("trace.enabled", True)
@config.patch("trace.provenance_tracking_level", 1)
class TestProvenanceTracingArtifact(TestCase):
@ -527,24 +547,8 @@ class TestProvenanceTracingStackTraces(TestCase):
)
@requires_cuda_and_triton
def test_tlparse_kernel_stack_traces(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
device = "cuda"
model = Model().to(device)
model = Model4().to(device)
x = torch.randn(8, 10).to(device)
a = torch.randn(10, 20).to(device)
b = torch.randn(20, 30).to(device)
@ -585,6 +589,160 @@ class TestProvenanceTracingStackTraces(TestCase):
f"Mismatch for key: {key}",
)
def _check_kernel_information_json(self, kernel_info, expected_kernels):
"""Validate kernel information JSON structure and content."""
self.assertIsInstance(kernel_info, dict)
for expected in expected_kernels:
self.assertIn(
expected,
kernel_info,
f"Expected kernel {expected} not found in {list(kernel_info)}",
)
for data in kernel_info.values():
self.assertIsInstance(data, dict)
for field in ["stack_traces", "post_grad_nodes", "pre_grad_nodes"]:
self.assertIn(field, data)
self.assertIsInstance(data[field], list)
for item in data[field]:
self.assertIsInstance(item, str)
@requires_cuda_and_triton
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
def test_kernel_information_generation(self):
"""Test basic kernel information generation in AOTI packages."""
model = Model4().to("cuda")
x = torch.randn(8, 10, device="cuda")
a = torch.randn(10, 20, device="cuda")
b = torch.randn(20, 30, device="cuda")
c = torch.randn(10, 30, device="cuda")
inputs = (x, a, b, c)
with tempfile.TemporaryDirectory() as temp_dir:
ep = torch.export.export(model, inputs, strict=False)
pt2_file = os.path.join(temp_dir, "model.pt2")
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
# Extract and check kernel_information.json exists in the package
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
zip_ref.extractall(temp_dir)
json_path = os.path.join(
temp_dir,
"model",
"data",
"aotinductor",
"model",
"kernel_information.json",
)
self.assertTrue(
os.path.exists(json_path),
f"kernel_information.json not found in extracted package at {json_path}",
)
with open(json_path) as f:
kernel_info = json.load(f)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_0": {
"stack_traces": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
"x = self.relu(x)",
],
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
},
"triton_poi_fused_mul_1": {
"stack_traces": [
"d = a * 3.14",
],
"post_grad_nodes": ["mul"],
"pre_grad_nodes": ["mul"],
},
"triton_poi_fused_addmm_gelu_2": {
"stack_traces": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
],
"post_grad_nodes": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
"pre_grad_nodes": ["gelu", "addmm"],
},
"aoti_torch_cuda_mm_out": {
"stack_traces": [
"x = self.fc1(x)",
"y = torch.addmm(c, d, b)",
],
"post_grad_nodes": ["mm_default_1", "mm_default"],
"pre_grad_nodes": ["linear", "addmm"],
},
}
self._check_kernel_information_json(kernel_info, expected.keys())
self.assertEqual(set(kernel_info.keys()), set(expected.keys()))
for key, data in expected.items():
all_lines = ",".join(kernel_info[key]["stack_traces"])
for s in data["stack_traces"]:
self.assertTrue(s in all_lines)
self.assertEqual(
sorted(kernel_info[key]["pre_grad_nodes"]),
sorted(data["pre_grad_nodes"]),
f"Mismatch for key: {key}",
)
self.assertEqual(
sorted(kernel_info[key]["post_grad_nodes"]),
sorted(data["post_grad_nodes"]),
f"Mismatch for key: {key}",
)
@torch._inductor.config.patch("trace.provenance_tracking_level", 0)
def test_no_kernel_information_without_provenance_tracking(self):
"""Test that kernel_information.json is not generated without provenance tracking."""
class SimpleModel(torch.nn.Module):
def forward(self, x):
return x * 2.0
model = SimpleModel()
x = torch.randn(4, 8)
# Compile with AOTI but without provenance tracking
with tempfile.TemporaryDirectory() as temp_dir:
ep = torch.export.export(model, (x,), strict=False)
pt2_file = os.path.join(temp_dir, "model.pt2")
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
# Extract and check kernel_information.json was NOT created in the package
extract_dir = os.path.join(temp_dir, "extracted")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
zip_ref.extractall(extract_dir)
expected_json_path = os.path.join(extract_dir, "kernel_information.json")
self.assertFalse(
os.path.exists(expected_json_path),
"kernel_information.json should not exist in package when provenance tracking is disabled",
)
def test_create_kernel_information_json_function(self):
"""Test the create_kernel_information_json function directly."""
# Test with empty state
result = create_kernel_information_json()
self.assertIsInstance(result, dict)
self.assertEqual(len(result), 0) # Should be empty with no provenance data
if __name__ == "__main__":
run_tests()

View File

@ -174,7 +174,7 @@ class TestTorchbind(TestCase):
custom_objs_config = file
elif file.endswith("/custom_obj_0"):
custom_obj_0 = file
elif file.endswith(".json") and "metadata" not in file:
elif file.endswith("wrapper.json") and "metadata" not in file:
extern_json = file
self.assertIsNotNone(custom_objs_config)

View File

@ -2414,6 +2414,15 @@ end
generated_files.append(output_so)
if config.aot_inductor.package:
if config.trace.provenance_tracking_level != 0:
kernel_info = torch._inductor.debug.create_kernel_information_json()
kernel_info_json = os.path.join(
wrapper_path_operator.parent, "kernel_information.json"
)
with open(kernel_info_json, "w") as f:
f.write(json.dumps(kernel_info, indent=4))
generated_files.append(kernel_info_json)
# We want to return the directory that contains all the AOTI
# generated files, not just the so
# return os.path.split(output_so)[0]

View File

@ -1009,6 +1009,48 @@ def dump_inductor_provenance_info(
return {}
def create_kernel_information_json() -> dict[str, dict[str, list[str]]]:
"""Create kernel information JSON"""
try:
global _inductor_post_to_pre_grad_nodes
global _inductor_kernel_stack_trace
global _inductor_triton_kernel_to_post_grad_node_info
post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {})
all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet(
_inductor_triton_kernel_to_post_grad_node_info.keys()
)
result = {}
for kernel_name in all_kernels:
post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get(
kernel_name, []
)
pre_grad_nodes: OrderedSet[str] = OrderedSet()
for post_node in post_grad_nodes:
pre_grad_nodes.update(post_to_pre.get(post_node, []))
result[kernel_name] = {
"stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []),
"post_grad_nodes": post_grad_nodes,
"pre_grad_nodes": list(pre_grad_nodes),
}
return result
except Exception as e:
signpost_event(
"inductor",
"provenance_tracking_error",
{
"function": "create_kernel_information_json",
"error_msg": str(e),
"stack_trace": traceback.format_exc(),
},
)
return {}
def set_kernel_post_grad_provenance_tracing(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut],
kernel_name: str,