[CPUInductor] Fix SVE256 detection (#146207)

This PR removes `torch.cpu._is_arm_sve_supported()` and replaces is with stable `torch.backends.cpu.get_cpu_capability()`

I should have reviewed https://github.com/pytorch/pytorch/pull/134672 more thoroughly, because it introduced duplicate, but slightly different API for detecting CPU architectures, which resulted in runtime crashes on system that do support SVE128, rather than SVE256

Fixes https://github.com/pytorch/pytorch/issues/145441

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146207
Approved by: https://github.com/angelayi
This commit is contained in:
Nikita Shulga
2025-01-31 14:59:34 -08:00
committed by PyTorch MergeBot
parent 8c657ae4be
commit e56dcf2772
8 changed files with 6 additions and 26 deletions

View File

@ -92,14 +92,6 @@ bool init_amx() {
#endif
}
bool is_arm_sve_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_arm_sve();
#else
return false;
#endif
}
static uint32_t get_cache_size(int level) {
#if !defined(__s390x__) && !defined(__powerpc__)
if (!cpuinfo_initialize()) {

View File

@ -24,9 +24,6 @@ TORCH_API bool is_amx_fp16_supported();
// Enable the system to use AMX instructions.
TORCH_API bool init_amx();
// Detect if CPU supports Arm(R) architecture SVE ISA
TORCH_API bool is_arm_sve_supported();
// Get the L1 cache size per core in Byte
TORCH_API uint32_t L1d_cache_size();

View File

@ -9,6 +9,5 @@ def _is_avx512_bf16_supported() -> _bool: ...
def _is_amx_tile_supported() -> _bool: ...
def _is_amx_fp16_supported() -> _bool: ...
def _init_amx() -> _bool: ...
def _is_arm_sve_supported() -> _bool: ...
def _L1d_cache_size() -> _int: ...
def _L2_cache_size() -> _int: ...

View File

@ -425,7 +425,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._cpu._is_amx_tile_supported",
"torch._C._cpu._is_amx_fp16_supported",
"torch._C._cpu._init_amx",
"torch._C._cpu._is_arm_sve_supported",
"torch._C._crash_if_aten_asan",
"torch._C._crash_if_csrc_asan",
"torch._C._crash_if_csrc_ubsan",
@ -2440,7 +2439,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._cpu._is_amx_tile_supported",
"torch._C._cpu._is_amx_fp16_supported",
"torch.cpu._init_amx",
"torch._C._cpu._is_arm_sve_supported",
"torch.cpu.current_device",
"torch.cpu.current_stream",
"torch.cpu.device_count",

View File

@ -16,7 +16,7 @@ from ..cpu_vec_isa import (
VecAVX512,
VecISA,
VecNEON,
VecSVE,
VecSVE256,
)
from ..utils import IndentedBuffer, parallel_num_threads
from ..virtualized import V
@ -339,7 +339,7 @@ class CppMicroGemmRef(CppMicroGemm):
compute_dtype=torch.float,
),
*generate_gemm_config(
VecSVE,
VecSVE256,
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
input_dtype=torch.float,
input2_dtype=torch.float,

View File

@ -166,7 +166,7 @@ class VecNEON(VecISA):
@dataclasses.dataclass
class VecSVE(VecISA):
class VecSVE256(VecISA):
# this function can be repurposed for SVE with variable vec length
_bit_width = 256
_macro = [
@ -328,7 +328,7 @@ def x86_isa_checker() -> list[str]:
invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()]
def get_isa_from_cpu_capability(
@ -389,8 +389,8 @@ def valid_vec_isa_list() -> list[VecISA]:
elif arch == "ppc64le":
isa_list.append(VecVSX())
elif arch == "aarch64":
if torch.cpu._is_arm_sve_supported():
isa_list.append(VecSVE())
if torch.backends.cpu.get_cpu_capability() == "SVE256":
isa_list.append(VecSVE256())
else:
isa_list.append(VecNEON())
elif arch in ["x86_64", "AMD64"]:

View File

@ -65,11 +65,6 @@ def _init_amx() -> bool:
return torch._C._cpu._init_amx()
def _is_arm_sve_supported() -> bool:
r"""Returns a bool indicating if CPU supports Arm SVE."""
return torch._C._cpu._is_arm_sve_supported()
def is_available() -> bool:
r"""Returns a bool indicating if CPU is currently available.

View File

@ -15,7 +15,6 @@ void initModule(PyObject* module) {
cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported);
cpu.def("_is_amx_fp16_supported", at::cpu::is_amx_fp16_supported);
cpu.def("_init_amx", at::cpu::init_amx);
cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported);
cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size);
cpu.def("_L2_cache_size", at::cpu::L2_cache_size);
}