mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CPUInductor] Fix SVE256 detection (#146207)
This PR removes `torch.cpu._is_arm_sve_supported()` and replaces is with stable `torch.backends.cpu.get_cpu_capability()` I should have reviewed https://github.com/pytorch/pytorch/pull/134672 more thoroughly, because it introduced duplicate, but slightly different API for detecting CPU architectures, which resulted in runtime crashes on system that do support SVE128, rather than SVE256 Fixes https://github.com/pytorch/pytorch/issues/145441 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146207 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
8c657ae4be
commit
e56dcf2772
@ -92,14 +92,6 @@ bool init_amx() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_arm_sve_supported() {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
return cpuinfo_initialize() && cpuinfo_has_arm_sve();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
static uint32_t get_cache_size(int level) {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
if (!cpuinfo_initialize()) {
|
||||
|
@ -24,9 +24,6 @@ TORCH_API bool is_amx_fp16_supported();
|
||||
// Enable the system to use AMX instructions.
|
||||
TORCH_API bool init_amx();
|
||||
|
||||
// Detect if CPU supports Arm(R) architecture SVE ISA
|
||||
TORCH_API bool is_arm_sve_supported();
|
||||
|
||||
// Get the L1 cache size per core in Byte
|
||||
TORCH_API uint32_t L1d_cache_size();
|
||||
|
||||
|
@ -9,6 +9,5 @@ def _is_avx512_bf16_supported() -> _bool: ...
|
||||
def _is_amx_tile_supported() -> _bool: ...
|
||||
def _is_amx_fp16_supported() -> _bool: ...
|
||||
def _init_amx() -> _bool: ...
|
||||
def _is_arm_sve_supported() -> _bool: ...
|
||||
def _L1d_cache_size() -> _int: ...
|
||||
def _L2_cache_size() -> _int: ...
|
||||
|
@ -425,7 +425,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._cpu._is_amx_tile_supported",
|
||||
"torch._C._cpu._is_amx_fp16_supported",
|
||||
"torch._C._cpu._init_amx",
|
||||
"torch._C._cpu._is_arm_sve_supported",
|
||||
"torch._C._crash_if_aten_asan",
|
||||
"torch._C._crash_if_csrc_asan",
|
||||
"torch._C._crash_if_csrc_ubsan",
|
||||
@ -2440,7 +2439,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._cpu._is_amx_tile_supported",
|
||||
"torch._C._cpu._is_amx_fp16_supported",
|
||||
"torch.cpu._init_amx",
|
||||
"torch._C._cpu._is_arm_sve_supported",
|
||||
"torch.cpu.current_device",
|
||||
"torch.cpu.current_stream",
|
||||
"torch.cpu.device_count",
|
||||
|
@ -16,7 +16,7 @@ from ..cpu_vec_isa import (
|
||||
VecAVX512,
|
||||
VecISA,
|
||||
VecNEON,
|
||||
VecSVE,
|
||||
VecSVE256,
|
||||
)
|
||||
from ..utils import IndentedBuffer, parallel_num_threads
|
||||
from ..virtualized import V
|
||||
@ -339,7 +339,7 @@ class CppMicroGemmRef(CppMicroGemm):
|
||||
compute_dtype=torch.float,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecSVE,
|
||||
VecSVE256,
|
||||
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
||||
input_dtype=torch.float,
|
||||
input2_dtype=torch.float,
|
||||
|
@ -166,7 +166,7 @@ class VecNEON(VecISA):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecSVE(VecISA):
|
||||
class VecSVE256(VecISA):
|
||||
# this function can be repurposed for SVE with variable vec length
|
||||
_bit_width = 256
|
||||
_macro = [
|
||||
@ -328,7 +328,7 @@ def x86_isa_checker() -> list[str]:
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()]
|
||||
|
||||
|
||||
def get_isa_from_cpu_capability(
|
||||
@ -389,8 +389,8 @@ def valid_vec_isa_list() -> list[VecISA]:
|
||||
elif arch == "ppc64le":
|
||||
isa_list.append(VecVSX())
|
||||
elif arch == "aarch64":
|
||||
if torch.cpu._is_arm_sve_supported():
|
||||
isa_list.append(VecSVE())
|
||||
if torch.backends.cpu.get_cpu_capability() == "SVE256":
|
||||
isa_list.append(VecSVE256())
|
||||
else:
|
||||
isa_list.append(VecNEON())
|
||||
elif arch in ["x86_64", "AMD64"]:
|
||||
|
@ -65,11 +65,6 @@ def _init_amx() -> bool:
|
||||
return torch._C._cpu._init_amx()
|
||||
|
||||
|
||||
def _is_arm_sve_supported() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports Arm SVE."""
|
||||
return torch._C._cpu._is_arm_sve_supported()
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if CPU is currently available.
|
||||
|
||||
|
@ -15,7 +15,6 @@ void initModule(PyObject* module) {
|
||||
cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported);
|
||||
cpu.def("_is_amx_fp16_supported", at::cpu::is_amx_fp16_supported);
|
||||
cpu.def("_init_amx", at::cpu::init_amx);
|
||||
cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported);
|
||||
cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size);
|
||||
cpu.def("_L2_cache_size", at::cpu::L2_cache_size);
|
||||
}
|
||||
|
Reference in New Issue
Block a user