mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Summary: 1. Move the debug printer call a level lower -> at here :https://www.internalfb.com/code/fbsource/[931d7bbb9e7cf2dcb926f42718f56fc940903eec]/fbcode/caffe2/torch/_inductor/codegen/cpp_wrapper_cuda.py?lines=335 2. Add UT for validating debug printer for user defined triton kernel codegen The benefit of having the debug printer call happens at a more centralized place is 1) reduce the duplicate debug printer related logic code scattered everywhere in the codebase 2) it can handle more triton kernel codegen path as long as it invokes this `generate_kernel_call()` for example, it can automatically handle/support user_defined_kernel 's debug printing which is a pretty common use case we encounter in debugging Test Plan: ```AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_aoti_debug_printer_user_defined_triton_kernel_abi_compatible_cuda``` Also verified that templateKernel codegen path still works Differential Revision: D61949020 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134789 Approved by: https://github.com/ColinPeppler
176 lines
6.5 KiB
Python
176 lines
6.5 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
from torch import dtype as torch_dtype
|
|
|
|
from .. import config
|
|
from ..virtualized import V
|
|
from .multi_kernel import MultiKernel
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# AOTI debug printing related configs
|
|
class IntermediateValueDebuggingLevel(Enum):
|
|
# OFF: No intermediate tensor value debug info will be printed or saved.
|
|
OFF = "0"
|
|
# LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
|
|
SAVE_ONLY = "1"
|
|
# LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
|
|
PRINT_ONLY = "2"
|
|
|
|
|
|
class DebugPrinterManager:
|
|
def __init__(
|
|
self,
|
|
debug_printer_level,
|
|
args_to_print_or_save: Optional[List[str]] = None,
|
|
kernel_name: str = "",
|
|
kernel=None,
|
|
arg_signatures: Optional[List[type]] = None,
|
|
):
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
|
|
if args_to_print_or_save is None:
|
|
args_to_print_or_save = []
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures: Optional[List[type]] = None
|
|
self.kernel = kernel
|
|
self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
|
|
|
|
def __enter__(self):
|
|
self._perform_debug_print_or_save_helper(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch=True,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
|
|
def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
|
|
self._perform_debug_print_or_save_helper(
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch=False,
|
|
arg_signatures=arg_signatures,
|
|
)
|
|
|
|
def _perform_debug_print_or_save_helper(
|
|
self,
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch,
|
|
arg_signatures: Optional[List[type]] = None,
|
|
):
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
|
|
return
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
|
|
# by default save all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_save(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
# by default print all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_print(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
|
|
@functools.lru_cache # noqa: B019
|
|
def _get_debug_filtered_kernel_names(self) -> List[str]:
|
|
if config.aot_inductor.filtered_kernel_names is None:
|
|
return []
|
|
return [
|
|
x.strip()
|
|
for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
|
|
]
|
|
|
|
def set_printer_args(
|
|
self,
|
|
args_to_print_or_save: List[str],
|
|
kernel_name: str,
|
|
arg_signatures: Optional[List[type]],
|
|
kernel,
|
|
):
|
|
# Note: MultiKernel debug printing is not supported for now
|
|
if isinstance(kernel, MultiKernel):
|
|
log.info(
|
|
"MultiKernel type is not supported in AOTI debug printer tool yet."
|
|
)
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures = arg_signatures
|
|
self.kernel = kernel
|
|
|
|
def codegen_intermediate_tensor_value_save(
|
|
self,
|
|
args_to_save,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[List[type]] = None,
|
|
) -> None:
|
|
for i, arg in enumerate(args_to_save):
|
|
if arg_signatures is not None and not isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
continue
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
if V.graph.cpp_wrapper:
|
|
if config.abi_compatible:
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
|
|
)
|
|
else:
|
|
# TODO: add non-abi compatible mode debug printing info
|
|
pass
|
|
else:
|
|
# currently, not cpp wrapper codegen mode not supported.
|
|
pass
|
|
|
|
def codegen_intermediate_tensor_value_print(
|
|
self,
|
|
args_to_print,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[List[type]] = None,
|
|
) -> None:
|
|
for i, arg in enumerate(args_to_print):
|
|
if arg_signatures is not None and not isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
continue
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
# when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
|
|
# check if filtered kernel name list is provided
|
|
if (
|
|
len(self.filtered_kernel_names_to_print) > 0
|
|
and kernel_name not in self.filtered_kernel_names_to_print
|
|
):
|
|
continue
|
|
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
if V.graph.cpp_wrapper:
|
|
if config.abi_compatible:
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
|
|
)
|
|
else:
|
|
# TODO: add non-abi compatible mode debug printing info
|
|
pass
|
|
else:
|
|
line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})"
|
|
V.graph.wrapper_code.writeline(line)
|