[inductor] add cpp builder code. (#124045)

Previous full PR https://github.com/pytorch/pytorch/pull/115248 is failed to merge due to fb_code is hard to debug.
I also tried to submit them as two pieces, https://github.com/pytorch/pytorch/pull/118514 https://github.com/pytorch/pytorch/pull/118515. And they have passed PreCI at that time.

Now I tried to split https://github.com/pytorch/pytorch/pull/115248 into smaller piece, and it is the first step of RFC https://github.com/pytorch/pytorch/issues/124245.
Changes:
1. Add cpp builder code, the new cpp_builder support Windows OS.
2. Add CPU ISA checker which is cross OS and exported from backend cpuinfo.
3. Switch compiler ISA checker to new cpp builder.
4. CppCodeCache use the new ISA checker.
5. Add temprary `test_new_cpp_build_logical` UT to help on transfer to new code.
<img width="1853" alt="Image" src="https://github.com/pytorch/pytorch/assets/8433590/ce6519ab-ba92-4204-b1d6-7d15d2ba2cbe">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124045
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
Xu Han
2024-05-08 05:27:11 +00:00
committed by PyTorch MergeBot
parent 08f6ef0e1c
commit 469383755f
9 changed files with 1351 additions and 32 deletions

View File

@ -5,6 +5,22 @@
namespace at::cpu {
bool is_cpu_support_avx2() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx2();
#else
return false;
#endif
}
bool is_cpu_support_avx512() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq();
#else
return false;
#endif
}
bool is_cpu_support_vnni() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni();

View File

