mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.xpu.get_arch_list and torch.xpu.get_gencode_flags for XPU (#137773)
# Motivation Add `torch.xpu.get_arch_list()` and `torch.xpu.get_gencode_flags()` methods that return architecture list and AOT flags to preserve what flags PyTorch XPU was built with. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137773 Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6d8c9be54b
commit
8cda774a03
@ -306,6 +306,17 @@ macro(torch_hip_get_arch_list store_var)
|
||||
string(REPLACE " " ";" ${store_var} "${_TMP}")
|
||||
endmacro()
|
||||
|
||||
##############################################################################
|
||||
# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST.
|
||||
# Usage:
|
||||
# torch_xpu_get_arch_list(variable_to_store_flags)
|
||||
#
|
||||
macro(torch_xpu_get_arch_list store_var)
|
||||
if(DEFINED ENV{TORCH_XPU_ARCH_LIST})
|
||||
set(${store_var} $ENV{TORCH_XPU_ARCH_LIST})
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
##############################################################################
|
||||
# Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME.
|
||||
# Usage:
|
||||
|
@ -28,3 +28,8 @@ add_library(torch::xpurt INTERFACE IMPORTED)
|
||||
set_property(
|
||||
TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES
|
||||
torch::sycl)
|
||||
|
||||
# setting xpu arch flags
|
||||
torch_xpu_get_arch_list(XPU_ARCH_FLAGS)
|
||||
# propagate to torch-xpu-ops
|
||||
set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS})
|
||||
|
@ -13,9 +13,11 @@ torch.xpu
|
||||
device
|
||||
device_count
|
||||
device_of
|
||||
get_arch_list
|
||||
get_device_capability
|
||||
get_device_name
|
||||
get_device_properties
|
||||
get_gencode_flags
|
||||
init
|
||||
is_available
|
||||
is_initialized
|
||||
|
4
setup.py
4
setup.py
@ -119,6 +119,10 @@
|
||||
# These are not CUDA versions, instead, they specify what
|
||||
# classes of NVIDIA hardware we should generate PTX for.
|
||||
#
|
||||
# TORCH_XPU_ARCH_LIST
|
||||
# specify which XPU architectures to build for.
|
||||
# ie `TORCH_XPU_ARCH_LIST="ats-m150,lnl-m"`
|
||||
#
|
||||
# PYTORCH_ROCM_ARCH
|
||||
# specify which AMD GPU targets to build for.
|
||||
# ie `PYTORCH_ROCM_ARCH="gfx900;gfx906"`
|
||||
|
@ -420,6 +420,14 @@ print(torch.xpu.device_count())
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_arch_list(self):
|
||||
arch_list = torch.xpu.get_arch_list()
|
||||
if not arch_list:
|
||||
return
|
||||
flags = torch.xpu.get_gencode_flags()
|
||||
for arch in arch_list:
|
||||
self.assertTrue(arch in flags)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
@ -229,6 +229,7 @@ class CMake:
|
||||
"STATIC_DISPATCH_BACKEND",
|
||||
"SELECTED_OP_LIST",
|
||||
"TORCH_CUDA_ARCH_LIST",
|
||||
"TORCH_XPU_ARCH_LIST",
|
||||
"TRACING_BASED",
|
||||
"PYTHON_LIB_REL_PATH",
|
||||
)
|
||||
|
@ -415,6 +415,12 @@ if(USE_ROCM)
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/cuda/Module.cpp PROPERTIES COMPILE_FLAGS "-DCUDA_ARCH_FLAGS=\"${PYTORCH_ROCM_ARCH_readable}\"")
|
||||
endif()
|
||||
|
||||
# Preserve XPU arch flags
|
||||
if(USE_XPU)
|
||||
string(REPLACE "," " " _ARCH_FLAGS_readable "${TORCH_XPU_ARCH_LIST}")
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/xpu/Module.cpp PROPERTIES COMPILE_FLAGS "-DXPU_ARCH_FLAGS=\"${_ARCH_FLAGS_readable}\"")
|
||||
endif()
|
||||
|
||||
target_compile_definitions(torch_python PRIVATE "-DTHP_BUILD_MAIN_LIB")
|
||||
|
||||
target_link_libraries(torch_python PRIVATE ${TORCH_LIB} ${TORCH_PYTHON_LINK_LIBRARIES})
|
||||
|
@ -2107,6 +2107,7 @@ def _xpu_exchangeDevice(device: _int) -> _int: ...
|
||||
def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
|
||||
def _xpu_getDevice() -> _int: ...
|
||||
def _xpu_getDeviceCount() -> _int: ...
|
||||
def _xpu_getArchFlags() -> Optional[str]: ...
|
||||
def _xpu_init() -> None: ...
|
||||
def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
|
||||
def _xpu_getCurrentStream(device: _int) -> Tuple: ...
|
||||
|
@ -39,6 +39,17 @@ static void poison_fork() {
|
||||
|
||||
// XPU management methods
|
||||
|
||||
PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef XPU_ARCH_FLAGS
|
||||
static const char* flags = C10_STRINGIZE(XPU_ARCH_FLAGS);
|
||||
return THPUtils_packString(flags);
|
||||
#else
|
||||
Py_RETURN_NONE;
|
||||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return PyBool_FromLong(in_bad_fork);
|
||||
@ -404,6 +415,7 @@ static struct PyMethodDef _THXPModule_methods[] = {
|
||||
THXPModule_getDeviceCount_wrap,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_xpu_getArchFlags", THXPModule_getArchFlags, METH_NOARGS, nullptr},
|
||||
{"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr},
|
||||
{"_xpu_getCurrentStream",
|
||||
THXPModule_getCurrentStream_wrap,
|
||||
|
@ -395,6 +395,24 @@ def synchronize(device: _device_t = None) -> None:
|
||||
return torch._C._xpu_synchronize(device)
|
||||
|
||||
|
||||
def get_arch_list() -> List[str]:
|
||||
r"""Return list XPU architectures this library was compiled for."""
|
||||
if not is_available():
|
||||
return []
|
||||
arch_flags = torch._C._xpu_getArchFlags()
|
||||
if arch_flags is None:
|
||||
return []
|
||||
return arch_flags.split()
|
||||
|
||||
|
||||
def get_gencode_flags() -> str:
|
||||
r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with."""
|
||||
arch_list = get_arch_list()
|
||||
if len(arch_list) == 0:
|
||||
return ""
|
||||
return f'-device {",".join(arch for arch in arch_list)}'
|
||||
|
||||
|
||||
def _get_generator(device: torch.device) -> torch._C.Generator:
|
||||
r"""Return the XPU Generator object for the given device.
|
||||
|
||||
@ -478,9 +496,11 @@ __all__ = [
|
||||
"device_of",
|
||||
"device_count",
|
||||
"empty_cache",
|
||||
"get_arch_list",
|
||||
"get_device_capability",
|
||||
"get_device_name",
|
||||
"get_device_properties",
|
||||
"get_gencode_flags",
|
||||
"get_rng_state",
|
||||
"get_rng_state_all",
|
||||
"get_stream",
|
||||
|
Reference in New Issue
Block a user