Add torch.xpu.get_arch_list and torch.xpu.get_gencode_flags for XPU (#137773)

# Motivation
Add `torch.xpu.get_arch_list()` and `torch.xpu.get_gencode_flags()` methods that return architecture list and AOT flags to preserve what flags PyTorch XPU was built with.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137773
Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2024-10-16 03:35:05 +00:00
committed by PyTorch MergeBot
parent 6d8c9be54b
commit 8cda774a03
10 changed files with 70 additions and 0 deletions

View File

@ -306,6 +306,17 @@ macro(torch_hip_get_arch_list store_var)
string(REPLACE " " ";" ${store_var} "${_TMP}")
endmacro()
##############################################################################
# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST.
# Usage:
# torch_xpu_get_arch_list(variable_to_store_flags)
#
macro(torch_xpu_get_arch_list store_var)
if(DEFINED ENV{TORCH_XPU_ARCH_LIST})
set(${store_var} $ENV{TORCH_XPU_ARCH_LIST})
endif()
endmacro()
##############################################################################
# Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME.
# Usage:

View File

@ -28,3 +28,8 @@ add_library(torch::xpurt INTERFACE IMPORTED)
set_property(
TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES
torch::sycl)
# setting xpu arch flags
torch_xpu_get_arch_list(XPU_ARCH_FLAGS)
# propagate to torch-xpu-ops
set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS})

View File

@ -13,9 +13,11 @@ torch.xpu
device
device_count
device_of
get_arch_list
get_device_capability
get_device_name
get_device_properties
get_gencode_flags
init
is_available
is_initialized

View File

@ -119,6 +119,10 @@
# These are not CUDA versions, instead, they specify what
# classes of NVIDIA hardware we should generate PTX for.
#
# TORCH_XPU_ARCH_LIST
# specify which XPU architectures to build for.
# ie `TORCH_XPU_ARCH_LIST="ats-m150,lnl-m"`
#
# PYTORCH_ROCM_ARCH
# specify which AMD GPU targets to build for.
# ie `PYTORCH_ROCM_ARCH="gfx900;gfx906"`

View File

@ -420,6 +420,14 @@ print(torch.xpu.device_count())
)
)
def test_get_arch_list(self):
arch_list = torch.xpu.get_arch_list()
if not arch_list:
return
flags = torch.xpu.get_gencode_flags()
for arch in arch_list:
self.assertTrue(arch in flags)
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)

View File

@ -229,6 +229,7 @@ class CMake:
"STATIC_DISPATCH_BACKEND",
"SELECTED_OP_LIST",
"TORCH_CUDA_ARCH_LIST",
"TORCH_XPU_ARCH_LIST",
"TRACING_BASED",
"PYTHON_LIB_REL_PATH",
)

View File

@ -415,6 +415,12 @@ if(USE_ROCM)
set_source_files_properties(${TORCH_SRC_DIR}/csrc/cuda/Module.cpp PROPERTIES COMPILE_FLAGS "-DCUDA_ARCH_FLAGS=\"${PYTORCH_ROCM_ARCH_readable}\"")
endif()
# Preserve XPU arch flags
if(USE_XPU)
string(REPLACE "," " " _ARCH_FLAGS_readable "${TORCH_XPU_ARCH_LIST}")
set_source_files_properties(${TORCH_SRC_DIR}/csrc/xpu/Module.cpp PROPERTIES COMPILE_FLAGS "-DXPU_ARCH_FLAGS=\"${_ARCH_FLAGS_readable}\"")
endif()
target_compile_definitions(torch_python PRIVATE "-DTHP_BUILD_MAIN_LIB")
target_link_libraries(torch_python PRIVATE ${TORCH_LIB} ${TORCH_PYTHON_LINK_LIBRARIES})

View File

@ -2107,6 +2107,7 @@ def _xpu_exchangeDevice(device: _int) -> _int: ...
def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
def _xpu_getDevice() -> _int: ...
def _xpu_getDeviceCount() -> _int: ...
def _xpu_getArchFlags() -> Optional[str]: ...
def _xpu_init() -> None: ...
def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
def _xpu_getCurrentStream(device: _int) -> Tuple: ...

View File

@ -39,6 +39,17 @@ static void poison_fork() {
// XPU management methods
PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
#ifdef XPU_ARCH_FLAGS
static const char* flags = C10_STRINGIZE(XPU_ARCH_FLAGS);
return THPUtils_packString(flags);
#else
Py_RETURN_NONE;
#endif
END_HANDLE_TH_ERRORS
}
static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
@ -404,6 +415,7 @@ static struct PyMethodDef _THXPModule_methods[] = {
THXPModule_getDeviceCount_wrap,
METH_NOARGS,
nullptr},
{"_xpu_getArchFlags", THXPModule_getArchFlags, METH_NOARGS, nullptr},
{"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
{"_xpu_getCurrentStream",
THXPModule_getCurrentStream_wrap,

View File

@ -395,6 +395,24 @@ def synchronize(device: _device_t = None) -> None:
return torch._C._xpu_synchronize(device)
def get_arch_list() -> List[str]:
r"""Return list XPU architectures this library was compiled for."""
if not is_available():
return []
arch_flags = torch._C._xpu_getArchFlags()
if arch_flags is None:
return []
return arch_flags.split()
def get_gencode_flags() -> str:
r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with."""
arch_list = get_arch_list()
if len(arch_list) == 0:
return ""
return f'-device {",".join(arch for arch in arch_list)}'
def _get_generator(device: torch.device) -> torch._C.Generator:
r"""Return the XPU Generator object for the given device.
@ -478,9 +496,11 @@ __all__ = [
"device_of",
"device_count",
"empty_cache",
"get_arch_list",
"get_device_capability",
"get_device_name",
"get_device_properties",
"get_gencode_flags",
"get_rng_state",
"get_rng_state_all",
"get_stream",