cpp_wrapper: Move #includes to per-device header files (#143909)

This prepares us for the next PR in the stack, where we introduce pre-compiled per-device header files to save compilation time.

Differential Revision: [D67938955](https://our.internmc.facebook.com/intern/diff/D67938955)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143909
Approved by: https://github.com/desertfire
This commit is contained in:
Benjamin Glass
2025-01-14 17:17:19 +00:00
committed by PyTorch MergeBot
parent 05095a45f2
commit d62b3979da
25 changed files with 157 additions and 109 deletions

View File

@ -556,7 +556,7 @@ exclude_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=#include <pybind11\/',
'--pattern=#include <pybind11\/(^|[^(gil\.h)])',
'--allowlist-pattern=#include <torch\/csrc\/utils\/pybind.h>',
'--linter-name=PYBIND11_INCLUDE',
'--match-first-only',

View File

@ -1284,6 +1284,7 @@ def main():
"include/torch/csrc/distributed/autograd/rpc_messages/*.h",
"include/torch/csrc/dynamo/*.h",
"include/torch/csrc/inductor/*.h",
"include/torch/csrc/inductor/aoti_include/*.h",
"include/torch/csrc/inductor/aoti_package/*.h",
"include/torch/csrc/inductor/aoti_runner/*.h",
"include/torch/csrc/inductor/aoti_runtime/*.h",
@ -1291,6 +1292,8 @@ def main():
"include/torch/csrc/inductor/aoti_torch/c/*.h",
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
"include/torch/csrc/inductor/aoti_torch/generated/extend/*.h",
"include/torch/csrc/inductor/cpp_wrapper/*.h",
"include/torch/csrc/inductor/cpp_wrapper/device_internal/*.h",
"include/torch/csrc/jit/*.h",
"include/torch/csrc/jit/backends/*.h",
"include/torch/csrc/jit/generated/*.h",

View File

@ -688,7 +688,6 @@ def torch_key() -> bytes:
# a hash representing the state of the source code.
extra_files = (
"codegen/aoti_runtime/interface.cpp",
"codegen/aoti_runtime/implementation.cpp",
"codegen/cpp_prefix.h",
"script.ld",
)

View File

@ -250,9 +250,6 @@ class DeviceOpOverrides:
def kernel_driver(self):
raise NotImplementedError
def abi_compatible_header(self):
raise NotImplementedError
def cpp_stream_type(self):
raise NotImplementedError

View File

@ -18,7 +18,7 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config, ir
from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name
from ..utils import _align, cache_on_self, normalize_name
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import get_device_op_overrides, IndentedBuffer, Kernel
@ -126,23 +126,17 @@ class CppWrapperCpu(PythonWrapperCodegen):
# include a hash so our code cache gives different constants different files
self.header.writeline(f"// {name} {hashed}")
def get_device_include(self):
if V.graph.aot_mode:
return f"#include <torch/csrc/inductor/aoti_include/{self.device}.h>"
return f"#include <torch/csrc/inductor/cpp_wrapper/{self.device}.h>"
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
if V.graph.aot_mode:
self.header.splice(
"""
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_runtime/model.h>
"""
)
with open(
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp")
) as f:
self.header.splice(f.read())
else:
if not V.graph.aot_mode:
self.header.splice(
"""
import torch
@ -150,61 +144,17 @@ class CppWrapperCpu(PythonWrapperCodegen):
cpp_wrapper_src = (
'''
#include <optional>
#include <Python.h>
#define PYBIND11_SIMPLE_GIL_MANAGEMENT
#include <pybind11/gil.h>
namespace py = pybind11;
class RAIIPyObject {
public:
RAIIPyObject() : obj_(nullptr) {}
RAIIPyObject(PyObject* obj) : obj_(obj) {}
~RAIIPyObject() {
Py_XDECREF(obj_);
}
RAIIPyObject& operator=(const RAIIPyObject& other) {
if (this != &other) {
Py_XDECREF(obj_);
obj_ = other.obj_;
Py_XINCREF(obj_);
}
return *this;
}
operator PyObject*() {
return obj_;
}
PyObject* get() {
return obj_;
}
private:
PyObject* obj_;
};
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
using namespace torch::aot_inductor;
"""
)
self.header.splice(
f"""
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>
self.header.splice(self.get_device_include())
#include <c10/util/generic_math.h>
typedef at::Half half;
typedef at::BFloat16 bfloat16;
if V.graph.aot_mode:
with open(
os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp")
) as f:
self.header.splice(f.read())
// Round up to the nearest multiple of {ALIGN_BYTES}
[[maybe_unused]] static int64_t align(int64_t nbytes) {{
return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES};
}}
"""
)
extend_aoti_c_shim_include = (
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
)
@ -1517,8 +1467,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
return final_tmp_name
def codegen_device_copy(self, src, dst, non_blocking: bool):
"""This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to
handle cases where dst is not an AtenTensorHandle."""
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));"
)
def codegen_multi_output(self, name, value):

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import os
from itertools import count
from typing import Callable, Dict, List, Optional
@ -82,18 +81,11 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
return DTYPE_TO_CPP[dtype]
return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>"
def write_header(self):
if V.graph.is_const_graph:
# We do not write header for constant graph, it will be written by main module.
return
super().write_header()
with open(
os.path.join(
os.path.dirname(__file__), "aoti_runtime", "implementation.cpp"
)
) as f:
self.header.splice(f.read())
def get_device_include(self):
assert self.device == "cpu", "ArrayRef only supported on CPU!"
if V.graph.aot_mode:
return "#include <torch/csrc/inductor/aoti_include/array_ref.h>"
return "#include <torch/csrc/inductor/cpp_wrapper/array_ref.h>"
def codegen_input_numel_asserts(self):
for name, buf in V.graph.graph_inputs.items():