@ -4,6 +4,9 @@
namespace at::cpu {
TORCH_API bool is_cpu_support_avx2();
TORCH_API bool is_cpu_support_avx512();
// Detect if CPU support Vector Neural Network Instruction.
TORCH_API bool is_cpu_support_vnni();

View File

@ -6245,6 +6245,11 @@ class CommonTemplate:
self.common(fn, [torch.randn(64, 64)])
def test_new_cpp_build_logical(self):
from torch._inductor.codecache import validate_new_cpp_commands
validate_new_cpp_commands()
def test_as_strided(self):
def fn(x):
return (

View File

@ -2,4 +2,6 @@ from torch.types import _bool
# Defined in torch/csrc/cpu/Module.cpp
def _is_cpu_support_avx2() -> _bool: ...
def _is_cpu_support_avx512() -> _bool: ...
def _is_cpu_support_vnni() -> _bool: ...

View File

@ -406,6 +406,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata",
"torch._C._construct_storage_from_data_pointer",
"torch._C._conv_determine_backend_memory_format",
"torch._C._cpu._is_cpu_support_avx2",
"torch._C._cpu._is_cpu_support_avx512",
"torch._C._cpu._is_cpu_support_vnni",
"torch._C._crash_if_aten_asan",
"torch._C._crash_if_csrc_asan",
@ -2420,6 +2422,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.chain_matmul",
"torch.compile",
"torch.compiled_with_cxx11_abi",
"torch.cpu._is_cpu_support_avx2",
"torch.cpu._is_cpu_support_avx512",
"torch.cpu._is_cpu_support_vnni",
"torch.cpu.current_device",
"torch.cpu.current_stream",

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import base64
import copyreg
import ctypes
import dataclasses
import functools
import hashlib
@ -82,6 +83,8 @@ _HERE = os.path.abspath(__file__)
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
_IS_WINDOWS = sys.platform == "win32"
if config.is_fbcode():
from triton.fb import build_paths
from triton.fb.build import _run_build_command
@ -1191,7 +1194,7 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
class VecISA:
_bit_width: int
_macro: str
_macro: List[str]
_arch_flags: str
_dtype_nelements: Dict[torch.dtype, int]
@ -1237,7 +1240,7 @@ cdll.LoadLibrary("__lib_path__")
def nelements(self, dtype: torch.dtype = torch.float) -> int:
return self._dtype_nelements[dtype]
def build_macro(self) -> str:
def build_macro(self) -> List[str]:
return self._macro
def build_arch_flags(self) -> str:
@ -1248,6 +1251,8 @@ cdll.LoadLibrary("__lib_path__")
@functools.lru_cache(None)
def __bool__(self) -> bool:
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
if config.cpp.vec_isa_ok is not None:
return config.cpp.vec_isa_ok
@ -1264,16 +1269,21 @@ cdll.LoadLibrary("__lib_path__")
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
build_cmd = shlex.split(
cpp_compile_command(
input_path, output_path, warning_all=False, vec_isa=self
)
output_dir = os.path.dirname(input_path)
buid_options = CppTorchOptions(chosen_isa=self, warning_all=False)
x86_isa_help_builder = CppBuilder(
key,
[input_path],
buid_options,
output_dir,
)
try:
# Check if the output file exist, and compile when not.
output_path = x86_isa_help_builder.get_target_file_path()
if not os.path.isfile(output_path):
compile_file(input_path, output_path, build_cmd)
status, target_file = x86_isa_help_builder.build()
if status:
return False
# Check build result
subprocess.check_call(
@ -1294,7 +1304,7 @@ cdll.LoadLibrary("__lib_path__")
@dataclasses.dataclass
class VecNEON(VecISA):
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
_macro = "-DCPU_CAPABILITY_NEON"
_macro = ["CPU_CAPABILITY_NEON"]
_arch_flags = "" # Unused
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
@ -1307,8 +1317,12 @@ class VecNEON(VecISA):
@dataclasses.dataclass
class VecAVX512(VecISA):
_bit_width = 512
_macro = "-DCPU_CAPABILITY_AVX512"
_arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
_macro = ["CPU_CAPABILITY_AVX512"]
_arch_flags = (
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
if not _IS_WINDOWS
else "/arch:AVX512"
) # TODO: use cflags
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
def __str__(self) -> str:
@ -1320,8 +1334,10 @@ class VecAVX512(VecISA):
@dataclasses.dataclass
class VecAVX2(VecISA):
_bit_width = 256
_macro = "-DCPU_CAPABILITY_AVX2"
_arch_flags = "-mavx2 -mfma"
_macro = ["CPU_CAPABILITY_AVX2"]
_arch_flags = (
"-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2"
) # TODO: use cflags
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
def __str__(self) -> str:
@ -1333,7 +1349,11 @@ class VecAVX2(VecISA):
@dataclasses.dataclass
class VecZVECTOR(VecISA):
_bit_width = 256
_macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION"
_macro = [
"CPU_CAPABILITY_ZVECTOR",
"CPU_CAPABILITY=ZVECTOR",
"HAVE_ZVECTOR_CPU_DEFINITION",
]
_arch_flags = "-mvx -mzvector"
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
@ -1345,7 +1365,7 @@ class VecZVECTOR(VecISA):
class InvalidVecISA(VecISA):
_bit_width = 0
_macro = ""
_macro = [""]
_arch_flags = ""
_dtype_nelements = {}
@ -1358,6 +1378,31 @@ class InvalidVecISA(VecISA):
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
def x86_isa_checker() -> List[str]:
supported_isa: List[str] = []
def _check_and_append_supported_isa(
dest: List[str], isa_supported: bool, isa_name: str
):
if isa_supported is True:
dest.append(isa_name)
Arch = platform.machine()
"""
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
"""
if Arch != "x86_64" and Arch != "AMD64":
return supported_isa
avx2 = torch.cpu._is_cpu_support_avx2()
avx512 = torch.cpu._is_cpu_support_avx512()
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
return supported_isa
invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
@ -1370,7 +1415,8 @@ def valid_vec_isa_list() -> List[VecISA]:
if sys.platform == "darwin" and platform.processor() == "arm":
return [VecNEON()]
if sys.platform != "linux":
cur_os = sys.platform
if cur_os != "linux" and cur_os != "win32":
return []
if platform.machine() == "s390x":
@ -1388,12 +1434,11 @@ def valid_vec_isa_list() -> List[VecISA]:
return []
isa_list = []
with open("/proc/cpuinfo") as _cpu_info:
_cpu_info_content = _cpu_info.read()
for isa in supported_vec_isa_list:
if str(isa) in _cpu_info_content and isa:
isa_list.append(isa)
return isa_list
_cpu_supported_isa = x86_isa_checker()
for isa in supported_vec_isa_list:
if str(isa) in _cpu_supported_isa:
isa_list.append(isa)
return isa_list
def pick_vec_isa() -> VecISA:
@ -1401,6 +1446,7 @@ def pick_vec_isa() -> VecISA:
return VecAVX2()
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
if not _valid_vec_isa_list:
return invalid_vec_isa
@ -1569,7 +1615,14 @@ def get_include_and_linking_paths(
_set_gpu_runtime_env()
from torch.utils import cpp_extension
macros = vec_isa.build_macro() if vec_isa != invalid_vec_isa else ""
# Remove below in the further
# macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else ""
macros = ""
if vec_isa != invalid_vec_isa:
for x in vec_isa.build_macro():
macros_def = f"-D{x} "
macros += macros_def
build_arch_flags = ""
if sys.platform == "linux" and (
include_pytorch
@ -1789,7 +1842,7 @@ def cpp_compile_command(
{get_warning_all_flag(warning_all)} {cpp_flags()}
{get_glibcxx_abi_build_flags()}
{ipaths_str} {lpaths} {libs} {build_arch_flags}
{macros} {linker_paths} {clang_flags}
{macros} {linker_paths} {clang_flags} {cpp_wrapper_flags()}
{optimization_flags()}
{use_custom_generated_macros()}
{use_fb_internal_macros()}
@ -2041,7 +2094,6 @@ class AotCodeCompiler:
def _to_bytes(t: torch.Tensor) -> bytes:
# This serializes the tensor's untyped_storage to bytes by accessing
# the raw data of the underlying structure.
import ctypes
if t.numel() == 0:
return b""
@ -2265,8 +2317,18 @@ class CppCodeCache:
"cuda": cuda,
"vec_isa": pick_vec_isa(),
}
cpp_command = repr(cpp_compile_command("i", "o", **compile_command))
key, input_path = write(source_code, "cpp", extra=cpp_command)
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
picked_vec_isa = pick_vec_isa()
dummy_builder = CppBuilder("i", ["o"], CppTorchOptions(picked_vec_isa))
# write function will calc source_code hash, the same source code with different
# ISA level should be generate different hash.
# So we need get a command_line which contains isa related parameter as a part of hash key.
# And then pass the command_line to below write function as extra parameter to
# guarantee the source code hash contains ISA difference.
dummy_cmd = dummy_builder.get_command_line()
key, input_path = write(source_code, "cpp", extra=dummy_cmd)
if key not in cls.cache:
from filelock import FileLock
@ -2535,7 +2597,85 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
)
@clear_on_fresh_inductor_cache
# TODO: Will remove the temp code after switch to new cpp_builder
def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]):
new_diff: List[str] = [x for x in new_cmd if x not in old_cmd]
old_diff: List[str] = [y for y in old_cmd if y not in new_cmd]
if new_diff or old_diff:
print("!!! new_cmd: ", new_cmd)
print("!!! old_cmd: ", old_cmd)
print("!!! new_diff: ", new_diff)
print("!!! old_diff: ", old_diff)
raise RuntimeError("Error in new and old command different.")
def _do_validate_cpp_commands(
include_pytorch: bool, cuda: bool, compile_only: bool, mmap_weights: bool
):
# PreCI will failed if test machine can't run cuda.
test_cuda = torch.cuda.is_available() and cuda
input_path = "/temp/dummy_input.cpp"
output_path = "/temp/dummy_output.so"
if compile_only:
output_path = "/temp/dummy_output.o"
picked_isa = pick_vec_isa()
old_cmd = cpp_compile_command(
input=input_path,
output=output_path,
include_pytorch=include_pytorch,
vec_isa=picked_isa,
cuda=test_cuda,
aot_mode=False,
compile_only=compile_only,
use_absolute_path=False,
use_mmap_weights=mmap_weights,
).split(" ")
from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions
dummy_build_option = CppTorchCudaOptions(
chosen_isa=picked_isa,
include_pytorch=include_pytorch,
use_cuda=test_cuda,
compile_only=compile_only,
use_mmap_weights=mmap_weights,
)
dummy_builder = CppBuilder(
name="dummy_output",
sources=input_path,
BuildOption=dummy_build_option,
output_dir="/temp/",
compile_only=compile_only,
use_absolute_path=False,
)
new_cmd = dummy_builder.get_command_line().split(" ")
_temp_validate_new_and_old_command(new_cmd, old_cmd)
# TODO: Will remove the temp code after switch to new cpp_builder
# It could help on sync new cpp_builder generate same command line as the old one.
def validate_new_cpp_commands():
cuda = [True, False]
use_mmap_weights = [True, False]
compile_only = [True, False]
include_pytorch = [True, False]
for x in cuda:
for y in use_mmap_weights:
for z in compile_only:
for m in include_pytorch:
print(
f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}"
)
_do_validate_cpp_commands(
include_pytorch=m, cuda=x, mmap_weights=y, compile_only=z
)
class PyCodeCache:
cache: Dict[str, ModuleType] = dict()
linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,16 @@ __all__ = [
_device_t = Union[_device, str, int, None]
def _is_cpu_support_avx2() -> bool:
r"""Returns a bool indicating if CPU supports AVX2."""
return torch._C._cpu._is_cpu_support_avx2()
def _is_cpu_support_avx512() -> bool:
r"""Returns a bool indicating if CPU supports AVX512."""
return torch._C._cpu._is_cpu_support_avx512()
def _is_cpu_support_vnni() -> bool:
r"""Returns a bool indicating if CPU supports VNNI."""
return torch._C._cpu._is_cpu_support_vnni()

View File

@ -2,15 +2,15 @@
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace cpu {
namespace torch::cpu {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cpu = m.def_submodule("_cpu", "cpu related pybind.");
cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2);
cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512);
cpu.def("_is_cpu_support_vnni", at::cpu::is_cpu_support_vnni);
}
} // namespace cpu
} // namespace torch
} // namespace torch::cpu