mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor Provenance Tracking (#163378)
Summary: - Move the `provenance_level` flag check to inside the `set_kernel_post_grad_provenance_tracing` call to simply the code - Move the `set_kernel_post_grad_provenance_tracing` call and `write_provenance_debug_handle` call to `codegen_comment`. - If some `call_kernel` call sites don't have a proceeding `codegen_comment` call, add one. Now all `call_kernel` call sites are accompanied with a `codegen_comment` call. - Add a `codegen_comment` method to BaseScheduling and remove the noop `codegen_comment` method in Scheduling - Remove `debug_handle` from `call_kernel`. Test Plan: CI ``` buck run @//mode/opt-split-dwarf fbcode//caffe2/test/inductor:provenance_tracing ``` Differential Revision: D82839271 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163378 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
908bcfd403
commit
520fca82c8
@ -150,7 +150,7 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
"cppCodeToPost",
|
||||
{
|
||||
"triton_poi_fused_mul_0:1": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1:2": [
|
||||
"triton_poi_fused_addmm_gelu_1:3": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
@ -164,12 +164,12 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
"postToCppCode",
|
||||
{
|
||||
"mul": ["triton_poi_fused_mul_0:1"],
|
||||
"mul_3": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"mul_1": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"add_tensor": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"add": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"erf": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"mul_2": ["triton_poi_fused_addmm_gelu_1:2"],
|
||||
"mul_3": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
"mul_1": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
"add_tensor": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
"add": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
"erf": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
"mul_2": ["triton_poi_fused_addmm_gelu_1:3"],
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -195,18 +195,18 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
),
|
||||
]
|
||||
if backend == "aot_inductor":
|
||||
expected_mapping[0][1]["aoti_torch_cuda_mm_out:3"] = [
|
||||
expected_mapping[0][1]["aoti_torch_cuda_mm_out:2"] = [
|
||||
"mm_default"
|
||||
]
|
||||
expected_mapping[1][1]["mm_default"] = [
|
||||
"aoti_torch_cuda_mm_out:3"
|
||||
"aoti_torch_cuda_mm_out:2"
|
||||
]
|
||||
else:
|
||||
expected_mapping[0][1]["extern_kernels.mm:3"] = [
|
||||
expected_mapping[0][1]["extern_kernels.mm:2"] = [
|
||||
"mm_default"
|
||||
]
|
||||
expected_mapping[1][1]["mm_default"] = [
|
||||
"extern_kernels.mm:3"
|
||||
"extern_kernels.mm:2"
|
||||
]
|
||||
self._check_provenance_tracking_node_mappings(
|
||||
filepath, expected_mapping
|
||||
@ -217,8 +217,8 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
if backend == "aot_inductor":
|
||||
expected_data = {
|
||||
"cpp_fused_mul_0:1": ["mul"],
|
||||
"aoti_torch_cpu_addmm_out:3": ["addmm"],
|
||||
"cpp_fused_gelu_1:2": [
|
||||
"aoti_torch_cpu_addmm_out:2": ["addmm"],
|
||||
"cpp_fused_gelu_1:3": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add",
|
||||
@ -230,14 +230,14 @@ class TestProvenanceTracingArtifact(TestCase):
|
||||
# backend == "inductor"
|
||||
expected_data = {
|
||||
"cpp_fused_mul_0:1": ["mul"],
|
||||
"cpp_fused_gelu_1:2": [
|
||||
"cpp_fused_gelu_1:3": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
"extern_kernels.addmm:3": ["addmm"],
|
||||
"extern_kernels.addmm:2": ["addmm"],
|
||||
}
|
||||
self._check_provenance_tracing_kernel_to_post_grad(
|
||||
filepath, expected_data
|
||||
@ -550,22 +550,22 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
example_inputs = (x, a, b, c)
|
||||
|
||||
expected = {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:1": [
|
||||
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:2": [
|
||||
"x = self.sigmoid(x)",
|
||||
"x = self.fc1(x)",
|
||||
"x = self.relu(x)",
|
||||
],
|
||||
"triton_poi_fused_mul_1:2": [
|
||||
"triton_poi_fused_mul_1:3": [
|
||||
"d = a * 3.14",
|
||||
],
|
||||
"triton_poi_fused_addmm_gelu_2:3": [
|
||||
"triton_poi_fused_addmm_gelu_2:5": [
|
||||
"z = torch.nn.functional.gelu(y)",
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
"extern_kernels.mm:4": [
|
||||
"extern_kernels.mm:1": [
|
||||
"x = self.fc1(x)",
|
||||
],
|
||||
"extern_kernels.mm:5": [
|
||||
"extern_kernels.mm:4": [
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
}
|
||||
@ -648,7 +648,7 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
kernel_info = json.load(f)
|
||||
|
||||
expected = {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_0:1": {
|
||||
"triton_poi_fused_addmm_relu_sigmoid_0:2": {
|
||||
"stack_traces": [
|
||||
"x = self.sigmoid(x)",
|
||||
"x = self.fc1(x)",
|
||||
@ -657,14 +657,14 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
|
||||
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
|
||||
},
|
||||
"triton_poi_fused_mul_1:2": {
|
||||
"triton_poi_fused_mul_1:3": {
|
||||
"stack_traces": [
|
||||
"d = a * 3.14",
|
||||
],
|
||||
"post_grad_nodes": ["mul"],
|
||||
"pre_grad_nodes": ["mul"],
|
||||
},
|
||||
"triton_poi_fused_addmm_gelu_2:3": {
|
||||
"triton_poi_fused_addmm_gelu_2:5": {
|
||||
"stack_traces": [
|
||||
"z = torch.nn.functional.gelu(y)",
|
||||
"y = torch.addmm(c, d, b)",
|
||||
@ -679,14 +679,14 @@ class TestProvenanceTracingStackTraces(TestCase):
|
||||
],
|
||||
"pre_grad_nodes": ["gelu", "addmm"],
|
||||
},
|
||||
"aoti_torch_cuda_mm_out:4": {
|
||||
"aoti_torch_cuda_mm_out:1": {
|
||||
"stack_traces": [
|
||||
"x = self.fc1(x)",
|
||||
],
|
||||
"post_grad_nodes": ["mm_default_1"],
|
||||
"pre_grad_nodes": ["linear"],
|
||||
},
|
||||
"aoti_torch_cuda_mm_out:5": {
|
||||
"aoti_torch_cuda_mm_out:4": {
|
||||
"stack_traces": [
|
||||
"y = torch.addmm(c, d, b)",
|
||||
],
|
||||
|
@ -5376,6 +5376,7 @@ class CppScheduling(BaseScheduling):
|
||||
)
|
||||
user.node.mark_run()
|
||||
|
||||
self.codegen_comment(node_schedule, kernel_name)
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
self.free_buffers_in_scheduler()
|
||||
@ -5441,18 +5442,20 @@ class CppScheduling(BaseScheduling):
|
||||
kernel_name = self.define_kernel(
|
||||
src_code, self.kernel_group.scheduled_nodes
|
||||
)
|
||||
# below add provenance tracing info for cpu CppKernel types
|
||||
debug_handle: Optional[int] = None
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
self.kernel_group.scheduled_nodes, kernel_name
|
||||
)
|
||||
self.kernel_group.call_kernel(
|
||||
V.graph.wrapper_code, kernel_name, debug_handle=debug_handle
|
||||
)
|
||||
self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name)
|
||||
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
self.reset_kernel_group()
|
||||
self._set_flush_status(False)
|
||||
|
||||
def codegen_comment(self, node_schedule, kernel_name=None):
|
||||
# below add provenance tracing info for cpu CppKernel types
|
||||
wrapper = V.graph.wrapper_code
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
|
||||
|
||||
class KernelGroup:
|
||||
def __init__(self):
|
||||
@ -5524,14 +5527,13 @@ class KernelGroup:
|
||||
code.splice(self.loops_code)
|
||||
return code.getvalue()
|
||||
|
||||
def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None):
|
||||
def call_kernel(self, wrapper, kernel_name):
|
||||
_, call_args, arg_types = self.args.cpp_argdefs()
|
||||
wrapper.generate_kernel_call(
|
||||
kernel_name,
|
||||
call_args,
|
||||
triton=False,
|
||||
arg_types=arg_types,
|
||||
debug_handle=debug_handle,
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,7 +22,6 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config, cpp_builder, ir
|
||||
from ..debug import set_kernel_post_grad_provenance_tracing
|
||||
from ..utils import _align, DeferredLineBase, LineContext, normalize_name
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
@ -1297,14 +1296,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
|
||||
device = d.type if (d := extern_kernel.get_device()) else self.device
|
||||
|
||||
debug_handle = None
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
extern_kernel, extern_kernel.get_kernel_name(), is_extern=True
|
||||
)
|
||||
|
||||
self.generate_c_shim_extern_kernel_call(
|
||||
extern_kernel.get_kernel_name(), args, device, debug_handle=debug_handle
|
||||
extern_kernel.get_kernel_name(), args, device
|
||||
)
|
||||
|
||||
if extern_kernel.python_kernel_name in (
|
||||
@ -1362,19 +1355,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
args = args + output_args
|
||||
device = d.type if (d := fallback_kernel.get_device()) else self.device
|
||||
|
||||
debug_handle = None
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
shim_fn = self.get_c_shim_func_name(fallback_kernel.cpp_kernel_name, device) # type: ignore[arg-type]
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
fallback_kernel,
|
||||
shim_fn,
|
||||
is_extern=True,
|
||||
)
|
||||
self.generate_c_shim_extern_kernel_call(
|
||||
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
|
||||
args,
|
||||
device,
|
||||
debug_handle=debug_handle,
|
||||
)
|
||||
for raii_handle in output_raii_handles:
|
||||
self.writeline(raii_handle)
|
||||
|
@ -175,6 +175,7 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
call_args, kernel_name, arg_signatures, kernel
|
||||
)
|
||||
with debug_printer_manager:
|
||||
self.codegen_comment(node_schedule, kernel_name)
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
|
@ -135,6 +135,7 @@ class CuteDSLScheduling(BaseScheduling):
|
||||
with V.set_kernel_handler(kernel):
|
||||
node_schedule = [template_node]
|
||||
kernel_name = self.define_kernel(src_code_str, node_schedule)
|
||||
self.codegen_comment(node_schedule, kernel_name)
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
self.free_buffers_in_scheduler()
|
||||
|
@ -94,6 +94,7 @@ class ROCmCPPScheduling(BaseScheduling):
|
||||
with V.set_kernel_handler(kernel):
|
||||
node_schedule = [template_node]
|
||||
kernel_name = self.define_kernel(src_code, node_schedule)
|
||||
self.codegen_comment(node_schedule, kernel_name)
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
self.free_buffers_in_scheduler()
|
||||
|
@ -41,7 +41,6 @@ from ..dependencies import MemoryDep, StarDep, WeakDep
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import IRNode
|
||||
|
||||
from ..debug import set_kernel_post_grad_provenance_tracing
|
||||
from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
from ..runtime.runtime_utils import green_text, yellow_text
|
||||
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
|
||||
@ -1471,17 +1470,10 @@ class SIMDScheduling(BaseScheduling):
|
||||
for kernel in kernels:
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
MultiKernel.merge_workspaces_inplace(kernels)
|
||||
debug_handles: list[tuple[str, Optional[int]]] = []
|
||||
for kernel in kernels:
|
||||
with V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
debug_handles.append((kernel_name, debug_handle))
|
||||
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
|
||||
kernel.kernel_name = kernel_name
|
||||
kernel.code_hash = code_hash(src_code)
|
||||
@ -1497,11 +1489,11 @@ class SIMDScheduling(BaseScheduling):
|
||||
for node in kernel_features.scheduler_nodes():
|
||||
node.mark_run()
|
||||
|
||||
self.codegen_comment(node_schedule)
|
||||
for kernel_name, debug_handle in debug_handles:
|
||||
V.graph.wrapper_code.write_provenance_debug_handle(
|
||||
kernel_name, debug_handle
|
||||
)
|
||||
# filter out NodeScheduleMarker
|
||||
base_scheduler_nodes = [
|
||||
node for node in node_schedule if isinstance(node, BaseSchedulerNode)
|
||||
]
|
||||
self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name)
|
||||
final_kernel.call_kernel(final_kernel.kernel_name)
|
||||
|
||||
if config.nan_asserts:
|
||||
@ -1696,11 +1688,6 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, kernel.kernel_name
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
||||
def _get_multikernel_shapes(
|
||||
@ -1819,8 +1806,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
|
||||
multi_kernel = SizeHintMultiKernel(kernels)
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
|
||||
self.codegen_comment(node_schedule, multi_kernel.kernel_name)
|
||||
multi_kernel.call_kernel(multi_kernel.kernel_name)
|
||||
V.graph.removed_buffers |= multi_kernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove
|
||||
@ -1851,7 +1837,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
)
|
||||
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
self.codegen_comment(node_schedule, kernel.kernel_name)
|
||||
kernel.call_kernel(kernel.kernel_name, template_node.node)
|
||||
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
@ -1937,12 +1923,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
for src_code, kernel, _ in kernel_code_list:
|
||||
kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
|
||||
# dump provenance node info for ComboKernelNode/ForeachKernel type
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
combo_kernel_node.snodes, kernel_name
|
||||
)
|
||||
self.codegen_comment([combo_kernel_node])
|
||||
self.codegen_comment(combo_kernel_node.snodes, kernel_name)
|
||||
log.debug("ComboKernels: generated kernel %s.", kernel_name)
|
||||
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
|
||||
@ -2663,9 +2644,6 @@ class SIMDScheduling(BaseScheduling):
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
|
||||
return src_code
|
||||
|
||||
def codegen_comment(self, node_schedule):
|
||||
pass
|
||||
|
||||
def define_kernel(self, src_code, node_schedule, kernel):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -34,6 +34,7 @@ from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
|
||||
from ..debug import set_kernel_post_grad_provenance_tracing
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime import triton_heuristics
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
@ -4867,7 +4868,7 @@ class TritonScheduling(SIMDScheduling):
|
||||
)
|
||||
return cls.backend_features
|
||||
|
||||
def codegen_comment(self, node_schedule):
|
||||
def codegen_comment(self, node_schedule, kernel_name=None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
if origins:
|
||||
@ -4893,6 +4894,13 @@ class TritonScheduling(SIMDScheduling):
|
||||
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
|
||||
)
|
||||
|
||||
if kernel_name:
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
|
||||
def define_kernel(self, src_code, node_schedule, kernel):
|
||||
wrapper = V.graph.wrapper_code
|
||||
if src_code in wrapper.src_to_kernel:
|
||||
|
@ -40,7 +40,6 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import async_compile, config, ir
|
||||
from ..codecache import output_code_log
|
||||
from ..debug import set_kernel_post_grad_provenance_tracing
|
||||
from ..ir import IRNode, ReinterpretView
|
||||
from ..runtime import triton_heuristics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
@ -509,19 +508,12 @@ class ExternKernelOutLine(WrapperLine):
|
||||
else:
|
||||
kernel_name = node.get_kernel_name()
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
provenance_debug_handle: Optional[int] = None
|
||||
# set provenance tracing kernel mapping for ExternKernel types
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
provenance_debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node, kernel_name, is_extern=True
|
||||
)
|
||||
self.wrapper._generate_extern_kernel_out_helper(
|
||||
kernel_name,
|
||||
node.codegen_reference(),
|
||||
node.output_view.codegen_reference() if node.output_view else None,
|
||||
args,
|
||||
device,
|
||||
provenance_debug_handle,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
@ -1516,13 +1508,11 @@ class PythonWrapperCodegen(CodeGen):
|
||||
out_view: Optional[str],
|
||||
args: list[str],
|
||||
device: str,
|
||||
debug_handle: Optional[int] = None,
|
||||
) -> None:
|
||||
# add debug printer code for triton kernel calls at (jit) inductor level
|
||||
debug_printer_manager = V.graph.wrapper_code.debug_printer
|
||||
debug_printer_manager.set_printer_args(args, kernel, None, None, "extern")
|
||||
args.append(f"out={out_view if out_view else out}")
|
||||
self.write_provenance_debug_handle(kernel, debug_handle)
|
||||
with debug_printer_manager:
|
||||
self.writeline(f"{kernel}({', '.join(args)})")
|
||||
|
||||
@ -2696,7 +2686,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
raw_args=None,
|
||||
triton_meta=None,
|
||||
original_fxnode_name=None,
|
||||
debug_handle: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
@ -2716,7 +2705,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
)
|
||||
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
self.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
self.writeline(
|
||||
KernelCallLine(
|
||||
self,
|
||||
|
@ -1103,6 +1103,9 @@ def set_kernel_post_grad_provenance_tracing(
|
||||
Returns a unique int debug handler for each call to this function.
|
||||
"""
|
||||
|
||||
if config.trace.provenance_tracking_level == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
|
||||
|
||||
|
@ -5605,11 +5605,23 @@ class ExternKernel(InputsKernel):
|
||||
self.apply_constraint()
|
||||
self.freeze_layout()
|
||||
|
||||
def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
def codegen_comment(
|
||||
self, wrapper: PythonWrapperCodegen, kernel_name: Optional[str] = None
|
||||
) -> None:
|
||||
origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
|
||||
if origin_str:
|
||||
wrapper.make_comment(origin_str)
|
||||
|
||||
if not kernel_name:
|
||||
kernel_name = self.try_get_kernel_name()
|
||||
if kernel_name:
|
||||
from .debug import set_kernel_post_grad_provenance_tracing
|
||||
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
self, kernel_name, is_extern=True
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -5652,25 +5664,29 @@ class ExternKernel(InputsKernel):
|
||||
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
|
||||
)
|
||||
|
||||
def get_kernel_name(self) -> str:
|
||||
def try_get_kernel_name(self) -> Optional[str]:
|
||||
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
||||
device = d.type if (d := self.get_device()) else V.graph.device_type
|
||||
if V.graph.fx_wrapper:
|
||||
assert self.python_kernel_name is not None
|
||||
return self.python_kernel_name
|
||||
elif V.graph.cpp_wrapper:
|
||||
assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
|
||||
V.graph.wrapper_code
|
||||
)
|
||||
assert self.cpp_kernel_name is not None
|
||||
if self.cpp_kernel_name is None:
|
||||
return None
|
||||
return V.graph.wrapper_code.get_c_shim_func_name(
|
||||
self.cpp_kernel_name, device
|
||||
)
|
||||
else:
|
||||
assert self.python_kernel_name is not None
|
||||
return self.python_kernel_name
|
||||
|
||||
def get_kernel_name(self) -> str:
|
||||
name = self.try_get_kernel_name()
|
||||
assert name is not None
|
||||
return name
|
||||
|
||||
@staticmethod
|
||||
def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
||||
pw = Pointwise.create(
|
||||
@ -6803,7 +6819,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
|
||||
|
||||
self.codegen_comment(wrapper)
|
||||
self.codegen_comment(wrapper, new_name)
|
||||
wrapper.generate_kernel_call(
|
||||
new_name,
|
||||
args,
|
||||
|
@ -5606,7 +5606,7 @@ class Scheduler:
|
||||
V.graph.zero_dim_cpu_tensor_list.add(read.name)
|
||||
|
||||
|
||||
class BaseScheduling:
|
||||
class BaseScheduling: # noqa: docstring_linter
|
||||
def __init__(self, scheduler: Optional[Scheduler]):
|
||||
super().__init__()
|
||||
self.scheduler = scheduler
|
||||
@ -5749,3 +5749,19 @@ class BaseScheduling:
|
||||
and memory copy time in milliseconds on randomly generated inputs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def codegen_comment(
|
||||
self,
|
||||
node_schedule: Sequence[BaseSchedulerNode],
|
||||
kernel_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if kernel_name:
|
||||
from torch._inductor.debug import set_kernel_post_grad_provenance_tracing
|
||||
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
V.graph.wrapper_code.write_provenance_debug_handle(
|
||||
kernel_name, debug_handle
|
||||
)
|
||||
|
Reference in New Issue
Block a user