View File

@ -203,9 +203,6 @@ class CppWrapperGpu(CppWrapperCpu):
return
super().write_header()
self.header.splice("#include <filesystem>")
self.header.splice(self.device_codegen.abi_compatible_header())
self.header.splice(
maybe_hipify_code_wrapper(self.device_codegen.kernel_driver())
)

View File

@ -225,9 +225,6 @@ class CUDADeviceOpOverrides(DeviceOpOverrides):
#endif
"""
def abi_compatible_header(self):
return "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
def cpp_stream_type(self):
return "cudaStream_t"

View File

@ -53,13 +53,13 @@ class DebugPrinterManager:
def __init__(
self,
debug_printer_level,
use_array_ref: bool,
args_to_print_or_save: Optional[List[str]] = None,
kernel_name: str = "",
kernel=None,
arg_signatures: Optional[List[type]] = None,
kernel_type=None,
):
self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level)
self.use_array_ref = use_array_ref
if args_to_print_or_save is None:
args_to_print_or_save = []
self.args_to_print_or_save = args_to_print_or_save
@ -155,12 +155,15 @@ class DebugPrinterManager:
]
self.args_to_print_or_save = args_to_print_or_save_extern
elif kernel_type == "cpp":
args_to_print_or_save_cpp = [
f"copy_arrayref_tensor_to_tensor({arg})"
self.args_to_print_or_save = [
(
f"copy_arrayref_tensor_to_tensor({arg})"
if self.use_array_ref
else arg
)
for arg in args_to_print_or_save
if arg.startswith(("buf", "arg"))
]
self.args_to_print_or_save = args_to_print_or_save_cpp
else:
self.args_to_print_or_save = args_to_print_or_save
self.kernel_name = kernel_name

View File

@ -721,7 +721,8 @@ class PythonWrapperCodegen(CodeGen):
# intermediate tensor value printing utility
self.debug_printer = DebugPrinterManager(
debug_printer_level=config.aot_inductor.debug_intermediate_value_printer
debug_printer_level=config.aot_inductor.debug_intermediate_value_printer,
use_array_ref=config.aot_inductor.allow_stack_allocation,
)
# Additional files that are dependent to the wrapper (ex. cubin files)

View File

@ -57,12 +57,6 @@ class XPUDeviceOpOverrides(DeviceOpOverrides):
"""
return source_codes
def abi_compatible_header(self):
return """
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
"""
def cpp_stream_type(self):
return "sycl::queue*"

View File

@ -0,0 +1,7 @@
#pragma once
#include <torch/csrc/inductor/aoti_include/common.h>
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/array_ref_impl.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h>

View File

@ -0,0 +1,17 @@
#pragma once
#include <filesystem>
#include <optional>
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_runtime/model.h>
#include <c10/util/generic_math.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
using half = at::Half;
using bfloat16 = at::BFloat16;
// Round up to the nearest multiple of 64
[[maybe_unused]] inline int64_t align(int64_t nbytes) {
return (nbytes + 64 - 1) & -64;
}

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/aoti_include/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h>

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/aoti_include/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/aoti_include/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h>

View File

@ -1,14 +1,11 @@
// NOTE: Like interface.cpp, this file will be copied into AOTInductor
// generated output. This file is intended to keep implementation
// details separate from the implementation of the AOTI public
// interface.
#pragma once
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
namespace torch {
namespace aot_inductor {
namespace torch::aot_inductor {
template <typename T>
void convert_output_to_handle(
const ArrayRefTensor<T>& output,
@ -82,9 +79,9 @@ template <typename T>
void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
if (tensor.numel() != numel) {
std::stringstream err;
err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
err << "incorrect numel for input tensor. expected " << numel << ", got "
<< tensor.numel();
throw std::runtime_error(err.str());
}
}
} // namespace aot_inductor
} // namespace torch
} // namespace torch::aot_inductor

View File

@ -0,0 +1,7 @@
#pragma once
#include <torch/csrc/inductor/aoti_runtime/arrayref_tensor.h>
#include <torch/csrc/inductor/aoti_runtime/thread_local.h>
#include <torch/csrc/inductor/array_ref_impl.h>
#include <torch/csrc/inductor/cpp_wrapper/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h>

View File

@ -0,0 +1,49 @@
#pragma once
#include <Python.h>
#include <filesystem>
#include <optional>
#define PYBIND11_SIMPLE_GIL_MANAGEMENT
#include <pybind11/gil.h>
namespace py = pybind11;
class RAIIPyObject {
public:
RAIIPyObject() : obj_(nullptr) {}
RAIIPyObject(PyObject* obj) : obj_(obj) {}
~RAIIPyObject() {
Py_XDECREF(obj_);
}
RAIIPyObject& operator=(const RAIIPyObject& other) {
if (this != &other) {
Py_XDECREF(obj_);
obj_ = other.obj_;
Py_XINCREF(obj_);
}
return *this;
}
operator PyObject*() {
return obj_;
}
PyObject* get() {
return obj_;
}
private:
PyObject* obj_;
};
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
using namespace torch::aot_inductor;
#include <c10/util/generic_math.h>
#include <torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h>
using half = at::Half;
using bfloat16 = at::BFloat16;
// Round up to the nearest multiple of 64
[[maybe_unused]] inline int64_t align(int64_t nbytes) {
return (nbytes + 64 - 1) & -64;
}

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/cpp_wrapper/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h>

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/cpp_wrapper/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h>

View File

@ -0,0 +1,3 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h>

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h>

View File

@ -0,0 +1,5 @@
#pragma once
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h>

View File

@ -0,0 +1,4 @@
#pragma once
#include <torch/csrc/inductor/cpp_wrapper/common.h>
#include <torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h>