diff --git a/aten/src/ATen/xpu/XPUContext.cpp b/aten/src/ATen/xpu/XPUContext.cpp index 2157e34648b8..e956ec9a1659 100644 --- a/aten/src/ATen/xpu/XPUContext.cpp +++ b/aten/src/ATen/xpu/XPUContext.cpp @@ -76,4 +76,23 @@ int32_t getGlobalIdxFromDevice(DeviceIndex device) { return device_global_idxs[device]; } +// Check if a device can access the memory of a peer device directly. +bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer) { + if (device == -1) { + device = c10::xpu::current_device(); + } + if (peer == -1) { + peer = c10::xpu::current_device(); + } + check_device_index(device); + check_device_index(peer); + // A device can always access itself + if (device == peer) { + return true; + } + return c10::xpu::get_raw_device(device).ext_oneapi_can_access_peer( + c10::xpu::get_raw_device(peer), + sycl::ext::oneapi::peer_access::access_supported); +} + } // namespace at::xpu diff --git a/aten/src/ATen/xpu/XPUContext.h b/aten/src/ATen/xpu/XPUContext.h index fb8fbe9c0aa4..a473f317ca3d 100644 --- a/aten/src/ATen/xpu/XPUContext.h +++ b/aten/src/ATen/xpu/XPUContext.h @@ -17,4 +17,6 @@ TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device); TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device); +TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer); + } // namespace at::xpu diff --git a/docs/source/xpu.md b/docs/source/xpu.md index 53a5fadeca35..1496a7f82c58 100644 --- a/docs/source/xpu.md +++ b/docs/source/xpu.md @@ -12,6 +12,7 @@ :nosignatures: StreamContext + can_device_access_peer current_device current_stream device diff --git a/test/test_xpu.py b/test/test_xpu.py index 04d045b00d8b..3474e4031ef2 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -585,6 +585,16 @@ if __name__ == "__main__": for arch in arch_list: self.assertTrue(arch in flags) + @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") + def test_can_device_access_peer(self): + device_count = torch.xpu.device_count() + for device in range(device_count): + for peer in range(device_count): + self.assertEqual( + torch.xpu.can_device_access_peer(device, peer), + torch.xpu.can_device_access_peer(peer, device), + ) + def test_torch_version_xpu(self): self.assertEqual(len(torch.version.xpu), 8) compiler_version = int(torch.version.xpu) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e55137c3d2bf..3b183c8af835 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2369,6 +2369,7 @@ def _xpu_memoryStats(device: _int) -> dict[str, Any]: ... def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... def _xpu_resetPeakMemoryStats(device: _int) -> None: ... def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ... +def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ... class _XpuDeviceProperties: name: str diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index d49fc0539a08..8f1aead1900c 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -415,6 +415,11 @@ static void initXpuMethodBindings(PyObject* module) { return std::make_tuple( stream.id(), stream.device_index(), stream.device_type()); }); + m.def( + "_xpu_canDeviceAccessPeer", + [](c10::DeviceIndex device, c10::DeviceIndex peer) { + return at::xpu::canDeviceAccessPeer(device, peer); + }); } // Callback for python part. Used for additional initialization of python diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 79aae38a3168..6e15bf4380e3 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -280,6 +280,22 @@ def _get_device(device: Union[int, str, torch.device]) -> torch.device: return device +def can_device_access_peer(device: _device_t, peer: _device_t) -> bool: + r"""Query whether a device can access a peer device's memory. + + Args: + device (torch.device or int or str): selected device. + peer (torch.device or int or str): peer device to query access to. + + Returns: + bool: ``True`` if ``device`` can access ``peer``, ``False`` otherwise. + """ + _lazy_init() + device = _get_device_index(device, optional=True) + peer = _get_device_index(peer, optional=True) + return torch._C._xpu_canDeviceAccessPeer(device, peer) + + class StreamContext: r"""Context-manager that selects a given stream. @@ -518,6 +534,7 @@ __all__ = [ "Event", "Stream", "StreamContext", + "can_device_access_peer", "current_device", "current_stream", "default_generators",