mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942. # Feature This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.  The IR currently supports the following ops: - All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.) - Reinterpret views (`ReinterpretLine`) - Kernel definitions (`KernelDefinitionLine`) - Calls to defined kernels (`KernelCallLine`) - Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`) - Ops with multiple outputs (`MultiOutputLine`) - Tensor cleanup at the end of a graph (`FreeLine`) - Leaving comments in code (`CommentLine`) There are two main motivations for this refactor: 1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`. 2. This design will hopefully promote stronger modularity and encapsulation. a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends. b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code. # Implementation details The implementation mainly consists of separating direct C++/Python codegen into two phases: 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do. 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code. The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros. # Test plan Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage. The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage. # Future directions These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files: - User-defined Triton kernels. - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.) - Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`. - Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants. These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes. One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor. Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458 Approved by: https://github.com/eellison
285 lines
11 KiB
Python
285 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
from enum import Enum
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
from torch import dtype as torch_dtype
|
|
|
|
from .. import config
|
|
from ..virtualized import V
|
|
from .multi_kernel import MultiKernel
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _print_debugging_tensor_value_info(msg, arg):
|
|
# helper for printing debugging stats for intermediate tensor values
|
|
# at jit inductor level codegen
|
|
max_numel_to_print = 64
|
|
print(msg)
|
|
if not isinstance(arg, torch.Tensor):
|
|
print("Value: ", arg)
|
|
return
|
|
numel = arg.float().numel()
|
|
# print the debug printing stats
|
|
if numel <= max_numel_to_print:
|
|
print(arg)
|
|
print("Number of elements: ", numel)
|
|
print("Size: ", arg.float().size())
|
|
print("Dtype: ", arg.float().mean().item())
|
|
print("Mean: ", arg.float().mean().item())
|
|
print("Min: ", arg.float().min().item())
|
|
print("Max: ", arg.float().max().item())
|
|
print("Std: ", arg.float().std().item())
|
|
|
|
|
|
# AOTI debug printing related configs
|
|
class IntermediateValueDebuggingLevel(Enum):
|
|
# OFF: No intermediate tensor value debug info will be printed or saved.
|
|
OFF = "0"
|
|
# LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
|
|
SAVE_ONLY = "1"
|
|
# LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
|
|
PRINT_ONLY = "2"
|
|
# LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed.
|
|
# This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc.
|
|
PRINT_KERNEL_NAMES_ONLY = "3"
|
|
|
|
|
|
class DebugPrinterManager:
|
|
def __init__(
|
|
self,
|
|
debug_printer_level,
|
|
use_array_ref: bool,
|
|
writeline: Optional[Callable[..., None]] = None,
|
|
args_to_print_or_save: Optional[list[str]] = None,
|
|
kernel_name: str = "",
|
|
kernel=None,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
kernel_type=None,
|
|
):
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
|
|
self.use_array_ref = use_array_ref
|
|
if args_to_print_or_save is None:
|
|
args_to_print_or_save = []
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures: Optional[list[type]] = None
|
|
self.kernel = kernel
|
|
self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
|
|
self.kernel_type = None
|
|
|
|
def __enter__(self):
|
|
self._perform_debug_print_or_save_helper(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch=True,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
|
|
def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
|
|
self._perform_debug_print_or_save_helper(
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch=False,
|
|
arg_signatures=arg_signatures,
|
|
)
|
|
|
|
def _perform_debug_print_or_save_helper(
|
|
self,
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
):
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
|
|
return
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
|
|
# by default save all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_save(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
# by default print all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_print(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
if (
|
|
self.debug_printer_level
|
|
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
|
|
):
|
|
# Print all kernel names to the console only
|
|
self.codegen_intermediate_tensor_value_print(
|
|
[],
|
|
self.kernel_name,
|
|
before_launch,
|
|
)
|
|
|
|
@functools.lru_cache # noqa: B019
|
|
def _get_debug_filtered_kernel_names(self) -> list[str]:
|
|
if config.aot_inductor.filtered_kernel_names is None:
|
|
return []
|
|
return [
|
|
x.strip()
|
|
for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
|
|
]
|
|
|
|
def set_printer_args(
|
|
self,
|
|
args_to_print_or_save: list[str],
|
|
kernel_name: str,
|
|
arg_signatures: Optional[list[type]],
|
|
kernel,
|
|
kernel_type=None,
|
|
):
|
|
# Note: MultiKernel debug printing is not supported for now
|
|
if isinstance(kernel, MultiKernel):
|
|
log.info(
|
|
"MultiKernel type is not supported in AOTI debug printer tool yet."
|
|
)
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
|
|
|
|
self.kernel_type = kernel_type
|
|
# Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to
|
|
# get the list of args_to_print_or_save
|
|
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
|
|
if kernel_type == "extern":
|
|
args_to_print_or_save_extern = [
|
|
arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
|
|
]
|
|
self.args_to_print_or_save = args_to_print_or_save_extern
|
|
elif kernel_type == "cpp":
|
|
self.args_to_print_or_save = [
|
|
(
|
|
f"copy_arrayref_tensor_to_tensor({arg})"
|
|
if self.use_array_ref
|
|
else arg
|
|
)
|
|
for arg in args_to_print_or_save
|
|
if arg.startswith(("buf", "arg"))
|
|
]
|
|
else:
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures = arg_signatures
|
|
self.kernel = kernel
|
|
|
|
def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None:
|
|
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
return
|
|
for arg in input_args_to_print:
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.prefix.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");'
|
|
)
|
|
|
|
def codegen_intermediate_tensor_value_save(
|
|
self,
|
|
args_to_save,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
) -> None:
|
|
for i, arg in enumerate(args_to_save):
|
|
if arg_signatures is not None and not isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
continue
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
|
|
)
|
|
else:
|
|
cwd = os.getcwd()
|
|
saved_dir = cwd + "/tmp/jit_inductor/"
|
|
if not os.path.exists(saved_dir):
|
|
log.info(
|
|
"Creating directory to save inductor intermediate tensor values."
|
|
)
|
|
os.makedirs(saved_dir)
|
|
# Save the model to the directory
|
|
saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
|
|
log.info(
|
|
"Saved intermediate tensor %s for %s to %s",
|
|
arg,
|
|
kernel_name,
|
|
saved_path,
|
|
)
|
|
line = f"torch.save({arg}, '{saved_path}')"
|
|
V.graph.wrapper_code.writeline(line)
|
|
|
|
def codegen_intermediate_tensor_value_print(
|
|
self,
|
|
args_to_print,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
) -> None:
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
|
|
# if the debug printing level is PRINT_KERNEL_NAMES_ONLY
|
|
# we only print the kernel name to the console
|
|
if (
|
|
self.debug_printer_level
|
|
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
|
|
):
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.writeline(
|
|
f'printf("[ {launch_prefix}: {kernel_name} ]\\n");'
|
|
)
|
|
return
|
|
|
|
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
return
|
|
for i, arg in enumerate(args_to_print):
|
|
# when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
|
|
# check if filtered kernel name list is provided
|
|
if (
|
|
len(self.filtered_kernel_names_to_print) > 0
|
|
and kernel_name.lower() not in self.filtered_kernel_names_to_print
|
|
):
|
|
continue
|
|
if V.graph.cpp_wrapper:
|
|
if arg_signatures is not None and isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
|
|
)
|
|
elif arg_signatures is not None and isinstance(
|
|
arg_signatures[i],
|
|
(
|
|
type(torch._inductor.codegen.wrapper.SymbolicCallArg),
|
|
type(int),
|
|
type(float),
|
|
type(bool),
|
|
),
|
|
):
|
|
V.graph.wrapper_code.writeline(
|
|
f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");'
|
|
)
|
|
else:
|
|
if arg_signatures is None and self.kernel_type == "cpp" or "extern":
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
|
|
)
|
|
else:
|
|
V.graph.wrapper_code.writeline(
|
|
f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})'
|
|
)
|