mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add kernel information JSON generation for AOTI packages (#160540)
Summary: Build on D80031559. Generate kernel_information.json in AOTI compiled artifacts by combining stack traces and node mappings from provenance tracking. This implementation delivers exactly what Zoomer team requested: **1. Core Function**: `create_kernel_information_json()` in debug.py combines 3 data sources: - `_inductor_kernel_stack_trace` → `stack_traces` field - `_inductor_triton_kernel_to_post_grad_node_info` → `post_grad_nodes` field - `_inductor_post_to_pre_grad_nodes["postToPre"]` → `pre_grad_nodes` field **2. AOTI Integration**: codecache.py writes `kernel_information.json` to pt2 packages when both AOTI packaging and provenance tracking are enabled. **3. Test Coverage**: TestKernelInformationAOTI class validates: - JSON file creation in AOTI packages using zipfile - Exact format compliance - Proper disabling without provenance tracking **Output Format** (exact specification): ```json { "triton_kernel_name_1": { "stack_traces": [str, str, ...], "post_grad_nodes": [str, str, ...], "pre_grad_nodes": [str, str, ...] } } ``` Test Plan: ``` buck test fbcode//caffe2/test/inductor:provenance_tracing -- TestKernelInformationAOTI ``` Manual validation: ```python import torch model = torch.nn.Linear(10, 1) with torch._inductor.config.patch("aot_inductor.package", True): with torch._inductor.config.patch("trace.basic_provenance_tracking", True): # AOTI compilation should generate kernel_information.json compiled = torch.export.export(model, (torch.randn(1, 10),)) ``` --- Rollback Plan: Differential Revision: D80139160 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160540 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
54cc63b467
commit
2b62ef7420
@ -4,16 +4,19 @@ import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import detect_fake_mode
|
||||
from torch._inductor import config
|
||||
from torch._inductor.debug import (
|
||||
create_kernel_information_json,
|
||||
create_mapping_pre_post_grad_nodes,
|
||||
create_node_mapping_kernel_to_post_grad,
|
||||
)
|
||||
@ -66,6 +69,23 @@ class Model3(torch.nn.Module):
|
||||
return torch.nn.functional.linear(a, self.weight, self.bias)
|
||||
|
||||
|
||||
class Model4(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, a, b, c):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
d = a * 3.14
|
||||
y = torch.addmm(c, d, b)
|
||||
z = torch.nn.functional.gelu(y)
|
||||
return x, z
|
||||
|
||||
|
||||
@config.patch("trace.enabled", True)
|
||||
@config.patch("trace.provenance_tracking_level", 1)
|
||||
class TestProvenanceTracingArtifact(TestCase):
|
||||
@ -527,24 +547,8 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_tlparse_kernel_stack_traces(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, a, b, c):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.sigmoid(x)
|
||||
d = a * 3.14
|
||||
y = torch.addmm(c, d, b)
|
||||
z = torch.nn.functional.gelu(y)
|
||||
return x, z
|
||||
|
||||
device = "cuda"
|
||||
model = Model().to(device)
|
||||
model = Model4().to(device)
|
||||
x = torch.randn(8, 10).to(device)
|
||||
a = torch.randn(10, 20).to(device)
|
||||
b = torch.randn(20, 30).to(device)
|
||||
@ -585,6 +589,160 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
def _check_kernel_information_json(self, kernel_info, expected_kernels):
|
||||
"""Validate kernel information JSON structure and content."""
|
||||
self.assertIsInstance(kernel_info, dict)
|
||||
|
||||
for expected in expected_kernels:
|
||||
self.assertIn(
|
||||
expected,
|
||||
kernel_info,
|
||||
f"Expected kernel {expected} not found in {list(kernel_info)}",
|
||||
)
|
||||
|
||||
for data in kernel_info.values():
|
||||
self.assertIsInstance(data, dict)
|
||||
for field in ["stack_traces", "post_grad_nodes", "pre_grad_nodes"]:
|
||||
self.assertIn(field, data)
|
||||
self.assertIsInstance(data[field], list)
|
||||
for item in data[field]:
|
||||
self.assertIsInstance(item, str)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._inductor.config.patch("trace.provenance_tracking_level", 1)
|
||||
def test_kernel_information_generation(self):
|
||||
"""Test basic kernel information generation in AOTI packages."""
|
||||
|
||||
model = Model4().to("cuda")
|
||||
x = torch.randn(8, 10, device="cuda")
|
||||
a = torch.randn(10, 20, device="cuda")
|
||||
b = torch.randn(20, 30, device="cuda")
|
||||
c = torch.randn(10, 30, device="cuda")
|
||||
inputs = (x, a, b, c)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
ep = torch.export.export(model, inputs, strict=False)
|
||||
pt2_file = os.path.join(temp_dir, "model.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
|
||||
|
||||
# Extract and check kernel_information.json exists in the package
|
||||
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
|
||||
json_path = os.path.join(
|
||||
temp_dir,
|
||||
"model",
|
||||
"data",
|
||||
"aotinductor",
|
||||
"model",
|
||||
"kernel_information.json",
|
||||
)
|
||||
self.assertTrue(
|
||||
os.path.exists(json_path),
|
||||
f"kernel_information.json not found in extracted package at {json_path}",
|
||||
)
|
||||
|
||||
with open(json_path) as f:
|
||||
kernel_info = json.load(f)
|
||||
|
||||
expected = {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_0": {
|
||||
"stack_traces": [
|
||||
"x = self.sigmoid(x)",
|
||||
"x = self.fc1(x)",
|
||||
"x = self.relu(x)",
|
||||
],
|
||||
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
|
||||
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
|
||||
},
|
||||
"triton_poi_fused_mul_1": {
|
||||
"stack_traces": [
|
||||
"d = a * 3.14",
|
||||
],
|
||||
"post_grad_nodes": ["mul"],
|
||||
"pre_grad_nodes": ["mul"],
|
||||
},
|
||||
"triton_poi_fused_addmm_gelu_2": {
|
||||
"stack_traces": [
|
||||
"z = torch.nn.functional.gelu(y)",
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
"post_grad_nodes": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
"pre_grad_nodes": ["gelu", "addmm"],
|
||||
},
|
||||
"aoti_torch_cuda_mm_out": {
|
||||
"stack_traces": [
|
||||
"x = self.fc1(x)",
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
"post_grad_nodes": ["mm_default_1", "mm_default"],
|
||||
"pre_grad_nodes": ["linear", "addmm"],
|
||||
},
|
||||
}
|
||||
|
||||
self._check_kernel_information_json(kernel_info, expected.keys())
|
||||
|
||||
self.assertEqual(set(kernel_info.keys()), set(expected.keys()))
|
||||
for key, data in expected.items():
|
||||
all_lines = ",".join(kernel_info[key]["stack_traces"])
|
||||
for s in data["stack_traces"]:
|
||||
self.assertTrue(s in all_lines)
|
||||
|
||||
self.assertEqual(
|
||||
sorted(kernel_info[key]["pre_grad_nodes"]),
|
||||
sorted(data["pre_grad_nodes"]),
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
sorted(kernel_info[key]["post_grad_nodes"]),
|
||||
sorted(data["post_grad_nodes"]),
|
||||
f"Mismatch for key: {key}",
|
||||
)
|
||||
|
||||
@torch._inductor.config.patch("trace.provenance_tracking_level", 0)
|
||||
def test_no_kernel_information_without_provenance_tracking(self):
|
||||
"""Test that kernel_information.json is not generated without provenance tracking."""
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * 2.0
|
||||
|
||||
model = SimpleModel()
|
||||
x = torch.randn(4, 8)
|
||||
|
||||
# Compile with AOTI but without provenance tracking
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
ep = torch.export.export(model, (x,), strict=False)
|
||||
pt2_file = os.path.join(temp_dir, "model.pt2")
|
||||
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
|
||||
|
||||
# Extract and check kernel_information.json was NOT created in the package
|
||||
extract_dir = os.path.join(temp_dir, "extracted")
|
||||
os.makedirs(extract_dir, exist_ok=True)
|
||||
with zipfile.ZipFile(pt2_file, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_dir)
|
||||
|
||||
expected_json_path = os.path.join(extract_dir, "kernel_information.json")
|
||||
self.assertFalse(
|
||||
os.path.exists(expected_json_path),
|
||||
"kernel_information.json should not exist in package when provenance tracking is disabled",
|
||||
)
|
||||
|
||||
def test_create_kernel_information_json_function(self):
|
||||
"""Test the create_kernel_information_json function directly."""
|
||||
# Test with empty state
|
||||
result = create_kernel_information_json()
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertEqual(len(result), 0) # Should be empty with no provenance data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -174,7 +174,7 @@ class TestTorchbind(TestCase):
|
||||
custom_objs_config = file
|
||||
elif file.endswith("/custom_obj_0"):
|
||||
custom_obj_0 = file
|
||||
elif file.endswith(".json") and "metadata" not in file:
|
||||
elif file.endswith("wrapper.json") and "metadata" not in file:
|
||||
extern_json = file
|
||||
|
||||
self.assertIsNotNone(custom_objs_config)
|
||||
|
@ -2414,6 +2414,15 @@ end
|
||||
generated_files.append(output_so)
|
||||
|
||||
if config.aot_inductor.package:
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
kernel_info = torch._inductor.debug.create_kernel_information_json()
|
||||
kernel_info_json = os.path.join(
|
||||
wrapper_path_operator.parent, "kernel_information.json"
|
||||
)
|
||||
with open(kernel_info_json, "w") as f:
|
||||
f.write(json.dumps(kernel_info, indent=4))
|
||||
generated_files.append(kernel_info_json)
|
||||
|
||||
# We want to return the directory that contains all the AOTI
|
||||
# generated files, not just the so
|
||||
# return os.path.split(output_so)[0]
|
||||
|
@ -1009,6 +1009,48 @@ def dump_inductor_provenance_info(
|
||||
return {}
|
||||
|
||||
|
||||
def create_kernel_information_json() -> dict[str, dict[str, list[str]]]:
|
||||
"""Create kernel information JSON"""
|
||||
try:
|
||||
global _inductor_post_to_pre_grad_nodes
|
||||
global _inductor_kernel_stack_trace
|
||||
global _inductor_triton_kernel_to_post_grad_node_info
|
||||
|
||||
post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {})
|
||||
all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet(
|
||||
_inductor_triton_kernel_to_post_grad_node_info.keys()
|
||||
)
|
||||
|
||||
result = {}
|
||||
for kernel_name in all_kernels:
|
||||
post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get(
|
||||
kernel_name, []
|
||||
)
|
||||
|
||||
pre_grad_nodes: OrderedSet[str] = OrderedSet()
|
||||
for post_node in post_grad_nodes:
|
||||
pre_grad_nodes.update(post_to_pre.get(post_node, []))
|
||||
|
||||
result[kernel_name] = {
|
||||
"stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []),
|
||||
"post_grad_nodes": post_grad_nodes,
|
||||
"pre_grad_nodes": list(pre_grad_nodes),
|
||||
}
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
signpost_event(
|
||||
"inductor",
|
||||
"provenance_tracking_error",
|
||||
{
|
||||
"function": "create_kernel_information_json",
|
||||
"error_msg": str(e),
|
||||
"stack_trace": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
def set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut],
|
||||
kernel_name: str,
|
||||
|
Reference in New Issue
Block a user