s390x: use runtime detection for vectorization support (#123936)

s390x: use runtime detection for vectorization support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123936
Approved by: https://github.com/malfet, https://github.com/jansel, https://github.com/xuhancn
This commit is contained in:
Aleksei Nikiforov
2024-05-03 21:34:34 +00:00
committed by PyTorch MergeBot
parent 5503c29357
commit 2b5ae2611e
3 changed files with 47 additions and 36 deletions

View File

@ -10,8 +10,19 @@
#include <cstdlib>
#include <cstring>
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
#include <sys/auxv.h>
#endif
namespace at::native {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static inline bool cpu_has_vxe()
{
return (getauxval(AT_HWCAP) & HWCAP_S390_VXE);
}
#endif
static CPUCapability compute_cpu_capability() {
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
if (envar) {
@ -60,10 +71,16 @@ static CPUCapability compute_cpu_capability() {
#endif
}
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
// vxe is needed for fp32 vector instructions
if (cpu_has_vxe()) {
return CPUCapability::ZVECTOR;
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
return CPUCapability::VSX;
#elif HAVE_ZVECTOR_CPU_DEFINITION
return CPUCapability::ZVECTOR;
#else
return CPUCapability::DEFAULT;
#endif

View File

@ -1,22 +1,5 @@
IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
message("-- <FindZVECTOR>")
set(Z_ARCH_LIST "")
#firstly, tries to add the arch of the platform
EXEC_PROGRAM(LD_SHOW_AUXV=1 ARGS "/bin/true" OUTPUT_VARIABLE bintrue)
if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*")
if(CMAKE_MATCH_COUNT GREATER 0)
string(TOLOWER ${CMAKE_MATCH_1} platform)
if(${platform} MATCHES "^z(14|15|16)")
message("-- Z ARCH Platform: ${platform}")
list( APPEND Z_ARCH_LIST "${platform}" )
endif()
endif()
endif()
#adds other archs in descending order. as its cached nothing will be checked twice
list( APPEND Z_ARCH_LIST "z16" )
list( APPEND Z_ARCH_LIST "z15" )
list( APPEND Z_ARCH_LIST "z14" )
SET(VECTORIZATION_CODE "
#include <vecintrin.h>
@ -32,25 +15,25 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
vuint32 selector= {0xFFFFFFFF, 0, 0xFFFFFFFF, 0xFFFFFFFF};
vfloat32 hf = vsel_ext(selector, h1,h2);
int ret = (int)(hf[0]*1000+hf[1]*100+hf[2]*10+hf[3]);
return ret==3856;
return (ret == 3856) ? 0 : -1;
}
")
foreach(Z_ARCH ${Z_ARCH_LIST})
SET(ARCH_SIMD_TEST_FLAGS_${Z_ARCH} " -mvx -mzvector -march=${Z_ARCH} -mtune=${Z_ARCH}")
message("-- check ${Z_ARCH}")
SET(ARCH_SIMD_TEST_FLAGS " -mvx -mzvector")
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
SET(CMAKE_REQUIRED_FLAGS "${ARCH_SIMD_TEST_FLAGS_${Z_ARCH}}")
set(VECTORIZATION_CODE_${Z_ARCH} "${VECTORIZATION_CODE}")
CHECK_CXX_SOURCE_COMPILES("${VECTORIZATION_CODE_${Z_ARCH}}" COMPILE_OUT_${Z_ARCH})
SET(CMAKE_REQUIRED_FLAGS "${ARCH_SIMD_TEST_FLAGS}")
# Do compilation check instead of runtime check
# in case it is compiled on older hardware
# or crosscompiled
CHECK_CXX_SOURCE_COMPILES("${VECTORIZATION_CODE}" COMPILE_OUT_ZVECTOR)
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
if(COMPILE_OUT_${Z_ARCH})
message("-- ${Z_ARCH} SIMD flags were set.")
if(COMPILE_OUT_ZVECTOR)
message("-- ZVECTOR flags were set.")
set(CXX_ZVECTOR_FOUND TRUE)
SET(CXX_ZVECTOR_FLAGS "${ARCH_SIMD_TEST_FLAGS_${Z_ARCH}}" )
break()
SET(CXX_ZVECTOR_FLAGS "${ARCH_SIMD_TEST_FLAGS}" )
else()
message("-- ZVECTOR flags were NOT set.")
endif()
endforeach()
message("-- </FindZVECTOR>")
endif()

View File

@ -1294,7 +1294,18 @@ def valid_vec_isa_list() -> List[VecISA]:
return []
if platform.machine() == "s390x":
with open("/proc/cpuinfo") as _cpu_info:
while True:
line = _cpu_info.readline()
if not line:
break
# process line
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
if featuresmatch:
for group in featuresmatch.groups():
if re.search(r"[\^ ]+vxe[\$ ]+", group):
return [VecZVECTOR()]
return []
isa_list = []
with open("/proc/cpuinfo") as _cpu_info: