mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR fixes the condition ``` if arg_signatures is None and self.kernel_type == "cpp" or "extern" ``` which is interpreted as ``` if (arg_signatures is None and self.kernel_type == "cpp") or ("extern"): ``` and it is always evaluated to `True`. According to the context the intention was ``` if arg_signatures is None and (self.kernel_type == "cpp" or self.kernel_type == "extern"): ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165033 Approved by: https://github.com/Skylion007
285 lines
11 KiB
Python
285 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
from enum import Enum
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
from torch import dtype as torch_dtype
|
|
|
|
from .. import config
|
|
from ..virtualized import V
|
|
from .multi_kernel import MultiKernel
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _print_debugging_tensor_value_info(msg, arg):
|
|
# helper for printing debugging stats for intermediate tensor values
|
|
# at jit inductor level codegen
|
|
max_numel_to_print = 64
|
|
print(msg)
|
|
if not isinstance(arg, torch.Tensor):
|
|
print("Value: ", arg)
|
|
return
|
|
numel = arg.float().numel()
|
|
# print the debug printing stats
|
|
if numel <= max_numel_to_print:
|
|
print(arg)
|
|
print("Number of elements: ", numel)
|
|
print("Size: ", arg.float().size())
|
|
print("Dtype: ", arg.float().mean().item())
|
|
print("Mean: ", arg.float().mean().item())
|
|
print("Min: ", arg.float().min().item())
|
|
print("Max: ", arg.float().max().item())
|
|
print("Std: ", arg.float().std().item())
|
|
|
|
|
|
# AOTI debug printing related configs
|
|
class IntermediateValueDebuggingLevel(Enum):
|
|
# OFF: No intermediate tensor value debug info will be printed or saved.
|
|
OFF = "0"
|
|
# LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed.
|
|
SAVE_ONLY = "1"
|
|
# LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed.
|
|
PRINT_ONLY = "2"
|
|
# LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed.
|
|
# This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc.
|
|
PRINT_KERNEL_NAMES_ONLY = "3"
|
|
|
|
|
|
class DebugPrinterManager:
|
|
def __init__(
|
|
self,
|
|
debug_printer_level,
|
|
use_array_ref: bool,
|
|
writeline: Optional[Callable[..., None]] = None,
|
|
args_to_print_or_save: Optional[list[str]] = None,
|
|
kernel_name: str = "",
|
|
kernel=None,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
kernel_type=None,
|
|
):
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
|
|
self.use_array_ref = use_array_ref
|
|
if args_to_print_or_save is None:
|
|
args_to_print_or_save = []
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures: Optional[list[type]] = None
|
|
self.kernel = kernel
|
|
self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names()
|
|
self.kernel_type = None
|
|
|
|
def __enter__(self):
|
|
self._perform_debug_print_or_save_helper(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch=True,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
|
|
def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures):
|
|
self._perform_debug_print_or_save_helper(
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch=False,
|
|
arg_signatures=arg_signatures,
|
|
)
|
|
|
|
def _perform_debug_print_or_save_helper(
|
|
self,
|
|
args_to_print_or_save,
|
|
kernel_name,
|
|
before_launch,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
):
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF:
|
|
return
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY:
|
|
# by default save all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_save(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
# by default print all the tensor values before launch
|
|
self.codegen_intermediate_tensor_value_print(
|
|
self.args_to_print_or_save,
|
|
self.kernel_name,
|
|
before_launch,
|
|
arg_signatures=self.arg_signatures,
|
|
)
|
|
if (
|
|
self.debug_printer_level
|
|
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
|
|
):
|
|
# Print all kernel names to the console only
|
|
self.codegen_intermediate_tensor_value_print(
|
|
[],
|
|
self.kernel_name,
|
|
before_launch,
|
|
)
|
|
|
|
@functools.lru_cache # noqa: B019
|
|
def _get_debug_filtered_kernel_names(self) -> list[str]:
|
|
if config.aot_inductor.filtered_kernel_names is None:
|
|
return []
|
|
return [
|
|
x.strip()
|
|
for x in config.aot_inductor.filtered_kernel_names.lower().split(",")
|
|
]
|
|
|
|
def set_printer_args(
|
|
self,
|
|
args_to_print_or_save: list[str],
|
|
kernel_name: str,
|
|
arg_signatures: Optional[list[type]],
|
|
kernel,
|
|
kernel_type=None,
|
|
):
|
|
# Note: MultiKernel debug printing is not supported for now
|
|
if isinstance(kernel, MultiKernel):
|
|
log.info(
|
|
"MultiKernel type is not supported in AOTI debug printer tool yet."
|
|
)
|
|
self.debug_printer_level = IntermediateValueDebuggingLevel.OFF
|
|
|
|
self.kernel_type = kernel_type
|
|
# Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to
|
|
# get the list of args_to_print_or_save
|
|
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
|
|
if kernel_type == "extern":
|
|
args_to_print_or_save_extern = [
|
|
arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
|
|
]
|
|
self.args_to_print_or_save = args_to_print_or_save_extern
|
|
elif kernel_type == "cpp":
|
|
self.args_to_print_or_save = [
|
|
(
|
|
f"copy_arrayref_tensor_to_tensor({arg})"
|
|
if self.use_array_ref
|
|
else arg
|
|
)
|
|
for arg in args_to_print_or_save
|
|
if arg.startswith(("buf", "arg"))
|
|
]
|
|
else:
|
|
self.args_to_print_or_save = args_to_print_or_save
|
|
self.kernel_name = kernel_name
|
|
self.arg_signatures = arg_signatures
|
|
self.kernel = kernel
|
|
|
|
def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None:
|
|
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
return
|
|
for arg in input_args_to_print:
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.prefix.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");'
|
|
)
|
|
|
|
def codegen_intermediate_tensor_value_save(
|
|
self,
|
|
args_to_save,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
) -> None:
|
|
for i, arg in enumerate(args_to_save):
|
|
if arg_signatures is not None and not isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
continue
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");'
|
|
)
|
|
else:
|
|
cwd = os.getcwd()
|
|
saved_dir = cwd + "/tmp/jit_inductor/"
|
|
if not os.path.exists(saved_dir):
|
|
log.info(
|
|
"Creating directory to save inductor intermediate tensor values."
|
|
)
|
|
os.makedirs(saved_dir)
|
|
# Save the model to the directory
|
|
saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt"
|
|
log.info(
|
|
"Saved intermediate tensor %s for %s to %s",
|
|
arg,
|
|
kernel_name,
|
|
saved_path,
|
|
)
|
|
line = f"torch.save({arg}, '{saved_path}')"
|
|
V.graph.wrapper_code.writeline(line)
|
|
|
|
def codegen_intermediate_tensor_value_print(
|
|
self,
|
|
args_to_print,
|
|
kernel_name,
|
|
before_launch=True,
|
|
arg_signatures: Optional[list[type]] = None,
|
|
) -> None:
|
|
launch_prefix = "before_launch" if before_launch else "after_launch"
|
|
|
|
# if the debug printing level is PRINT_KERNEL_NAMES_ONLY
|
|
# we only print the kernel name to the console
|
|
if (
|
|
self.debug_printer_level
|
|
== IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY
|
|
):
|
|
if V.graph.cpp_wrapper:
|
|
V.graph.wrapper_code.writeline(
|
|
f'printf("[ {launch_prefix}: {kernel_name} ]\\n");'
|
|
)
|
|
return
|
|
|
|
if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY:
|
|
return
|
|
for i, arg in enumerate(args_to_print):
|
|
# when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY,
|
|
# check if filtered kernel name list is provided
|
|
if (
|
|
len(self.filtered_kernel_names_to_print) > 0
|
|
and kernel_name.lower() not in self.filtered_kernel_names_to_print
|
|
):
|
|
continue
|
|
if V.graph.cpp_wrapper:
|
|
if arg_signatures is not None and isinstance(
|
|
arg_signatures[i], torch_dtype
|
|
):
|
|
# infer from the arg data type (has torch.dtype) to see if it is a tensor type
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
|
|
)
|
|
elif arg_signatures is not None and isinstance(
|
|
arg_signatures[i],
|
|
(
|
|
type(torch._inductor.codegen.wrapper.SymbolicCallArg),
|
|
type(int),
|
|
type(float),
|
|
type(bool),
|
|
),
|
|
):
|
|
V.graph.wrapper_code.writeline(
|
|
f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");'
|
|
)
|
|
else:
|
|
if arg_signatures is None and self.kernel_type in ("cpp", "extern"):
|
|
V.graph.wrapper_code.writeline(
|
|
f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");'
|
|
)
|
|
else:
|
|
V.graph.wrapper_code.writeline(
|
|
f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})'
|
|
)
|