From e56dcf2772220ea1a738bc5fa072a0b9ff4f355a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 31 Jan 2025 14:59:34 -0800 Subject: [PATCH] [CPUInductor] Fix SVE256 detection (#146207) This PR removes `torch.cpu._is_arm_sve_supported()` and replaces is with stable `torch.backends.cpu.get_cpu_capability()` I should have reviewed https://github.com/pytorch/pytorch/pull/134672 more thoroughly, because it introduced duplicate, but slightly different API for detecting CPU architectures, which resulted in runtime crashes on system that do support SVE128, rather than SVE256 Fixes https://github.com/pytorch/pytorch/issues/145441 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146207 Approved by: https://github.com/angelayi --- aten/src/ATen/cpu/Utils.cpp | 8 -------- aten/src/ATen/cpu/Utils.h | 3 --- torch/_C/_cpu.pyi | 1 - torch/_dynamo/trace_rules.py | 2 -- torch/_inductor/codegen/cpp_micro_gemm.py | 4 ++-- torch/_inductor/cpu_vec_isa.py | 8 ++++---- torch/cpu/__init__.py | 5 ----- torch/csrc/cpu/Module.cpp | 1 - 8 files changed, 6 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index b7b99e50d91b..2aff12cfa6df 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -92,14 +92,6 @@ bool init_amx() { #endif } -bool is_arm_sve_supported() { -#if !defined(__s390x__) && !defined(__powerpc__) - return cpuinfo_initialize() && cpuinfo_has_arm_sve(); -#else - return false; -#endif -} - static uint32_t get_cache_size(int level) { #if !defined(__s390x__) && !defined(__powerpc__) if (!cpuinfo_initialize()) { diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 1214e1e0ce6d..b339cb328b9b 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -24,9 +24,6 @@ TORCH_API bool is_amx_fp16_supported(); // Enable the system to use AMX instructions. TORCH_API bool init_amx(); -// Detect if CPU supports Arm(R) architecture SVE ISA -TORCH_API bool is_arm_sve_supported(); - // Get the L1 cache size per core in Byte TORCH_API uint32_t L1d_cache_size(); diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index f03164bfa00d..a667edc721a9 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -9,6 +9,5 @@ def _is_avx512_bf16_supported() -> _bool: ... def _is_amx_tile_supported() -> _bool: ... def _is_amx_fp16_supported() -> _bool: ... def _init_amx() -> _bool: ... -def _is_arm_sve_supported() -> _bool: ... def _L1d_cache_size() -> _int: ... def _L2_cache_size() -> _int: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ec9dd09c90ae..48ca3fa65cd6 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -425,7 +425,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._is_amx_fp16_supported", "torch._C._cpu._init_amx", - "torch._C._cpu._is_arm_sve_supported", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -2440,7 +2439,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._is_amx_fp16_supported", "torch.cpu._init_amx", - "torch._C._cpu._is_arm_sve_supported", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 21e2265c554d..50e7269742b5 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -16,7 +16,7 @@ from ..cpu_vec_isa import ( VecAVX512, VecISA, VecNEON, - VecSVE, + VecSVE256, ) from ..utils import IndentedBuffer, parallel_num_threads from ..virtualized import V @@ -339,7 +339,7 @@ class CppMicroGemmRef(CppMicroGemm): compute_dtype=torch.float, ), *generate_gemm_config( - VecSVE, + VecSVE256, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float, input2_dtype=torch.float, diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 6d0ca8f3156b..67b0160ae6a6 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -166,7 +166,7 @@ class VecNEON(VecISA): @dataclasses.dataclass -class VecSVE(VecISA): +class VecSVE256(VecISA): # this function can be repurposed for SVE with variable vec length _bit_width = 256 _macro = [ @@ -328,7 +328,7 @@ def x86_isa_checker() -> list[str]: invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()] +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()] def get_isa_from_cpu_capability( @@ -389,8 +389,8 @@ def valid_vec_isa_list() -> list[VecISA]: elif arch == "ppc64le": isa_list.append(VecVSX()) elif arch == "aarch64": - if torch.cpu._is_arm_sve_supported(): - isa_list.append(VecSVE()) + if torch.backends.cpu.get_cpu_capability() == "SVE256": + isa_list.append(VecSVE256()) else: isa_list.append(VecNEON()) elif arch in ["x86_64", "AMD64"]: diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 67ebb633802f..702fbaa3d978 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -65,11 +65,6 @@ def _init_amx() -> bool: return torch._C._cpu._init_amx() -def _is_arm_sve_supported() -> bool: - r"""Returns a bool indicating if CPU supports Arm SVE.""" - return torch._C._cpu._is_arm_sve_supported() - - def is_available() -> bool: r"""Returns a bool indicating if CPU is currently available. diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 5e3f4b5b18bb..38fea7f995c3 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -15,7 +15,6 @@ void initModule(PyObject* module) { cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported); cpu.def("_is_amx_fp16_supported", at::cpu::is_amx_fp16_supported); cpu.def("_init_amx", at::cpu::init_amx); - cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported); cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size); cpu.def("_L2_cache_size", at::cpu::L2_cache_size); }