diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index b7b99e50d91b..2aff12cfa6df 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -92,14 +92,6 @@ bool init_amx() { #endif } -bool is_arm_sve_supported() { -#if !defined(__s390x__) && !defined(__powerpc__) - return cpuinfo_initialize() && cpuinfo_has_arm_sve(); -#else - return false; -#endif -} - static uint32_t get_cache_size(int level) { #if !defined(__s390x__) && !defined(__powerpc__) if (!cpuinfo_initialize()) { diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 1214e1e0ce6d..b339cb328b9b 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -24,9 +24,6 @@ TORCH_API bool is_amx_fp16_supported(); // Enable the system to use AMX instructions. TORCH_API bool init_amx(); -// Detect if CPU supports Arm(R) architecture SVE ISA -TORCH_API bool is_arm_sve_supported(); - // Get the L1 cache size per core in Byte TORCH_API uint32_t L1d_cache_size(); diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index f03164bfa00d..a667edc721a9 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -9,6 +9,5 @@ def _is_avx512_bf16_supported() -> _bool: ... def _is_amx_tile_supported() -> _bool: ... def _is_amx_fp16_supported() -> _bool: ... def _init_amx() -> _bool: ... -def _is_arm_sve_supported() -> _bool: ... def _L1d_cache_size() -> _int: ... def _L2_cache_size() -> _int: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ec9dd09c90ae..48ca3fa65cd6 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -425,7 +425,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._is_amx_fp16_supported", "torch._C._cpu._init_amx", - "torch._C._cpu._is_arm_sve_supported", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -2440,7 +2439,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._is_amx_fp16_supported", "torch.cpu._init_amx", - "torch._C._cpu._is_arm_sve_supported", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 21e2265c554d..50e7269742b5 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -16,7 +16,7 @@ from ..cpu_vec_isa import ( VecAVX512, VecISA, VecNEON, - VecSVE, + VecSVE256, ) from ..utils import IndentedBuffer, parallel_num_threads from ..virtualized import V @@ -339,7 +339,7 @@ class CppMicroGemmRef(CppMicroGemm): compute_dtype=torch.float, ), *generate_gemm_config( - VecSVE, + VecSVE256, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float, input2_dtype=torch.float, diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 6d0ca8f3156b..67b0160ae6a6 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -166,7 +166,7 @@ class VecNEON(VecISA): @dataclasses.dataclass -class VecSVE(VecISA): +class VecSVE256(VecISA): # this function can be repurposed for SVE with variable vec length _bit_width = 256 _macro = [ @@ -328,7 +328,7 @@ def x86_isa_checker() -> list[str]: invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()] +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE256()] def get_isa_from_cpu_capability( @@ -389,8 +389,8 @@ def valid_vec_isa_list() -> list[VecISA]: elif arch == "ppc64le": isa_list.append(VecVSX()) elif arch == "aarch64": - if torch.cpu._is_arm_sve_supported(): - isa_list.append(VecSVE()) + if torch.backends.cpu.get_cpu_capability() == "SVE256": + isa_list.append(VecSVE256()) else: isa_list.append(VecNEON()) elif arch in ["x86_64", "AMD64"]: diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 67ebb633802f..702fbaa3d978 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -65,11 +65,6 @@ def _init_amx() -> bool: return torch._C._cpu._init_amx() -def _is_arm_sve_supported() -> bool: - r"""Returns a bool indicating if CPU supports Arm SVE.""" - return torch._C._cpu._is_arm_sve_supported() - - def is_available() -> bool: r"""Returns a bool indicating if CPU is currently available. diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 5e3f4b5b18bb..38fea7f995c3 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -15,7 +15,6 @@ void initModule(PyObject* module) { cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported); cpu.def("_is_amx_fp16_supported", at::cpu::is_amx_fp16_supported); cpu.def("_init_amx", at::cpu::init_amx); - cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported); cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size); cpu.def("_L2_cache_size", at::cpu::L2_cache_size); }