[OpenReg] Improve the Event and Stream capabilities of DeviceGuardImplInterface (#160101)

**Changes:**

- Based on `OpenRegStream` and `OpenRegEvent`, we improve the implementation of Device Guard for `OpenReg`
- Add some related testcases
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160101
Approved by: https://github.com/albanD
ghstack dependencies: #161917, #161918
This commit is contained in:
FFFrog
2025-09-13 02:06:25 +08:00
committed by PyTorch MergeBot
parent 27daa6af6a
commit 29f84b0f61
8 changed files with 272 additions and 60 deletions

View File

@ -36,7 +36,7 @@ OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
static int count = []() {
try {
auto result = device_count_impl();
TORCH_INTERNAL_ASSERT(
TORCH_CHECK(
result <= std::numeric_limits<c10::DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;

View File

@ -1,19 +1,26 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <include/openreg.h>
#include "OpenRegEvent.h"
#include "OpenRegFunctions.h"
#include "OpenRegStream.h"
namespace c10::openreg {
// Device guard registration
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
OpenRegGuardImpl() = default;
explicit OpenRegGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == static_type);
TORCH_CHECK(
t == static_type,
"OpenRegGuardImpl initialized with non-PrivateUse1 DeviceType: ",
t);
}
/**
@ -27,7 +34,8 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* Set the current device to Device, and return the previous c10::Device.
*/
c10::Device exchangeDevice(c10::Device d) const override {
TORCH_CHECK(d.is_privateuseone());
TORCH_CHECK(
d.is_privateuseone(), "Excepted a PrivateUse1 device, but got ", d);
auto old_device_index = ExchangeDevice(d.index());
return c10::Device(static_type, old_device_index);
@ -45,7 +53,8 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* Set the current device to c10::Device.
*/
void setDevice(c10::Device d) const override {
TORCH_CHECK(d.is_privateuseone());
TORCH_CHECK(
d.is_privateuseone(), "Excepted a PrivateUse1 device, but got ", d);
set_device(d.index());
}
@ -55,8 +64,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* (so, e.g., this can be called from a destructor).
*/
void uncheckedSetDevice(c10::Device d) const noexcept override {
TORCH_CHECK(d.is_privateuseone());
set_device(d.index());
}
@ -64,23 +71,14 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* Get the current stream for a given device.
*/
c10::Stream getStream(c10::Device d) const noexcept override {
return c10::Stream(c10::Stream::DEFAULT, d);
return getCurrentOpenRegStream(d.index()).unwrap();
}
/**
* Get the default stream for a given device.
*/
c10::Stream getDefaultStream(c10::Device d) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
}
/**
* Get a stream from the global pool for a given device.
*/
c10::Stream getStreamFromGlobalPool(
c10::Device d,
bool isHighPriority = false) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
return getDefaultOpenRegStream(d.index());
}
/**
@ -88,8 +86,16 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
return c10::Stream(c10::Stream::DEFAULT, d);
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPool(priority, d.index());
}
/**
* Get a stream from the global pool for a given device.
*/
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return getStreamFromPool(isHighPriority, d.index());
}
/**
@ -98,14 +104,37 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* to set the current device to match the device of this stream.
*/
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
return s;
const OpenRegStream stream(s);
const auto old_stream = getCurrentOpenRegStream(s.device().index());
setCurrentOpenRegStream(stream);
return old_stream.unwrap();
}
/**
* Get the number of devices.
*
* WARNING: This is REQUIRED to not raise an exception.
* If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
DeviceIndex deviceCount() const noexcept override {
return device_count();
}
/**
* Destroys the given event.
*/
void destroyEvent(void* event, const c10::DeviceIndex device_index)
const noexcept override {}
const noexcept override {
if (!event)
return;
auto or_event = static_cast<orEvent_t>(event);
auto orig_device = current_device();
set_device(device_index);
orEventDestroy(or_event);
set_device(orig_device);
}
/**
* Increments the event's version and enqueues a job with this version
@ -118,10 +147,40 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::Stream& stream,
const c10::DeviceIndex device_index,
const c10::EventFlag flag) const override {
static int event_id = 1;
TORCH_CHECK(
device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
if (!*event)
*event = reinterpret_cast<void*>(event_id++);
orEvent_t or_event = static_cast<orEvent_t>(*event);
OpenRegStream or_stream{stream};
const auto orig_device = current_device();
set_device(stream.device().index());
if (!or_event) {
auto or_flag = orEventDisableTiming;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
or_flag = orEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
or_flag = orEventEnableTiming;
break;
default:
TORCH_CHECK(false, "Received unknown flag");
}
orEventCreateWithFlags(&or_event, or_flag);
}
orEventRecord(or_event, or_stream);
*event = or_event;
set_device(orig_device);
}
/**
@ -132,7 +191,17 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(void* event, const c10::Stream& stream) const override {}
void block(void* event, const c10::Stream& stream) const override {
if (!event)
return;
orEvent_t or_event = static_cast<orEvent_t>(event);
OpenRegStream or_stream{stream};
const auto orig_device = current_device();
set_device(stream.device().index());
orStreamWaitEvent(or_stream, or_event, 0);
set_device(orig_device);
}
/**
* Returns true if (and only if)
@ -141,47 +210,56 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
* Returns false otherwise.
*/
bool queryEvent(void* event) const override {
return true;
if (!event)
return true;
orEvent_t or_event = static_cast<orEvent_t>(event);
const orError_t err = orEventQuery(or_event);
return err == orSuccess ? true : false;
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
c10::DeviceIndex deviceCount() const noexcept override {
int device_index = -1;
orGetDeviceCount(&device_index);
return device_index;
}
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
bool queryStream(const c10::Stream& stream) const override {
return true;
OpenRegStream or_stream{stream};
return or_stream.query();
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
void synchronizeStream(const c10::Stream& stream) const override {}
void synchronizeStream(const c10::Stream& stream) const override {
OpenRegStream or_stream{stream};
or_stream.synchronize();
}
/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
void synchronizeEvent(void* event) const override {}
void synchronizeEvent(void* event) const override {
if (!event)
return;
orEvent_t or_event = static_cast<orEvent_t>(event);
orEventSynchronize(or_event);
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
* Wait (by blocking the calling thread) until all the work has
* completed running on the device.
*/
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const c10::Stream& stream) const override {}
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
DeviceIndex orig_device{-1};
auto orig_devicec = current_device();
set_device(device_index);
orDeviceSynchronize();
set_device(orig_device);
}
/**
* Fetch the elapsed time between two recorded events.
@ -190,7 +268,20 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
void* event1,
void* event2,
const c10::DeviceIndex device_index) const override {
return 1;
TORCH_CHECK(
event1 && event2,
"Both events must be recorded before calculating elapsed time.");
auto orig_device = current_device();
set_device(device_index);
orEvent_t or_event1 = static_cast<orEvent_t>(event1);
orEvent_t or_event2 = static_cast<orEvent_t>(event2);
float time_ms = 0;
orEventElapsedTime(&time_ms, or_event1, or_event2);
set_device(orig_device);
return static_cast<double>(time_ms);
}
};

View File

@ -1,3 +1,5 @@
#pragma once
#include <ATen/core/CachingHostAllocator.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>

View File

@ -1,3 +1,5 @@
#pragma once
#include <ATen/core/CachingHostAllocator.h>
#include <c10/core/Allocator.h>

View File

@ -1,3 +1,5 @@
#pragma once
#include <torch/csrc/jit/serialization/pickler.h>
#define REGISTER_PRIVATEUSE1_SERIALIZATION( \

View File

@ -0,0 +1,32 @@
# Owner(s): ["module: PrivateUse1"]
import torch
import torch_openreg # noqa: F401
from torch.testing._internal.common_utils import run_tests, TestCase
class TestDevice(TestCase):
def test_device_count(self):
count = torch.accelerator.device_count()
self.assertEqual(count, 2)
def test_device_switch(self):
torch.accelerator.set_device_index(1)
self.assertEqual(torch.accelerator.current_device_index(), 1)
torch.accelerator.set_device_index(0)
self.assertEqual(torch.accelerator.current_device_index(), 0)
def test_device_context(self):
device = torch.accelerator.current_device_index()
with torch.accelerator.device_index(None):
self.assertEqual(torch.accelerator.current_device_index(), device)
self.assertEqual(torch.accelerator.current_device_index(), device)
with torch.accelerator.device_index(1):
self.assertEqual(torch.accelerator.current_device_index(), 1)
self.assertEqual(torch.accelerator.current_device_index(), device)
if __name__ == "__main__":
run_tests()

View File

@ -6,34 +6,72 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
class TestEvent(TestCase):
@skipIfTorchDynamo()
def test_record_event(self):
def test_event_create(self):
event = torch.Event(device="openreg")
self.assertEqual(event.device.type, "openreg")
self.assertEqual(event.device.index, None)
self.assertEqual(event.event_id, 0)
event = torch.Event(device="openreg:1")
self.assertEqual(event.device.type, "openreg")
self.assertEqual(event.device.index, None)
self.assertEqual(event.event_id, 0)
event = torch.Event()
self.assertEqual(event.device.type, "openreg")
self.assertEqual(event.device.index, None)
self.assertEqual(event.event_id, 0)
stream = torch.Stream(device="openreg:1")
event = stream.record_event()
self.assertEqual(event.device.type, "openreg")
self.assertEqual(event.device.index, 1)
self.assertNotEqual(event.event_id, 0)
@skipIfTorchDynamo()
def test_event_query(self):
event = torch.Event()
self.assertTrue(event.query())
stream = torch.Stream(device="openreg:1")
event = stream.record_event()
event.synchronize()
self.assertTrue(event.query())
@skipIfTorchDynamo()
def test_event_record(self):
stream = torch.Stream(device="openreg:1")
event1 = stream.record_event()
self.assertNotEqual(0, event1.event_id)
event2 = stream.record_event()
self.assertNotEqual(0, event2.event_id)
self.assertNotEqual(event1.event_id, event2.event_id)
@skipIfTorchDynamo()
def test_event_elapsed_time(self):
stream = torch.Stream(device="openreg:1")
e1 = torch.Event(device="openreg:1", enable_timing=True)
e1.record(stream)
e2 = torch.Event(device="openreg:1", enable_timing=True)
e2.record(stream)
e2.synchronize()
self.assertTrue(e2.query())
event1 = torch.Event(device="openreg:1", enable_timing=True)
event1.record(stream)
event2 = torch.Event(device="openreg:1", enable_timing=True)
event2.record(stream)
ms = e1.elapsed_time(e2)
stream.synchronize()
self.assertTrue(event1.query())
self.assertTrue(event2.query())
ms = event1.elapsed_time(event2)
self.assertTrue(ms > 0)
@skipIfTorchDynamo()
def test_event_wait_stream(self):
s1 = torch.Stream(device="openreg")
s2 = torch.Stream(device="openreg")
e1 = s1.record_event()
e1.wait(s2)
stream1 = torch.Stream(device="openreg")
stream2 = torch.Stream(device="openreg")
event = stream1.record_event()
stream2.wait_event(event)
if __name__ == "__main__":

View File

@ -5,11 +5,56 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
class TestStream(TestCase):
def test_stream_create(self):
stream = torch.Stream(device="openreg")
self.assertEqual(stream.device_index, torch.openreg.current_device())
stream = torch.Stream(device="openreg:1")
self.assertEqual(stream.device.type, "openreg")
self.assertEqual(stream.device_index, 1)
stream = torch.Stream(1)
self.assertEqual(stream.device.type, "openreg")
self.assertEqual(stream.device_index, 1)
stream1 = torch.Stream(
stream_id=stream.stream_id,
device_type=stream.device_type,
device_index=stream.device_index,
)
self.assertEqual(stream, stream1)
def test_stream_context(self):
with torch.Stream(device="openreg:1") as stream:
self.assertEqual(torch.accelerator.current_stream(), stream)
@skipIfTorchDynamo()
def test_stream_switch(self):
stream1 = torch.Stream(device="openreg:0")
torch.accelerator.set_stream(stream1)
current_stream = torch.accelerator.current_stream()
self.assertEqual(current_stream, stream1)
stream2 = torch.Stream(device="openreg:1")
torch.accelerator.set_stream(stream2)
current_stream = torch.accelerator.current_stream()
self.assertEqual(current_stream, stream2)
def test_stream_synchronize(self):
stream = torch.Stream(device="openreg:1")
self.assertEqual(True, stream.query())
event = torch.Event()
event.record(stream)
stream.synchronize()
self.assertEqual(True, stream.query())
def test_stream_repr(self):
stream = torch.Stream(device="openreg:1")
self.assertTrue(
"torch.Stream device_type=openreg, device_index=1" in repr(stream)
)
def test_stream_wait_stream(self):
stream_1 = torch.Stream(device="openreg:0")
stream_2 = torch.Stream(device="openreg:1")