mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] add cpp builder code. (#124045)
Previous full PR https://github.com/pytorch/pytorch/pull/115248 is failed to merge due to fb_code is hard to debug. I also tried to submit them as two pieces, https://github.com/pytorch/pytorch/pull/118514 https://github.com/pytorch/pytorch/pull/118515. And they have passed PreCI at that time. Now I tried to split https://github.com/pytorch/pytorch/pull/115248 into smaller piece, and it is the first step of RFC https://github.com/pytorch/pytorch/issues/124245. Changes: 1. Add cpp builder code, the new cpp_builder support Windows OS. 2. Add CPU ISA checker which is cross OS and exported from backend cpuinfo. 3. Switch compiler ISA checker to new cpp builder. 4. CppCodeCache use the new ISA checker. 5. Add temprary `test_new_cpp_build_logical` UT to help on transfer to new code. <img width="1853" alt="Image" src="https://github.com/pytorch/pytorch/assets/8433590/ce6519ab-ba92-4204-b1d6-7d15d2ba2cbe"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124045 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
@ -5,6 +5,22 @@
|
||||
|
||||
namespace at::cpu {
|
||||
|
||||
bool is_cpu_support_avx2() {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
return cpuinfo_initialize() && cpuinfo_has_x86_avx2();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_cpu_support_avx512() {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_cpu_support_vnni() {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni();
|
||||
|
@ -4,6 +4,9 @@
|
||||
|
||||
namespace at::cpu {
|
||||
|
||||
TORCH_API bool is_cpu_support_avx2();
|
||||
TORCH_API bool is_cpu_support_avx512();
|
||||
|
||||
// Detect if CPU support Vector Neural Network Instruction.
|
||||
TORCH_API bool is_cpu_support_vnni();
|
||||
|
||||
|
@ -6245,6 +6245,11 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, [torch.randn(64, 64)])
|
||||
|
||||
def test_new_cpp_build_logical(self):
|
||||
from torch._inductor.codecache import validate_new_cpp_commands
|
||||
|
||||
validate_new_cpp_commands()
|
||||
|
||||
def test_as_strided(self):
|
||||
def fn(x):
|
||||
return (
|
||||
|
@ -2,4 +2,6 @@ from torch.types import _bool
|
||||
|
||||
# Defined in torch/csrc/cpu/Module.cpp
|
||||
|
||||
def _is_cpu_support_avx2() -> _bool: ...
|
||||
def _is_cpu_support_avx512() -> _bool: ...
|
||||
def _is_cpu_support_vnni() -> _bool: ...
|
||||
|
@ -406,6 +406,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata",
|
||||
"torch._C._construct_storage_from_data_pointer",
|
||||
"torch._C._conv_determine_backend_memory_format",
|
||||
"torch._C._cpu._is_cpu_support_avx2",
|
||||
"torch._C._cpu._is_cpu_support_avx512",
|
||||
"torch._C._cpu._is_cpu_support_vnni",
|
||||
"torch._C._crash_if_aten_asan",
|
||||
"torch._C._crash_if_csrc_asan",
|
||||
@ -2420,6 +2422,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.chain_matmul",
|
||||
"torch.compile",
|
||||
"torch.compiled_with_cxx11_abi",
|
||||
"torch.cpu._is_cpu_support_avx2",
|
||||
"torch.cpu._is_cpu_support_avx512",
|
||||
"torch.cpu._is_cpu_support_vnni",
|
||||
"torch.cpu.current_device",
|
||||
"torch.cpu.current_stream",
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copyreg
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
@ -82,6 +83,8 @@ _HERE = os.path.abspath(__file__)
|
||||
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
|
||||
_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
if config.is_fbcode():
|
||||
from triton.fb import build_paths
|
||||
from triton.fb.build import _run_build_command
|
||||
@ -1191,7 +1194,7 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
||||
|
||||
class VecISA:
|
||||
_bit_width: int
|
||||
_macro: str
|
||||
_macro: List[str]
|
||||
_arch_flags: str
|
||||
_dtype_nelements: Dict[torch.dtype, int]
|
||||
|
||||
@ -1237,7 +1240,7 @@ cdll.LoadLibrary("__lib_path__")
|
||||
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
||||
return self._dtype_nelements[dtype]
|
||||
|
||||
def build_macro(self) -> str:
|
||||
def build_macro(self) -> List[str]:
|
||||
return self._macro
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
@ -1248,6 +1251,8 @@ cdll.LoadLibrary("__lib_path__")
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def __bool__(self) -> bool:
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
|
||||
|
||||
if config.cpp.vec_isa_ok is not None:
|
||||
return config.cpp.vec_isa_ok
|
||||
|
||||
@ -1264,16 +1269,21 @@ cdll.LoadLibrary("__lib_path__")
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
output_path = input_path[:-3] + "so"
|
||||
build_cmd = shlex.split(
|
||||
cpp_compile_command(
|
||||
input_path, output_path, warning_all=False, vec_isa=self
|
||||
)
|
||||
output_dir = os.path.dirname(input_path)
|
||||
buid_options = CppTorchOptions(chosen_isa=self, warning_all=False)
|
||||
x86_isa_help_builder = CppBuilder(
|
||||
key,
|
||||
[input_path],
|
||||
buid_options,
|
||||
output_dir,
|
||||
)
|
||||
try:
|
||||
# Check if the output file exist, and compile when not.
|
||||
output_path = x86_isa_help_builder.get_target_file_path()
|
||||
if not os.path.isfile(output_path):
|
||||
compile_file(input_path, output_path, build_cmd)
|
||||
status, target_file = x86_isa_help_builder.build()
|
||||
if status:
|
||||
return False
|
||||
|
||||
# Check build result
|
||||
subprocess.check_call(
|
||||
@ -1294,7 +1304,7 @@ cdll.LoadLibrary("__lib_path__")
|
||||
@dataclasses.dataclass
|
||||
class VecNEON(VecISA):
|
||||
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
|
||||
_macro = "-DCPU_CAPABILITY_NEON"
|
||||
_macro = ["CPU_CAPABILITY_NEON"]
|
||||
_arch_flags = "" # Unused
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
@ -1307,8 +1317,12 @@ class VecNEON(VecISA):
|
||||
@dataclasses.dataclass
|
||||
class VecAVX512(VecISA):
|
||||
_bit_width = 512
|
||||
_macro = "-DCPU_CAPABILITY_AVX512"
|
||||
_arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
||||
_macro = ["CPU_CAPABILITY_AVX512"]
|
||||
_arch_flags = (
|
||||
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
||||
if not _IS_WINDOWS
|
||||
else "/arch:AVX512"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -1320,8 +1334,10 @@ class VecAVX512(VecISA):
|
||||
@dataclasses.dataclass
|
||||
class VecAVX2(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = "-DCPU_CAPABILITY_AVX2"
|
||||
_arch_flags = "-mavx2 -mfma"
|
||||
_macro = ["CPU_CAPABILITY_AVX2"]
|
||||
_arch_flags = (
|
||||
"-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -1333,7 +1349,11 @@ class VecAVX2(VecISA):
|
||||
@dataclasses.dataclass
|
||||
class VecZVECTOR(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION"
|
||||
_macro = [
|
||||
"CPU_CAPABILITY_ZVECTOR",
|
||||
"CPU_CAPABILITY=ZVECTOR",
|
||||
"HAVE_ZVECTOR_CPU_DEFINITION",
|
||||
]
|
||||
_arch_flags = "-mvx -mzvector"
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
@ -1345,7 +1365,7 @@ class VecZVECTOR(VecISA):
|
||||
|
||||
class InvalidVecISA(VecISA):
|
||||
_bit_width = 0
|
||||
_macro = ""
|
||||
_macro = [""]
|
||||
_arch_flags = ""
|
||||
_dtype_nelements = {}
|
||||
|
||||
@ -1358,6 +1378,31 @@ class InvalidVecISA(VecISA):
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
def x86_isa_checker() -> List[str]:
|
||||
supported_isa: List[str] = []
|
||||
|
||||
def _check_and_append_supported_isa(
|
||||
dest: List[str], isa_supported: bool, isa_name: str
|
||||
):
|
||||
if isa_supported is True:
|
||||
dest.append(isa_name)
|
||||
|
||||
Arch = platform.machine()
|
||||
"""
|
||||
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
"""
|
||||
if Arch != "x86_64" and Arch != "AMD64":
|
||||
return supported_isa
|
||||
|
||||
avx2 = torch.cpu._is_cpu_support_avx2()
|
||||
avx512 = torch.cpu._is_cpu_support_avx512()
|
||||
|
||||
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
||||
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
||||
|
||||
return supported_isa
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
|
||||
|
||||
@ -1370,7 +1415,8 @@ def valid_vec_isa_list() -> List[VecISA]:
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
return [VecNEON()]
|
||||
|
||||
if sys.platform != "linux":
|
||||
cur_os = sys.platform
|
||||
if cur_os != "linux" and cur_os != "win32":
|
||||
return []
|
||||
|
||||
if platform.machine() == "s390x":
|
||||
@ -1388,12 +1434,11 @@ def valid_vec_isa_list() -> List[VecISA]:
|
||||
return []
|
||||
|
||||
isa_list = []
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
_cpu_info_content = _cpu_info.read()
|
||||
for isa in supported_vec_isa_list:
|
||||
if str(isa) in _cpu_info_content and isa:
|
||||
isa_list.append(isa)
|
||||
return isa_list
|
||||
_cpu_supported_isa = x86_isa_checker()
|
||||
for isa in supported_vec_isa_list:
|
||||
if str(isa) in _cpu_supported_isa:
|
||||
isa_list.append(isa)
|
||||
return isa_list
|
||||
|
||||
|
||||
def pick_vec_isa() -> VecISA:
|
||||
@ -1401,6 +1446,7 @@ def pick_vec_isa() -> VecISA:
|
||||
return VecAVX2()
|
||||
|
||||
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
||||
|
||||
if not _valid_vec_isa_list:
|
||||
return invalid_vec_isa
|
||||
|
||||
@ -1569,7 +1615,14 @@ def get_include_and_linking_paths(
|
||||
_set_gpu_runtime_env()
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
macros = vec_isa.build_macro() if vec_isa != invalid_vec_isa else ""
|
||||
# Remove below in the further
|
||||
# macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else ""
|
||||
macros = ""
|
||||
if vec_isa != invalid_vec_isa:
|
||||
for x in vec_isa.build_macro():
|
||||
macros_def = f"-D{x} "
|
||||
macros += macros_def
|
||||
|
||||
build_arch_flags = ""
|
||||
if sys.platform == "linux" and (
|
||||
include_pytorch
|
||||
@ -1789,7 +1842,7 @@ def cpp_compile_command(
|
||||
{get_warning_all_flag(warning_all)} {cpp_flags()}
|
||||
{get_glibcxx_abi_build_flags()}
|
||||
{ipaths_str} {lpaths} {libs} {build_arch_flags}
|
||||
{macros} {linker_paths} {clang_flags}
|
||||
{macros} {linker_paths} {clang_flags} {cpp_wrapper_flags()}
|
||||
{optimization_flags()}
|
||||
{use_custom_generated_macros()}
|
||||
{use_fb_internal_macros()}
|
||||
@ -2041,7 +2094,6 @@ class AotCodeCompiler:
|
||||
def _to_bytes(t: torch.Tensor) -> bytes:
|
||||
# This serializes the tensor's untyped_storage to bytes by accessing
|
||||
# the raw data of the underlying structure.
|
||||
import ctypes
|
||||
|
||||
if t.numel() == 0:
|
||||
return b""
|
||||
@ -2265,8 +2317,18 @@ class CppCodeCache:
|
||||
"cuda": cuda,
|
||||
"vec_isa": pick_vec_isa(),
|
||||
}
|
||||
cpp_command = repr(cpp_compile_command("i", "o", **compile_command))
|
||||
key, input_path = write(source_code, "cpp", extra=cpp_command)
|
||||
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
|
||||
|
||||
picked_vec_isa = pick_vec_isa()
|
||||
dummy_builder = CppBuilder("i", ["o"], CppTorchOptions(picked_vec_isa))
|
||||
# write function will calc source_code hash, the same source code with different
|
||||
# ISA level should be generate different hash.
|
||||
# So we need get a command_line which contains isa related parameter as a part of hash key.
|
||||
# And then pass the command_line to below write function as extra parameter to
|
||||
# guarantee the source code hash contains ISA difference.
|
||||
dummy_cmd = dummy_builder.get_command_line()
|
||||
key, input_path = write(source_code, "cpp", extra=dummy_cmd)
|
||||
|
||||
if key not in cls.cache:
|
||||
from filelock import FileLock
|
||||
@ -2535,7 +2597,85 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
)
|
||||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
# TODO: Will remove the temp code after switch to new cpp_builder
|
||||
def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]):
|
||||
new_diff: List[str] = [x for x in new_cmd if x not in old_cmd]
|
||||
old_diff: List[str] = [y for y in old_cmd if y not in new_cmd]
|
||||
|
||||
if new_diff or old_diff:
|
||||
print("!!! new_cmd: ", new_cmd)
|
||||
print("!!! old_cmd: ", old_cmd)
|
||||
print("!!! new_diff: ", new_diff)
|
||||
print("!!! old_diff: ", old_diff)
|
||||
raise RuntimeError("Error in new and old command different.")
|
||||
|
||||
|
||||
def _do_validate_cpp_commands(
|
||||
include_pytorch: bool, cuda: bool, compile_only: bool, mmap_weights: bool
|
||||
):
|
||||
# PreCI will failed if test machine can't run cuda.
|
||||
test_cuda = torch.cuda.is_available() and cuda
|
||||
input_path = "/temp/dummy_input.cpp"
|
||||
output_path = "/temp/dummy_output.so"
|
||||
if compile_only:
|
||||
output_path = "/temp/dummy_output.o"
|
||||
picked_isa = pick_vec_isa()
|
||||
|
||||
old_cmd = cpp_compile_command(
|
||||
input=input_path,
|
||||
output=output_path,
|
||||
include_pytorch=include_pytorch,
|
||||
vec_isa=picked_isa,
|
||||
cuda=test_cuda,
|
||||
aot_mode=False,
|
||||
compile_only=compile_only,
|
||||
use_absolute_path=False,
|
||||
use_mmap_weights=mmap_weights,
|
||||
).split(" ")
|
||||
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions
|
||||
|
||||
dummy_build_option = CppTorchCudaOptions(
|
||||
chosen_isa=picked_isa,
|
||||
include_pytorch=include_pytorch,
|
||||
use_cuda=test_cuda,
|
||||
compile_only=compile_only,
|
||||
use_mmap_weights=mmap_weights,
|
||||
)
|
||||
|
||||
dummy_builder = CppBuilder(
|
||||
name="dummy_output",
|
||||
sources=input_path,
|
||||
BuildOption=dummy_build_option,
|
||||
output_dir="/temp/",
|
||||
compile_only=compile_only,
|
||||
use_absolute_path=False,
|
||||
)
|
||||
new_cmd = dummy_builder.get_command_line().split(" ")
|
||||
|
||||
_temp_validate_new_and_old_command(new_cmd, old_cmd)
|
||||
|
||||
|
||||
# TODO: Will remove the temp code after switch to new cpp_builder
|
||||
# It could help on sync new cpp_builder generate same command line as the old one.
|
||||
def validate_new_cpp_commands():
|
||||
cuda = [True, False]
|
||||
use_mmap_weights = [True, False]
|
||||
compile_only = [True, False]
|
||||
include_pytorch = [True, False]
|
||||
|
||||
for x in cuda:
|
||||
for y in use_mmap_weights:
|
||||
for z in compile_only:
|
||||
for m in include_pytorch:
|
||||
print(
|
||||
f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}"
|
||||
)
|
||||
_do_validate_cpp_commands(
|
||||
include_pytorch=m, cuda=x, mmap_weights=y, compile_only=z
|
||||
)
|
||||
|
||||
|
||||
class PyCodeCache:
|
||||
cache: Dict[str, ModuleType] = dict()
|
||||
linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
|
||||
|
1139
torch/_inductor/cpp_builder.py
Normal file
1139
torch/_inductor/cpp_builder.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -27,6 +27,16 @@ __all__ = [
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
def _is_cpu_support_avx2() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports AVX2."""
|
||||
return torch._C._cpu._is_cpu_support_avx2()
|
||||
|
||||
|
||||
def _is_cpu_support_avx512() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports AVX512."""
|
||||
return torch._C._cpu._is_cpu_support_avx512()
|
||||
|
||||
|
||||
def _is_cpu_support_vnni() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports VNNI."""
|
||||
return torch._C._cpu._is_cpu_support_vnni()
|
||||
|
@ -2,15 +2,15 @@
|
||||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace cpu {
|
||||
namespace torch::cpu {
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
auto cpu = m.def_submodule("_cpu", "cpu related pybind.");
|
||||
cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2);
|
||||
cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512);
|
||||
cpu.def("_is_cpu_support_vnni", at::cpu::is_cpu_support_vnni);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace torch
|
||||
} // namespace torch::cpu
|
||||
|
Reference in New Issue
Block a user