mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
s390x: use runtime detection for vectorization support (#123936)
s390x: use runtime detection for vectorization support Pull Request resolved: https://github.com/pytorch/pytorch/pull/123936 Approved by: https://github.com/malfet, https://github.com/jansel, https://github.com/xuhancn
This commit is contained in:
committed by
PyTorch MergeBot
parent
5503c29357
commit
2b5ae2611e
@ -10,8 +10,19 @@
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
#include <sys/auxv.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
static inline bool cpu_has_vxe()
|
||||
{
|
||||
return (getauxval(AT_HWCAP) & HWCAP_S390_VXE);
|
||||
}
|
||||
#endif
|
||||
|
||||
static CPUCapability compute_cpu_capability() {
|
||||
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
|
||||
if (envar) {
|
||||
@ -60,10 +71,16 @@ static CPUCapability compute_cpu_capability() {
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
||||
// vxe is needed for fp32 vector instructions
|
||||
if (cpu_has_vxe()) {
|
||||
return CPUCapability::ZVECTOR;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_VSX_CPU_DEFINITION
|
||||
return CPUCapability::VSX;
|
||||
#elif HAVE_ZVECTOR_CPU_DEFINITION
|
||||
return CPUCapability::ZVECTOR;
|
||||
#else
|
||||
return CPUCapability::DEFAULT;
|
||||
#endif
|
||||
|
@ -1,22 +1,5 @@
|
||||
|
||||
IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
message("-- <FindZVECTOR>")
|
||||
set(Z_ARCH_LIST "")
|
||||
#firstly, tries to add the arch of the platform
|
||||
EXEC_PROGRAM(LD_SHOW_AUXV=1 ARGS "/bin/true" OUTPUT_VARIABLE bintrue)
|
||||
if(bintrue MATCHES "AT_PLATFORM:[ \\t\\n\\r]*([a-zA-Z0-9_]+)[ \\t\\n\\r]*")
|
||||
if(CMAKE_MATCH_COUNT GREATER 0)
|
||||
string(TOLOWER ${CMAKE_MATCH_1} platform)
|
||||
if(${platform} MATCHES "^z(14|15|16)")
|
||||
message("-- Z ARCH Platform: ${platform}")
|
||||
list( APPEND Z_ARCH_LIST "${platform}" )
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
#adds other archs in descending order. as its cached nothing will be checked twice
|
||||
list( APPEND Z_ARCH_LIST "z16" )
|
||||
list( APPEND Z_ARCH_LIST "z15" )
|
||||
list( APPEND Z_ARCH_LIST "z14" )
|
||||
|
||||
SET(VECTORIZATION_CODE "
|
||||
#include <vecintrin.h>
|
||||
@ -32,25 +15,25 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
vuint32 selector= {0xFFFFFFFF, 0, 0xFFFFFFFF, 0xFFFFFFFF};
|
||||
vfloat32 hf = vsel_ext(selector, h1,h2);
|
||||
int ret = (int)(hf[0]*1000+hf[1]*100+hf[2]*10+hf[3]);
|
||||
return ret==3856;
|
||||
return (ret == 3856) ? 0 : -1;
|
||||
}
|
||||
")
|
||||
|
||||
foreach(Z_ARCH ${Z_ARCH_LIST})
|
||||
SET(ARCH_SIMD_TEST_FLAGS_${Z_ARCH} " -mvx -mzvector -march=${Z_ARCH} -mtune=${Z_ARCH}")
|
||||
message("-- check ${Z_ARCH}")
|
||||
SET(ARCH_SIMD_TEST_FLAGS " -mvx -mzvector")
|
||||
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
SET(CMAKE_REQUIRED_FLAGS "${ARCH_SIMD_TEST_FLAGS_${Z_ARCH}}")
|
||||
set(VECTORIZATION_CODE_${Z_ARCH} "${VECTORIZATION_CODE}")
|
||||
CHECK_CXX_SOURCE_COMPILES("${VECTORIZATION_CODE_${Z_ARCH}}" COMPILE_OUT_${Z_ARCH})
|
||||
SET(CMAKE_REQUIRED_FLAGS "${ARCH_SIMD_TEST_FLAGS}")
|
||||
# Do compilation check instead of runtime check
|
||||
# in case it is compiled on older hardware
|
||||
# or crosscompiled
|
||||
CHECK_CXX_SOURCE_COMPILES("${VECTORIZATION_CODE}" COMPILE_OUT_ZVECTOR)
|
||||
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
if(COMPILE_OUT_${Z_ARCH})
|
||||
message("-- ${Z_ARCH} SIMD flags were set.")
|
||||
if(COMPILE_OUT_ZVECTOR)
|
||||
message("-- ZVECTOR flags were set.")
|
||||
set(CXX_ZVECTOR_FOUND TRUE)
|
||||
SET(CXX_ZVECTOR_FLAGS "${ARCH_SIMD_TEST_FLAGS_${Z_ARCH}}" )
|
||||
break()
|
||||
SET(CXX_ZVECTOR_FLAGS "${ARCH_SIMD_TEST_FLAGS}" )
|
||||
else()
|
||||
message("-- ZVECTOR flags were NOT set.")
|
||||
endif()
|
||||
endforeach()
|
||||
message("-- </FindZVECTOR>")
|
||||
|
||||
endif()
|
||||
|
@ -1294,7 +1294,18 @@ def valid_vec_isa_list() -> List[VecISA]:
|
||||
return []
|
||||
|
||||
if platform.machine() == "s390x":
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
while True:
|
||||
line = _cpu_info.readline()
|
||||
if not line:
|
||||
break
|
||||
# process line
|
||||
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
|
||||
if featuresmatch:
|
||||
for group in featuresmatch.groups():
|
||||
if re.search(r"[\^ ]+vxe[\$ ]+", group):
|
||||
return [VecZVECTOR()]
|
||||
return []
|
||||
|
||||
isa_list = []
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
|
Reference in New Issue
Block a user