mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[OpenReg] Improve the Event and Stream capabilities of DeviceGuardImplInterface (#160101)
**Changes:** - Based on `OpenRegStream` and `OpenRegEvent`, we improve the implementation of Device Guard for `OpenReg` - Add some related testcases Pull Request resolved: https://github.com/pytorch/pytorch/pull/160101 Approved by: https://github.com/albanD ghstack dependencies: #161917, #161918
This commit is contained in:
@ -36,7 +36,7 @@ OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
TORCH_CHECK(
|
||||
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||
"Too many devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
|
||||
@ -1,19 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegEvent.h"
|
||||
#include "OpenRegFunctions.h"
|
||||
#include "OpenRegStream.h"
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// Device guard registration
|
||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
|
||||
OpenRegGuardImpl() = default;
|
||||
|
||||
explicit OpenRegGuardImpl(c10::DeviceType t) {
|
||||
TORCH_INTERNAL_ASSERT(t == static_type);
|
||||
TORCH_CHECK(
|
||||
t == static_type,
|
||||
"OpenRegGuardImpl initialized with non-PrivateUse1 DeviceType: ",
|
||||
t);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -27,7 +34,8 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* Set the current device to Device, and return the previous c10::Device.
|
||||
*/
|
||||
c10::Device exchangeDevice(c10::Device d) const override {
|
||||
TORCH_CHECK(d.is_privateuseone());
|
||||
TORCH_CHECK(
|
||||
d.is_privateuseone(), "Excepted a PrivateUse1 device, but got ", d);
|
||||
|
||||
auto old_device_index = ExchangeDevice(d.index());
|
||||
return c10::Device(static_type, old_device_index);
|
||||
@ -45,7 +53,8 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* Set the current device to c10::Device.
|
||||
*/
|
||||
void setDevice(c10::Device d) const override {
|
||||
TORCH_CHECK(d.is_privateuseone());
|
||||
TORCH_CHECK(
|
||||
d.is_privateuseone(), "Excepted a PrivateUse1 device, but got ", d);
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
@ -55,8 +64,6 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* (so, e.g., this can be called from a destructor).
|
||||
*/
|
||||
void uncheckedSetDevice(c10::Device d) const noexcept override {
|
||||
TORCH_CHECK(d.is_privateuseone());
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
|
||||
@ -64,23 +71,14 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* Get the current stream for a given device.
|
||||
*/
|
||||
c10::Stream getStream(c10::Device d) const noexcept override {
|
||||
return c10::Stream(c10::Stream::DEFAULT, d);
|
||||
return getCurrentOpenRegStream(d.index()).unwrap();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default stream for a given device.
|
||||
*/
|
||||
c10::Stream getDefaultStream(c10::Device d) const override {
|
||||
return c10::Stream(c10::Stream::DEFAULT, d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a stream from the global pool for a given device.
|
||||
*/
|
||||
c10::Stream getStreamFromGlobalPool(
|
||||
c10::Device d,
|
||||
bool isHighPriority = false) const override {
|
||||
return c10::Stream(c10::Stream::DEFAULT, d);
|
||||
return getDefaultOpenRegStream(d.index());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -88,8 +86,16 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* copied and shared around, device backend should be able to correctly handle
|
||||
* the lifetime of the stream.
|
||||
*/
|
||||
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
|
||||
return c10::Stream(c10::Stream::DEFAULT, d);
|
||||
Stream getNewStream(Device d, int priority = 0) const override {
|
||||
return getStreamFromPool(priority, d.index());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a stream from the global pool for a given device.
|
||||
*/
|
||||
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
|
||||
const override {
|
||||
return getStreamFromPool(isHighPriority, d.index());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -98,14 +104,37 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* to set the current device to match the device of this stream.
|
||||
*/
|
||||
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
|
||||
return s;
|
||||
const OpenRegStream stream(s);
|
||||
const auto old_stream = getCurrentOpenRegStream(s.device().index());
|
||||
setCurrentOpenRegStream(stream);
|
||||
return old_stream.unwrap();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of devices.
|
||||
*
|
||||
* WARNING: This is REQUIRED to not raise an exception.
|
||||
* If there is some sort of problem, e.g., driver error,
|
||||
* you should report that there are zero available devices.
|
||||
*/
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
return device_count();
|
||||
}
|
||||
|
||||
/**
|
||||
* Destroys the given event.
|
||||
*/
|
||||
void destroyEvent(void* event, const c10::DeviceIndex device_index)
|
||||
const noexcept override {}
|
||||
const noexcept override {
|
||||
if (!event)
|
||||
return;
|
||||
|
||||
auto or_event = static_cast<orEvent_t>(event);
|
||||
auto orig_device = current_device();
|
||||
set_device(device_index);
|
||||
orEventDestroy(or_event);
|
||||
set_device(orig_device);
|
||||
}
|
||||
|
||||
/**
|
||||
* Increments the event's version and enqueues a job with this version
|
||||
@ -118,10 +147,40 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
const c10::Stream& stream,
|
||||
const c10::DeviceIndex device_index,
|
||||
const c10::EventFlag flag) const override {
|
||||
static int event_id = 1;
|
||||
TORCH_CHECK(
|
||||
device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
if (!*event)
|
||||
*event = reinterpret_cast<void*>(event_id++);
|
||||
orEvent_t or_event = static_cast<orEvent_t>(*event);
|
||||
OpenRegStream or_stream{stream};
|
||||
|
||||
const auto orig_device = current_device();
|
||||
set_device(stream.device().index());
|
||||
|
||||
if (!or_event) {
|
||||
auto or_flag = orEventDisableTiming;
|
||||
switch (flag) {
|
||||
case EventFlag::PYTORCH_DEFAULT:
|
||||
or_flag = orEventDisableTiming;
|
||||
break;
|
||||
case EventFlag::BACKEND_DEFAULT:
|
||||
or_flag = orEventEnableTiming;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Received unknown flag");
|
||||
}
|
||||
|
||||
orEventCreateWithFlags(&or_event, or_flag);
|
||||
}
|
||||
|
||||
orEventRecord(or_event, or_stream);
|
||||
*event = or_event;
|
||||
|
||||
set_device(orig_device);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -132,7 +191,17 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* When the stream reaches this command it will stop processing
|
||||
* additional commands until that version of the event is marked as recorded.
|
||||
*/
|
||||
void block(void* event, const c10::Stream& stream) const override {}
|
||||
void block(void* event, const c10::Stream& stream) const override {
|
||||
if (!event)
|
||||
return;
|
||||
|
||||
orEvent_t or_event = static_cast<orEvent_t>(event);
|
||||
OpenRegStream or_stream{stream};
|
||||
const auto orig_device = current_device();
|
||||
set_device(stream.device().index());
|
||||
orStreamWaitEvent(or_stream, or_event, 0);
|
||||
set_device(orig_device);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if (and only if)
|
||||
@ -141,47 +210,56 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
* Returns false otherwise.
|
||||
*/
|
||||
bool queryEvent(void* event) const override {
|
||||
return true;
|
||||
if (!event)
|
||||
return true;
|
||||
|
||||
orEvent_t or_event = static_cast<orEvent_t>(event);
|
||||
const orError_t err = orEventQuery(or_event);
|
||||
|
||||
return err == orSuccess ? true : false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of devices. WARNING: This is REQUIRED to not raise
|
||||
* an exception. If there is some sort of problem, e.g., driver error,
|
||||
* you should report that there are zero available devices.
|
||||
*/
|
||||
c10::DeviceIndex deviceCount() const noexcept override {
|
||||
int device_index = -1;
|
||||
orGetDeviceCount(&device_index);
|
||||
return device_index;
|
||||
}
|
||||
/**
|
||||
* Return true if all the work previously enqueued on the stream for
|
||||
* asynchronous execution has completed running on the device.
|
||||
*/
|
||||
bool queryStream(const c10::Stream& stream) const override {
|
||||
return true;
|
||||
OpenRegStream or_stream{stream};
|
||||
return or_stream.query();
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait (by blocking the calling thread) until all the work previously
|
||||
* enqueued on the stream has completed running on the device.
|
||||
*/
|
||||
void synchronizeStream(const c10::Stream& stream) const override {}
|
||||
void synchronizeStream(const c10::Stream& stream) const override {
|
||||
OpenRegStream or_stream{stream};
|
||||
or_stream.synchronize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait (by blocking the calling thread) until all the work previously
|
||||
* recorded on the event has completed running on the device.
|
||||
*/
|
||||
void synchronizeEvent(void* event) const override {}
|
||||
void synchronizeEvent(void* event) const override {
|
||||
if (!event)
|
||||
return;
|
||||
|
||||
orEvent_t or_event = static_cast<orEvent_t>(event);
|
||||
orEventSynchronize(or_event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the caching allocator (if any) is aware that the given DataPtr is
|
||||
* being used on the given stream, and that it should thus avoid recycling the
|
||||
* DataPtr until all work on that stream is done.
|
||||
* Wait (by blocking the calling thread) until all the work has
|
||||
* completed running on the device.
|
||||
*/
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const c10::Stream& stream) const override {}
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
DeviceIndex orig_device{-1};
|
||||
auto orig_devicec = current_device();
|
||||
set_device(device_index);
|
||||
orDeviceSynchronize();
|
||||
set_device(orig_device);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the elapsed time between two recorded events.
|
||||
@ -190,7 +268,20 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
void* event1,
|
||||
void* event2,
|
||||
const c10::DeviceIndex device_index) const override {
|
||||
return 1;
|
||||
TORCH_CHECK(
|
||||
event1 && event2,
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
auto orig_device = current_device();
|
||||
set_device(device_index);
|
||||
|
||||
orEvent_t or_event1 = static_cast<orEvent_t>(event1);
|
||||
orEvent_t or_event2 = static_cast<orEvent_t>(event2);
|
||||
float time_ms = 0;
|
||||
orEventElapsedTime(&time_ms, or_event1, or_event2);
|
||||
|
||||
set_device(orig_device);
|
||||
|
||||
return static_cast<double>(time_ms);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
||||
#define REGISTER_PRIVATEUSE1_SERIALIZATION( \
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
import torch_openreg # noqa: F401
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestDevice(TestCase):
|
||||
def test_device_count(self):
|
||||
count = torch.accelerator.device_count()
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
def test_device_switch(self):
|
||||
torch.accelerator.set_device_index(1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
|
||||
torch.accelerator.set_device_index(0)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 0)
|
||||
|
||||
def test_device_context(self):
|
||||
device = torch.accelerator.current_device_index()
|
||||
with torch.accelerator.device_index(None):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
with torch.accelerator.device_index(1):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -6,34 +6,72 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
|
||||
|
||||
class TestEvent(TestCase):
|
||||
@skipIfTorchDynamo()
|
||||
def test_record_event(self):
|
||||
def test_event_create(self):
|
||||
event = torch.Event(device="openreg")
|
||||
self.assertEqual(event.device.type, "openreg")
|
||||
self.assertEqual(event.device.index, None)
|
||||
self.assertEqual(event.event_id, 0)
|
||||
|
||||
event = torch.Event(device="openreg:1")
|
||||
self.assertEqual(event.device.type, "openreg")
|
||||
self.assertEqual(event.device.index, None)
|
||||
self.assertEqual(event.event_id, 0)
|
||||
|
||||
event = torch.Event()
|
||||
self.assertEqual(event.device.type, "openreg")
|
||||
self.assertEqual(event.device.index, None)
|
||||
self.assertEqual(event.event_id, 0)
|
||||
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
event = stream.record_event()
|
||||
self.assertEqual(event.device.type, "openreg")
|
||||
self.assertEqual(event.device.index, 1)
|
||||
self.assertNotEqual(event.event_id, 0)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_query(self):
|
||||
event = torch.Event()
|
||||
self.assertTrue(event.query())
|
||||
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
event = stream.record_event()
|
||||
event.synchronize()
|
||||
self.assertTrue(event.query())
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_record(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
event1 = stream.record_event()
|
||||
self.assertNotEqual(0, event1.event_id)
|
||||
|
||||
event2 = stream.record_event()
|
||||
self.assertNotEqual(0, event2.event_id)
|
||||
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_elapsed_time(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
e1 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
e1.record(stream)
|
||||
e2 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
e2.record(stream)
|
||||
|
||||
e2.synchronize()
|
||||
self.assertTrue(e2.query())
|
||||
event1 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
event1.record(stream)
|
||||
event2 = torch.Event(device="openreg:1", enable_timing=True)
|
||||
event2.record(stream)
|
||||
|
||||
ms = e1.elapsed_time(e2)
|
||||
stream.synchronize()
|
||||
self.assertTrue(event1.query())
|
||||
self.assertTrue(event2.query())
|
||||
|
||||
ms = event1.elapsed_time(event2)
|
||||
self.assertTrue(ms > 0)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_event_wait_stream(self):
|
||||
s1 = torch.Stream(device="openreg")
|
||||
s2 = torch.Stream(device="openreg")
|
||||
e1 = s1.record_event()
|
||||
e1.wait(s2)
|
||||
stream1 = torch.Stream(device="openreg")
|
||||
stream2 = torch.Stream(device="openreg")
|
||||
|
||||
event = stream1.record_event()
|
||||
stream2.wait_event(event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -5,11 +5,56 @@ from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, T
|
||||
|
||||
|
||||
class TestStream(TestCase):
|
||||
def test_stream_create(self):
|
||||
stream = torch.Stream(device="openreg")
|
||||
self.assertEqual(stream.device_index, torch.openreg.current_device())
|
||||
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
self.assertEqual(stream.device.type, "openreg")
|
||||
self.assertEqual(stream.device_index, 1)
|
||||
|
||||
stream = torch.Stream(1)
|
||||
self.assertEqual(stream.device.type, "openreg")
|
||||
self.assertEqual(stream.device_index, 1)
|
||||
|
||||
stream1 = torch.Stream(
|
||||
stream_id=stream.stream_id,
|
||||
device_type=stream.device_type,
|
||||
device_index=stream.device_index,
|
||||
)
|
||||
self.assertEqual(stream, stream1)
|
||||
|
||||
def test_stream_context(self):
|
||||
with torch.Stream(device="openreg:1") as stream:
|
||||
self.assertEqual(torch.accelerator.current_stream(), stream)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_stream_switch(self):
|
||||
stream1 = torch.Stream(device="openreg:0")
|
||||
torch.accelerator.set_stream(stream1)
|
||||
current_stream = torch.accelerator.current_stream()
|
||||
self.assertEqual(current_stream, stream1)
|
||||
|
||||
stream2 = torch.Stream(device="openreg:1")
|
||||
torch.accelerator.set_stream(stream2)
|
||||
current_stream = torch.accelerator.current_stream()
|
||||
self.assertEqual(current_stream, stream2)
|
||||
|
||||
def test_stream_synchronize(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
self.assertEqual(True, stream.query())
|
||||
|
||||
event = torch.Event()
|
||||
event.record(stream)
|
||||
stream.synchronize()
|
||||
self.assertEqual(True, stream.query())
|
||||
|
||||
def test_stream_repr(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
self.assertTrue(
|
||||
"torch.Stream device_type=openreg, device_index=1" in repr(stream)
|
||||
)
|
||||
|
||||
def test_stream_wait_stream(self):
|
||||
stream_1 = torch.Stream(device="openreg:0")
|
||||
stream_2 = torch.Stream(device="openreg:1")
|
||||
|
||||
Reference in New Issue
Block a user