Compare commits

...

1 Commits

Author SHA1 Message Date
6b5a96913f [Pytorch] Add Armv9-a build flags (#166640)
Summary:

Adding detection and enablement of Armv9a core instruction sets

Test Plan: CI

Differential Revision: D85860431
2025-10-30 09:33:36 -07:00
3 changed files with 103 additions and 2 deletions

View File

@ -403,9 +403,17 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND CPU_CAPABILITY_NAMES "SVE256")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION")
if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
if(CXX_ARMV9A_FOUND)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv9-a+sve2+fp16fml+sha3+bf16+i8mm -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
if(CXX_ARMV9A_FOUND)
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv9-a+sve2+fp16fml+sha3+bf16+i8mm -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE")
else()
list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256")
endif()
endif()
endif()

View File

@ -106,6 +106,37 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
}
")
SET(SVE2_CODE "
#include <arm_neon.h>
#include <arm_sve.h>
int main()
{
//SVE2
svuint8_t a = svdup_n_u8(0);
svuint8_t b = svdup_n_u8(1);
svuint8_t c = svdup_n_u8(2);
a = sveor3_u8(a, b, c);
//sha3
uint8x16_t x = vdupq_n_u8(0);
uint8x16_t y = vdupq_n_u8(1);
uint8x16_t z = vdupq_n_u8(2);
x = veor3q_u8(x, y, z);
//fp16fml
float32x4_t i = vdupq_n_f32(1.0);
float16x8_t j = vdupq_n_f16(1.0);
float16x8_t k = vdupq_n_f16(1.0);
i = vfmlalq_low_f16(i, j, k);
//bf16
bfloat16x8_t h = vreinterpretq_bf16_f16(j);
//i8mm
int32x4_t d = vdupq_n_s32(1);
int8x16_t e = vdupq_n_s8(2);
int8x16_t f = vdupq_n_s8(3);
d = vmmlaq_s32(d, e, f);
return 0;
}
")
SET(ARM_BF16_CODE "
#include <arm_neon.h>
int main()
@ -153,6 +184,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
# Check for SVE256 vector length
CHECK_COMPILES(CXX "SVE256" "-march=armv8.2-a+sve -msve-vector-bits=256" "${SVE_CODE}")
CHECK_COMPILES(CXX "ARMV9A" "-march=armv9-a+sve2+fp16fml+sha3+bf16+i8mm" "${SVE2_CODE}")
CHECK_COMPILES(CXX "ARM_BF16" "-march=armv8.2-a+sve+bf16 -msve-vector-bits=256" "${ARM_BF16_CODE}")
# If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user

View File

@ -74,6 +74,37 @@ from ctypes import cdll
cdll.LoadLibrary("__lib_path__")
"""
_armv9a_code = """
#include <arm_neon.h>
#include <arm_sve.h>
int main()
{
//SVE2
svuint8_t a = svdup_n_u8(0);
svuint8_t b = svdup_n_u8(1);
svuint8_t c = svdup_n_u8(2);
a = sveor3_u8(a, b, c);
//sha3
uint8x16_t x = vdupq_n_u8(0);
uint8x16_t y = vdupq_n_u8(1);
uint8x16_t z = vdupq_n_u8(2);
x = veor3q_u8(x, y, z);
//fp16fml
float32x4_t i = vdupq_n_f32(1.0);
float16x8_t j = vdupq_n_f16(1.0);
float16x8_t k = vdupq_n_f16(1.0);
i = vfmlalq_low_f16(i, j, k);
//bf16
bfloat16x8_t h = vreinterpretq_bf16_f16(j);
//i8mm
int32x4_t d = vdupq_n_s32(1);
int8x16_t e = vdupq_n_s8(2);
int8x16_t f = vdupq_n_s8(3);
d = vmmlaq_s32(d, e, f);
return 0;
}
""" # noqa: B950
def bit_width(self) -> int:
return self._bit_width
@ -160,6 +191,21 @@ class VecNEON(VecISA):
_arch_flags = "" # Unused
_dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8}
@functools.cache # noqa: B019
# pyrefly: ignore [bad-override]
def __bool__(self) -> bool:
# check armv9a
if not _IS_WINDOWS:
# save _arch_flags
base_flags = self._arch_flags
# temporarily change _arch_flags for armv9a check_build
self._arch_flags = "-march=armv9-a+sve2+fp16fml+sha3+bf16+i8mm"
if self.check_build(VecISA._armv9a_code):
return True
# restore _arch_flags
self._arch_flags = base_flags
return super().__bool__()
def __str__(self) -> str:
if config.is_fbcode():
return "neon"
@ -182,6 +228,21 @@ class VecSVE256(VecISA):
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
def __bool__(self) -> bool:
# check armv9a
if not _IS_WINDOWS:
# save _arch_flags
base_flags = self._arch_flags
# temporarily change _arch_flags for armv9a check_build
self._arch_flags = (
"-march=armv9-a+sve2+fp16fml+sha3+bf16+i8mm -msve-vector-bits=256"
)
if self.check_build(VecISA._armv9a_code):
return True
# restore _arch_flags
self._arch_flags = base_flags
return super().__bool__()
def __str__(self) -> str:
if config.is_fbcode():
return "neon"