Files
pytorch/torch/_inductor/codegen/debug_utils.py
Rachel Guo ae3aa8ff73 [AOTI][Tooling][5/n] Refactor the debug printer call to a level lower (#134789)
Summary:
1. Move the debug printer call a level lower -> at here
:https://www.internalfb.com/code/fbsource/[931d7bbb9e7cf2dcb926f42718f56fc940903eec]/fbcode/caffe2/torch/_inductor/codegen/cpp_wrapper_cuda.py?lines=335
2. Add UT for validating debug printer for user defined triton kernel codegen

The benefit of having the debug printer call happens at a more centralized place is 1) reduce the duplicate debug printer related logic code scattered everywhere in the codebase 2) it can handle more triton kernel codegen path as long as it invokes this `generate_kernel_call()` for example,  it can automatically handle/support user_defined_kernel 's debug printing which is a pretty common use case we encounter in debugging

Test Plan:
```AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1  TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_aoti_debug_printer_user_defined_triton_kernel_abi_compatible_cuda```

Also verified that templateKernel codegen path still works

Differential Revision: D61949020

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134789
Approved by: https://github.com/ColinPeppler
2024-09-04 02:41:30 +00:00

176 lines
6.5 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
from enum import Enum
from typing import List, Optional
from torch import dtype as torch_dtype
from .. import config
from ..virtualized import V
from .multi_kernel import MultiKernel
log = logging.getLogger(__name__)
# AOTI debug printing related configs
class IntermediateValueDebuggingLevel(Enum):
# OFF: No intermediate tensor value debug info will be printed or saved.
OFF = "0"
# LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
SAVE_ONLY = "1"
# LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
PRINT_ONLY = "2"
class DebugPrinterManager:
def __init__(
self,
debug_printer_level,
args_to_print_or_save: Optional[List[str]] = None,
kernel_name: str = "",
kernel=None,
arg_signatures: Optional[List[type]] = None,
):
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
if args_to_print_or_save is None:
args_to_print_or_save = []
self.args_to_print_or_save = args_to_print_or_save
self.kernel_name = kernel_name
self.arg_signatures: Optional[List[type]] = None
self.kernel = kernel
self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
def __enter__(self):
self._perform_debug_print_or_save_helper(
self.args_to_print_or_save,
self.kernel_name,
before_launch=True,
arg_signatures=self.arg_signatures,
)
def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
self._perform_debug_print_or_save_helper(
args_to_print_or_save,
kernel_name,
before_launch=False,
arg_signatures=arg_signatures,
)
def _perform_debug_print_or_save_helper(
self,
args_to_print_or_save,
kernel_name,
before_launch,
arg_signatures: Optional[List[type]] = None,
):
if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
return
if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
# by default save all the tensor values before launch
self.codegen_intermediate_tensor_value_save(
self.args_to_print_or_save,
self.kernel_name,
before_launch,
arg_signatures=self.arg_signatures,
)
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
# by default print all the tensor values before launch
self.codegen_intermediate_tensor_value_print(
self.args_to_print_or_save,
self.kernel_name,
before_launch,
arg_signatures=self.arg_signatures,
)
@functools.lru_cache # noqa: B019
def _get_debug_filtered_kernel_names(self) -> List[str]:
if config.aot_inductor.filtered_kernel_names is None:
return []
return [
x.strip()
for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
]
def set_printer_args(
self,
args_to_print_or_save: List[str],
kernel_name: str,
arg_signatures: Optional[List[type]],
kernel,
):
# Note: MultiKernel debug printing is not supported for now
if isinstance(kernel, MultiKernel):
log.info(
"MultiKernel type is not supported in AOTI debug printer tool yet."
)
self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
self.args_to_print_or_save = args_to_print_or_save
self.kernel_name = kernel_name
self.arg_signatures = arg_signatures
self.kernel = kernel
def codegen_intermediate_tensor_value_save(
self,
args_to_save,
kernel_name,
before_launch=True,
arg_signatures: Optional[List[type]] = None,
) -> None:
for i, arg in enumerate(args_to_save):
if arg_signatures is not None and not isinstance(
arg_signatures[i], torch_dtype
):
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
continue
launch_prefix = "before_launch" if before_launch else "after_launch"
if V.graph.cpp_wrapper:
if config.abi_compatible:
V.graph.wrapper_code.writeline(
f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
)
else:
# TODO: add non-abi compatible mode debug printing info
pass
else:
# currently, not cpp wrapper codegen mode not supported.
pass
def codegen_intermediate_tensor_value_print(
self,
args_to_print,
kernel_name,
before_launch=True,
arg_signatures: Optional[List[type]] = None,
) -> None:
for i, arg in enumerate(args_to_print):
if arg_signatures is not None and not isinstance(
arg_signatures[i], torch_dtype
):
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
continue
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
# when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
# check if filtered kernel name list is provided
if (
len(self.filtered_kernel_names_to_print) > 0
and kernel_name not in self.filtered_kernel_names_to_print
):
continue
launch_prefix = "before_launch" if before_launch else "after_launch"
if V.graph.cpp_wrapper:
if config.abi_compatible:
V.graph.wrapper_code.writeline(
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
)
else:
# TODO: add non-abi compatible mode debug printing info
pass
else:
line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})"
V.graph.wrapper_code.writeline(line)