Files
pytorch/torch/_inductor/codegen/debug_utils.py
Blaine Burton Rister c0a0761871 [Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942.

# Feature

This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.

![wrapper_ir](https://github.com/user-attachments/assets/a61db21b-caf3-45d2-bfdb-91066ae4ba6b)

The IR currently supports the following ops:
- All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.)
- Reinterpret views (`ReinterpretLine`)
- Kernel definitions (`KernelDefinitionLine`)
- Calls to defined kernels (`KernelCallLine`)
- Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`)
- Ops with multiple outputs (`MultiOutputLine`)
- Tensor cleanup at the end of a graph (`FreeLine`)
- Leaving comments in code (`CommentLine`)

There are two main motivations for this refactor:
1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`.
2. This design will hopefully promote stronger modularity and encapsulation.
   a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends.
   b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code.

# Implementation details

The implementation mainly consists of separating direct C++/Python codegen into two phases:
 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do.
 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code.

The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to  `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros.

# Test plan

Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage.

The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage.

# Future directions

These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files:
 - User-defined Triton kernels.
 - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.)
 -  Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`.
 -  Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants.

These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes.

One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor.

Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458
Approved by: https://github.com/eellison
2025-04-15 17:28:36 +00:00

285 lines
11 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
import os
from enum import Enum
from typing import Callable, Optional
import torch
from torch import dtype as torch_dtype
from .. import config
from ..virtualized import V
from .multi_kernel import MultiKernel
log = logging.getLogger(__name__)
def _print_debugging_tensor_value_info(msg, arg):
# helper for printing debugging stats for intermediate tensor values
# at jit inductor level codegen
max_numel_to_print = 64
print(msg)
if not isinstance(arg, torch.Tensor):
print("Value: ", arg)
return
numel = arg.float().numel()
# print the debug printing stats
if numel <= max_numel_to_print:
print(arg)
print("Number of elements: ", numel)
print("Size: ", arg.float().size())
print("Dtype: ", arg.float().mean().item())
print("Mean: ", arg.float().mean().item())
print("Min: ", arg.float().min().item())
print("Max: ", arg.float().max().item())
print("Std: ", arg.float().std().item())
# AOTI debug printing related configs
class IntermediateValueDebuggingLevel(Enum):
# OFF: No intermediate tensor value debug info will be printed or saved.
OFF = "0"
# LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
SAVE_ONLY = "1"
# LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
PRINT_ONLY = "2"
# LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed.
# This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc.
PRINT_KERNEL_NAMES_ONLY = "3"
class DebugPrinterManager:
def __init__(
self,
debug_printer_level,
use_array_ref: bool,
writeline: Optional[Callable[..., None]] = None,
args_to_print_or_save: Optional[list[str]] = None,
kernel_name: str = "",
kernel=None,
arg_signatures: Optional[list[type]] = None,
kernel_type=None,
):
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
self.use_array_ref = use_array_ref
if args_to_print_or_save is None:
args_to_print_or_save = []
self.args_to_print_or_save = args_to_print_or_save
self.kernel_name = kernel_name
self.arg_signatures: Optional[list[type]] = None
self.kernel = kernel
self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
self.kernel_type = None
def __enter__(self):
self._perform_debug_print_or_save_helper(
self.args_to_print_or_save,
self.kernel_name,
before_launch=True,
arg_signatures=self.arg_signatures,
)
def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
self._perform_debug_print_or_save_helper(
args_to_print_or_save,
kernel_name,
before_launch=False,
arg_signatures=arg_signatures,
)
def _perform_debug_print_or_save_helper(
self,
args_to_print_or_save,
kernel_name,
before_launch,
arg_signatures: Optional[list[type]] = None,
):
if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
return
if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
# by default save all the tensor values before launch
self.codegen_intermediate_tensor_value_save(
self.args_to_print_or_save,
self.kernel_name,
before_launch,
arg_signatures=self.arg_signatures,
)
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
# by default print all the tensor values before launch
self.codegen_intermediate_tensor_value_print(
self.args_to_print_or_save,
self.kernel_name,
before_launch,
arg_signatures=self.arg_signatures,
)
if (
self.debug_printer_level
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
):
# Print all kernel names to the console only
self.codegen_intermediate_tensor_value_print(
[],
self.kernel_name,
before_launch,
)
@functools.lru_cache # noqa: B019
def _get_debug_filtered_kernel_names(self) -> list[str]:
if config.aot_inductor.filtered_kernel_names is None:
return []
return [
x.strip()
for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
]
def set_printer_args(
self,
args_to_print_or_save: list[str],
kernel_name: str,
arg_signatures: Optional[list[type]],
kernel,
kernel_type=None,
):
# Note: MultiKernel debug printing is not supported for now
if isinstance(kernel, MultiKernel):
log.info(
"MultiKernel type is not supported in AOTI debug printer tool yet."
)
self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
self.kernel_type = kernel_type
# Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to
# get the list of args_to_print_or_save
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
if kernel_type == "extern":
args_to_print_or_save_extern = [
arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
]
self.args_to_print_or_save = args_to_print_or_save_extern
elif kernel_type == "cpp":
self.args_to_print_or_save = [
(
f"copy_arrayref_tensor_to_tensor({arg})"
if self.use_array_ref
else arg
)
for arg in args_to_print_or_save
if arg.startswith(("buf", "arg"))
]
else:
self.args_to_print_or_save = args_to_print_or_save
self.kernel_name = kernel_name
self.arg_signatures = arg_signatures
self.kernel = kernel
def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None:
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
return
for arg in input_args_to_print:
if V.graph.cpp_wrapper:
V.graph.wrapper_code.prefix.writeline(
f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");'
)
def codegen_intermediate_tensor_value_save(
self,
args_to_save,
kernel_name,
before_launch=True,
arg_signatures: Optional[list[type]] = None,
) -> None:
for i, arg in enumerate(args_to_save):
if arg_signatures is not None and not isinstance(
arg_signatures[i], torch_dtype
):
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
continue
launch_prefix = "before_launch" if before_launch else "after_launch"
if V.graph.cpp_wrapper:
V.graph.wrapper_code.writeline(
f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
)
else:
cwd = os.getcwd()
saved_dir = cwd + "/tmp/jit_inductor/"
if not os.path.exists(saved_dir):
log.info(
"Creating directory to save inductor intermediate tensor values."
)
os.makedirs(saved_dir)
# Save the model to the directory
saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
log.info(
"Saved intermediate tensor %s for %s to %s",
arg,
kernel_name,
saved_path,
)
line = f"torch.save({arg}, '{saved_path}')"
V.graph.wrapper_code.writeline(line)
def codegen_intermediate_tensor_value_print(
self,
args_to_print,
kernel_name,
before_launch=True,
arg_signatures: Optional[list[type]] = None,
) -> None:
launch_prefix = "before_launch" if before_launch else "after_launch"
# if the debug printing level is PRINT_KERNEL_NAMES_ONLY
# we only print the kernel name to the console
if (
self.debug_printer_level
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
):
if V.graph.cpp_wrapper:
V.graph.wrapper_code.writeline(
f'printf("[ {launch_prefix}: {kernel_name} ]\\n");'
)
return
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
return
for i, arg in enumerate(args_to_print):
# when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
# check if filtered kernel name list is provided
if (
len(self.filtered_kernel_names_to_print) > 0
and kernel_name.lower() not in self.filtered_kernel_names_to_print
):
continue
if V.graph.cpp_wrapper:
if arg_signatures is not None and isinstance(
arg_signatures[i], torch_dtype
):
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
V.graph.wrapper_code.writeline(
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
)
elif arg_signatures is not None and isinstance(
arg_signatures[i],
(
type(torch._inductor.codegen.wrapper.SymbolicCallArg),
type(int),
type(float),
type(bool),
),
):
V.graph.wrapper_code.writeline(
f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");'
)
else:
if arg_signatures is None and self.kernel_type == "cpp" or "extern":
V.graph.wrapper_code.writeline(
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
)
else:
V.graph.wrapper_code.writeline(
f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})'
)