mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pt2_provenance_tracking] add support for cpp kernel (#149185)
Summary: As title. Add inductor cpp kernel to post grad graph node mapping & UT. Context: Raised as a feature request for AOTI CPU case. https://fb.workplace.com/groups/1028545332188949/permalink/1169020841474730/ Differential Revision: D71181284 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149185 Approved by: https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7869196482
commit
b8f91bcb14
@ -5,12 +5,14 @@ import logging
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.debug import create_node_mapping
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
@ -31,7 +33,6 @@ class Model(torch.nn.Module):
|
||||
return z
|
||||
|
||||
|
||||
@requires_cuda
|
||||
@config.patch("trace.enabled", True)
|
||||
class TestProvenanceTracingArtifact(TestCase):
|
||||
"""
|
||||
@ -39,6 +40,7 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
corresponding "inductor triton kernel node" is expected.
|
||||
"""
|
||||
|
||||
@requires_cuda
|
||||
def _check_provenance_tracing_artifact(self, filepath):
|
||||
self.assertTrue(filepath.is_dir())
|
||||
filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json"
|
||||
@ -113,6 +115,7 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
]
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
|
||||
|
||||
@requires_cuda
|
||||
def test_triton_kernel_to_post_grad_tracing(self):
|
||||
a = torch.randn(10, 20, device="cuda")
|
||||
b = torch.randn(20, 30, device="cuda")
|
||||
@ -146,6 +149,60 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
finally:
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
|
||||
def test_triton_kernel_to_post_grad_tracing_cpu(self):
|
||||
a = torch.randn(10, 20, device="cpu")
|
||||
b = torch.randn(20, 30, device="cpu")
|
||||
c = torch.randn(10, 30, device="cpu")
|
||||
example_inputs = (a, b, c)
|
||||
|
||||
model = Model()
|
||||
ep = torch.export._trace._export(model, example_inputs)
|
||||
gm = ep.module()
|
||||
filepath = None
|
||||
|
||||
for backend in ["aot_inductor", "inductor"]:
|
||||
try:
|
||||
with config.patch(
|
||||
{
|
||||
"trace.debug_dir": tempfile.mkdtemp(),
|
||||
"force_disable_caches": True,
|
||||
}
|
||||
):
|
||||
with self.assertLogs(
|
||||
logging.getLogger("torch._inductor.debug"),
|
||||
level=logging.WARNING,
|
||||
) as cm:
|
||||
if backend == "aot_inductor":
|
||||
so_path = torch._inductor.aot_compile(gm, example_inputs)
|
||||
optimized = AOTIRunnerUtil.load("cpu", so_path)
|
||||
optimized(*example_inputs)
|
||||
else:
|
||||
compiled = torch.compile(gm, backend=backend)
|
||||
compiled(*example_inputs)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
|
||||
self.assertTrue(m)
|
||||
filepath = Path(m.group(1))
|
||||
filename = (
|
||||
Path(filepath)
|
||||
/ "inductor_triton_kernel_to_post_grad_nodes.json"
|
||||
)
|
||||
with open(filename) as f:
|
||||
actual_data = json.load(f)
|
||||
# check the inductor kernel to post grad nodes mapping is expected for cpu
|
||||
expected_data = {
|
||||
"cpp_fused_mul_0": ["mul"],
|
||||
"cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"],
|
||||
}
|
||||
self.assertEqual(
|
||||
sorted(actual_data.items()), sorted(expected_data.items())
|
||||
)
|
||||
|
||||
finally:
|
||||
if filepath:
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
|
||||
class TestProvenanceTracingNodeMapping(TestCase):
|
||||
def test_create_node_mapping(self):
|
||||
|
||||
@ -43,6 +43,7 @@ from ..utils import (
|
||||
is_welford_reduction,
|
||||
parallel_num_threads,
|
||||
Placeholder,
|
||||
set_kernel_post_grad_provenance_tracing,
|
||||
sympy_index_symbol,
|
||||
sympy_index_symbol_with_prefix,
|
||||
sympy_product,
|
||||
@ -5150,6 +5151,10 @@ class CppScheduling(BaseScheduling):
|
||||
else ""
|
||||
)
|
||||
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
|
||||
# below add provenance tracing info for cpu CppKernel types
|
||||
if config.trace.enabled:
|
||||
set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
|
||||
|
||||
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
|
||||
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
|
||||
|
||||
Reference in New Issue
Block a user