mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The correct vector length for ARM64 is 128 bits (16 bytes). We were previously using double this, apparently just because that would be the same length as AVX2. Differential Revision: [D63984039](https://our.internmc.facebook.com/intern/diff/D63984039/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137426 Approved by: https://github.com/jgong5, https://github.com/malfet ghstack dependencies: #138486, #138542, #138655, #138716, #138744
428 lines
13 KiB
Python
428 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
import dataclasses
|
|
import functools
|
|
import os
|
|
import platform
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import warnings
|
|
from typing import Any, Callable, Dict, List, Union
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
|
|
|
|
_IS_WINDOWS = sys.platform == "win32"
|
|
|
|
|
|
def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
|
# ISA dry compile will cost about 1 sec time each startup time.
|
|
# Please check the issue: https://github.com/pytorch/pytorch/issues/100378
|
|
# Actually, dry compile is checking compile capability for ISA.
|
|
# We just record the compiler version, isa options and pytorch version info,
|
|
# and generated them to output binary hash path.
|
|
# It would optimize and skip compile existing binary.
|
|
from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler
|
|
|
|
compiler_info = get_compiler_version_info(get_cpp_compiler())
|
|
torch_version = torch.__version__
|
|
fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
|
|
return fingerprint
|
|
|
|
|
|
class VecISA:
|
|
_bit_width: int
|
|
_macro: List[str]
|
|
_arch_flags: str
|
|
_dtype_nelements: Dict[torch.dtype, int]
|
|
|
|
# Note [Checking for Vectorized Support in Inductor]
|
|
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
|
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
|
|
# like exp, pow, sin, cos and etc.
|
|
# But PyTorch and TorchInductor might use different compilers to build code. If
|
|
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
|
|
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
|
|
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
|
|
# gcc/g++ compiler by default while it could support the AVX512 compilation.
|
|
# Therefore, there would be a conflict sleef version between PyTorch and
|
|
# TorchInductor. Hence, we dry-compile the following code to check whether current
|
|
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
|
|
# also needs the logic
|
|
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
|
|
# making the runtime check unnecessary.
|
|
_avx_code = """
|
|
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE)
|
|
#include <ATen/cpu/vec/functional.h>
|
|
#include <ATen/cpu/vec/vec.h>
|
|
#endif
|
|
|
|
alignas(64) float in_out_ptr0[16] = {0.0};
|
|
|
|
extern "C" void __avx_chk_kernel() {
|
|
auto tmp0 = at::vec::Vectorized<float>(1);
|
|
auto tmp1 = tmp0.exp();
|
|
tmp1.store(in_out_ptr0);
|
|
}
|
|
""" # noqa: B950
|
|
|
|
_avx_py_load = """
|
|
import torch
|
|
from ctypes import cdll
|
|
cdll.LoadLibrary("__lib_path__")
|
|
"""
|
|
|
|
def bit_width(self) -> int:
|
|
return self._bit_width
|
|
|
|
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
|
return self._dtype_nelements[dtype]
|
|
|
|
def build_macro(self) -> List[str]:
|
|
return self._macro
|
|
|
|
def build_arch_flags(self) -> str:
|
|
return self._arch_flags
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(str(self))
|
|
|
|
def check_build(self, code: str) -> bool:
|
|
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
|
|
from torch._inductor.cpp_builder import (
|
|
CppBuilder,
|
|
CppTorchOptions,
|
|
normalize_path_separator,
|
|
)
|
|
|
|
key, input_path = write(
|
|
code,
|
|
"cpp",
|
|
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
|
)
|
|
from filelock import FileLock
|
|
|
|
lock_dir = get_lock_dir()
|
|
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
|
with lock:
|
|
output_dir = os.path.dirname(input_path)
|
|
buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
|
|
x86_isa_help_builder = CppBuilder(
|
|
key,
|
|
[input_path],
|
|
buid_options,
|
|
output_dir,
|
|
)
|
|
try:
|
|
# Check if the output file exist, and compile when not.
|
|
output_path = normalize_path_separator(
|
|
x86_isa_help_builder.get_target_file_path()
|
|
)
|
|
if not os.path.isfile(output_path):
|
|
status, target_file = x86_isa_help_builder.build()
|
|
|
|
# Check build result
|
|
subprocess.check_call(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
VecISA._avx_py_load.replace("__lib_path__", output_path),
|
|
],
|
|
cwd=output_dir,
|
|
stderr=subprocess.DEVNULL,
|
|
env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
|
|
)
|
|
except Exception as e:
|
|
return False
|
|
|
|
return True
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.__bool__impl(config.cpp.vec_isa_ok)
|
|
|
|
@functools.lru_cache(None) # noqa: B019
|
|
def __bool__impl(self, vec_isa_ok) -> bool:
|
|
if vec_isa_ok is not None:
|
|
return vec_isa_ok
|
|
|
|
if config.is_fbcode():
|
|
return True
|
|
|
|
return self.check_build(VecISA._avx_code)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecNEON(VecISA):
|
|
_bit_width = 128 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h
|
|
_macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VEC256_WITH_SLEEF"]
|
|
_arch_flags = "" # Unused
|
|
_dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8}
|
|
|
|
def __str__(self) -> str:
|
|
return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecSVE(VecISA):
|
|
# this function can be repurposed for SVE with variable vec length
|
|
_bit_width = 256
|
|
_macro = [
|
|
"CPU_CAPABILITY_SVE",
|
|
"CPU_CAPABILITY_SVE256",
|
|
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
|
|
]
|
|
_arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
|
|
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
|
|
|
def __str__(self) -> str:
|
|
return "asimd"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecAVX512(VecISA):
|
|
_bit_width = 512
|
|
_macro = ["CPU_CAPABILITY_AVX512"]
|
|
_arch_flags = (
|
|
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
|
if not _IS_WINDOWS
|
|
else "/arch:AVX512"
|
|
) # TODO: use cflags
|
|
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
|
|
|
def __str__(self) -> str:
|
|
return "avx512"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecAMX(VecAVX512):
|
|
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
|
|
|
|
def __str__(self) -> str:
|
|
return super().__str__() + " amx_tile"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
_amx_code = """
|
|
#include <cstdint>
|
|
#include <immintrin.h>
|
|
|
|
struct amx_tilecfg {
|
|
uint8_t palette_id;
|
|
uint8_t start_row;
|
|
uint8_t reserved_0[14];
|
|
uint16_t colsb[16];
|
|
uint8_t rows[16];
|
|
};
|
|
|
|
extern "C" void __amx_chk_kernel() {
|
|
amx_tilecfg cfg = {0};
|
|
_tile_loadconfig(&cfg);
|
|
_tile_zero(0);
|
|
_tile_dpbf16ps(0, 1, 2);
|
|
_tile_dpbusd(0, 1, 2);
|
|
}
|
|
"""
|
|
|
|
@functools.lru_cache(None) # noqa: B019
|
|
def __bool__(self) -> bool:
|
|
if super().__bool__():
|
|
if config.is_fbcode():
|
|
return False
|
|
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
|
|
return True
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecAVX2(VecISA):
|
|
_bit_width = 256
|
|
_macro = ["CPU_CAPABILITY_AVX2"]
|
|
_arch_flags = (
|
|
"-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
|
|
) # TODO: use cflags
|
|
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
|
|
|
def __str__(self) -> str:
|
|
return "avx2"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecZVECTOR(VecISA):
|
|
_bit_width = 256
|
|
_macro = [
|
|
"CPU_CAPABILITY_ZVECTOR",
|
|
"CPU_CAPABILITY=ZVECTOR",
|
|
"HAVE_ZVECTOR_CPU_DEFINITION",
|
|
]
|
|
_arch_flags = "-mvx -mzvector"
|
|
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
|
|
|
def __str__(self) -> str:
|
|
return "zvector"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VecVSX(VecISA):
|
|
_bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256
|
|
_macro = ["CPU_CAPABILITY_VSX"]
|
|
_arch_flags = "-mvsx"
|
|
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
|
|
|
def __str__(self) -> str:
|
|
return "vsx"
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
class InvalidVecISA(VecISA):
|
|
_bit_width = 0
|
|
_macro = [""]
|
|
_arch_flags = ""
|
|
_dtype_nelements = {}
|
|
|
|
def __str__(self) -> str:
|
|
return "INVALID_VEC_ISA"
|
|
|
|
def __bool__(self) -> bool: # type: ignore[override]
|
|
return False
|
|
|
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
|
|
|
|
|
def x86_isa_checker() -> List[str]:
|
|
supported_isa: List[str] = []
|
|
|
|
def _check_and_append_supported_isa(
|
|
dest: List[str], isa_supported: bool, isa_name: str
|
|
) -> None:
|
|
if isa_supported:
|
|
dest.append(isa_name)
|
|
|
|
Arch = platform.machine()
|
|
"""
|
|
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
|
"""
|
|
if Arch != "x86_64" and Arch != "AMD64":
|
|
return supported_isa
|
|
|
|
avx2 = torch.cpu._is_avx2_supported()
|
|
avx512 = torch.cpu._is_avx512_supported()
|
|
amx_tile = torch.cpu._is_amx_tile_supported()
|
|
|
|
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
|
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
|
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
|
|
|
|
return supported_isa
|
|
|
|
|
|
invalid_vec_isa = InvalidVecISA()
|
|
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
|
|
|
|
|
|
def get_isa_from_cpu_capability(
|
|
capability: Union[str, None],
|
|
vec_isa_list: List[VecISA],
|
|
invalid_vec_isa: InvalidVecISA,
|
|
):
|
|
# AMX setting is not supported in eager
|
|
# VecAMX will be prioritized for selection when setting ATEN_CPU_CAPABILITY to avx512
|
|
# TODO add sve256 support
|
|
capability_to_isa_str = {
|
|
"default": "INVALID_VEC_ISA",
|
|
"zvector": "zvector",
|
|
"vsx": "vsx",
|
|
"avx2": "avx2",
|
|
"avx512": "avx512",
|
|
}
|
|
if capability in capability_to_isa_str.keys():
|
|
isa_str = capability_to_isa_str[capability]
|
|
if isa_str == "INVALID_VEC_ISA":
|
|
return invalid_vec_isa
|
|
for vec_isa in vec_isa_list:
|
|
if isa_str in str(vec_isa):
|
|
return vec_isa
|
|
|
|
if capability:
|
|
warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}")
|
|
|
|
return vec_isa_list[0]
|
|
|
|
|
|
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
|
# might have too much redundant content that is useless for ISA check. Hence,
|
|
# we only cache some key isa information.
|
|
@functools.lru_cache(None)
|
|
def valid_vec_isa_list() -> List[VecISA]:
|
|
isa_list: List[VecISA] = []
|
|
if sys.platform == "darwin" and platform.processor() == "arm":
|
|
isa_list.append(VecNEON())
|
|
|
|
if sys.platform not in ["linux", "win32"]:
|
|
return isa_list
|
|
|
|
arch = platform.machine()
|
|
if arch == "s390x":
|
|
with open("/proc/cpuinfo") as _cpu_info:
|
|
while True:
|
|
line = _cpu_info.readline()
|
|
if not line:
|
|
break
|
|
# process line
|
|
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
|
|
if featuresmatch:
|
|
for group in featuresmatch.groups():
|
|
if re.search(r"[\^ ]+vxe[\$ ]+", group):
|
|
isa_list.append(VecZVECTOR())
|
|
break
|
|
elif arch == "ppc64le":
|
|
isa_list.append(VecVSX())
|
|
elif arch == "aarch64":
|
|
if torch.cpu._is_arm_sve_supported():
|
|
isa_list.append(VecSVE())
|
|
else:
|
|
isa_list.append(VecNEON())
|
|
elif arch in ["x86_64", "AMD64"]:
|
|
"""
|
|
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
|
"""
|
|
_cpu_supported_x86_isa = x86_isa_checker()
|
|
for isa in supported_vec_isa_list:
|
|
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
|
isa_list.append(isa)
|
|
|
|
return isa_list
|
|
|
|
|
|
def pick_vec_isa() -> VecISA:
|
|
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
|
|
return VecAVX2()
|
|
|
|
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
|
if not _valid_vec_isa_list:
|
|
return invalid_vec_isa
|
|
|
|
# If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY
|
|
# to control CPU vec ISA
|
|
if config.cpp.simdlen is None:
|
|
return get_isa_from_cpu_capability(
|
|
os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa
|
|
)
|
|
|
|
for isa in _valid_vec_isa_list:
|
|
if config.cpp.simdlen == isa.bit_width():
|
|
return isa
|
|
|
|
return invalid_vec_isa
|