Add a new API torch.xpu.can_device_access_peer for Intel GPU (#162705)

# Motivation
Aligned with other backends, this PR introduces an new API `torch.xpu.can_device_access_peer`, which is used in vllm distributed [scenarios](2048c4e379/vllm/distributed/device_communicators/custom_all_reduce.py (L37))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162705
Approved by: https://github.com/EikanWang, https://github.com/ezyang
This commit is contained in:
Yu, Guangye
2025-09-15 13:07:09 +00:00
committed by PyTorch MergeBot
parent 6db37d7206
commit 0819de412d
7 changed files with 55 additions and 0 deletions

View File

@ -76,4 +76,23 @@ int32_t getGlobalIdxFromDevice(DeviceIndex device) {
return device_global_idxs[device]; return device_global_idxs[device];
} }
// Check if a device can access the memory of a peer device directly.
bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer) {
if (device == -1) {
device = c10::xpu::current_device();
}
if (peer == -1) {
peer = c10::xpu::current_device();
}
check_device_index(device);
check_device_index(peer);
// A device can always access itself
if (device == peer) {
return true;
}
return c10::xpu::get_raw_device(device).ext_oneapi_can_access_peer(
c10::xpu::get_raw_device(peer),
sycl::ext::oneapi::peer_access::access_supported);
}
} // namespace at::xpu } // namespace at::xpu

View File

@ -17,4 +17,6 @@ TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device);
TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device); TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device);
TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer);
} // namespace at::xpu } // namespace at::xpu

View File

@ -12,6 +12,7 @@
:nosignatures: :nosignatures:
StreamContext StreamContext
can_device_access_peer
current_device current_device
current_stream current_stream
device device

View File

@ -585,6 +585,16 @@ if __name__ == "__main__":
for arch in arch_list: for arch in arch_list:
self.assertTrue(arch in flags) self.assertTrue(arch in flags)
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_can_device_access_peer(self):
device_count = torch.xpu.device_count()
for device in range(device_count):
for peer in range(device_count):
self.assertEqual(
torch.xpu.can_device_access_peer(device, peer),
torch.xpu.can_device_access_peer(peer, device),
)
def test_torch_version_xpu(self): def test_torch_version_xpu(self):
self.assertEqual(len(torch.version.xpu), 8) self.assertEqual(len(torch.version.xpu), 8)
compiler_version = int(torch.version.xpu) compiler_version = int(torch.version.xpu)

View File

@ -2369,6 +2369,7 @@ def _xpu_memoryStats(device: _int) -> dict[str, Any]: ...
def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _xpu_resetPeakMemoryStats(device: _int) -> None: ... def _xpu_resetPeakMemoryStats(device: _int) -> None: ...
def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ... def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ...
def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ...
class _XpuDeviceProperties: class _XpuDeviceProperties:
name: str name: str

View File

@ -415,6 +415,11 @@ static void initXpuMethodBindings(PyObject* module) {
return std::make_tuple( return std::make_tuple(
stream.id(), stream.device_index(), stream.device_type()); stream.id(), stream.device_index(), stream.device_type());
}); });
m.def(
"_xpu_canDeviceAccessPeer",
[](c10::DeviceIndex device, c10::DeviceIndex peer) {
return at::xpu::canDeviceAccessPeer(device, peer);
});
} }
// Callback for python part. Used for additional initialization of python // Callback for python part. Used for additional initialization of python

View File

@ -280,6 +280,22 @@ def _get_device(device: Union[int, str, torch.device]) -> torch.device:
return device return device
def can_device_access_peer(device: _device_t, peer: _device_t) -> bool:
r"""Query whether a device can access a peer device's memory.
Args:
device (torch.device or int or str): selected device.
peer (torch.device or int or str): peer device to query access to.
Returns:
bool: ``True`` if ``device`` can access ``peer``, ``False`` otherwise.
"""
_lazy_init()
device = _get_device_index(device, optional=True)
peer = _get_device_index(peer, optional=True)
return torch._C._xpu_canDeviceAccessPeer(device, peer)
class StreamContext: class StreamContext:
r"""Context-manager that selects a given stream. r"""Context-manager that selects a given stream.
@ -518,6 +534,7 @@ __all__ = [
"Event", "Event",
"Stream", "Stream",
"StreamContext", "StreamContext",
"can_device_access_peer",
"current_device", "current_device",
"current_stream", "current_stream",
"default_generators", "default_generators",