[OpenReg] Integrate Event&Stream from OpenReg Backend into PyTorch (#160100)

We integrated the openreg backend’s `Stream` and `Event` into PyTorch, all of which are similar
to other accelerators like `CUDA`, `XPUs`, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160100
Approved by: https://github.com/albanD
ghstack dependencies: #161603, #160099, #161773
This commit is contained in:
FFFrog
2025-08-30 18:10:48 +08:00
committed by PyTorch MergeBot
parent 6284881b2a
commit b93f87d67b
10 changed files with 606 additions and 11 deletions

View File

@ -36,6 +36,7 @@ else()
message(FATAL_ERROR "Cannot find Python directory")
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)

View File

@ -0,0 +1,146 @@
#pragma once
#include <include/openreg.h>
#include "OpenRegException.h"
#include "OpenRegStream.h"
namespace c10::openreg {
struct OpenRegEvent {
OpenRegEvent(bool enable_timing) noexcept : enable_timing_{enable_timing} {}
~OpenRegEvent() {
if (is_created_) {
OPENREG_CHECK(orEventDestroy(event_));
}
}
OpenRegEvent(const OpenRegEvent&) = delete;
OpenRegEvent& operator=(const OpenRegEvent&) = delete;
OpenRegEvent(OpenRegEvent&& other) noexcept {
moveHelper(std::move(other));
}
OpenRegEvent& operator=(OpenRegEvent&& other) noexcept {
if (this != &other) {
moveHelper(std::move(other));
}
return *this;
}
operator orEvent_t() const {
return event();
}
std::optional<at::Device> device() const {
if (is_created_) {
return at::Device(at::kPrivateUse1, device_index_);
} else {
return std::nullopt;
}
}
bool isCreated() const {
return is_created_;
}
DeviceIndex device_index() const {
return device_index_;
}
orEvent_t event() const {
return event_;
}
bool query() const {
if (!is_created_) {
return true;
}
orError_t err = orEventQuery(event_);
if (err == orSuccess) {
return true;
}
return false;
}
void record() {
record(getCurrentOpenRegStream());
}
void recordOnce(const OpenRegStream& stream) {
if (!was_recorded_)
record(stream);
}
void record(const OpenRegStream& stream) {
if (!is_created_) {
createEvent(stream.device_index());
}
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
OPENREG_CHECK(orEventRecord(event_, stream));
was_recorded_ = true;
}
void block(const OpenRegStream& stream) {
if (is_created_) {
OPENREG_CHECK(orStreamWaitEvent(stream, event_, 0));
}
}
float elapsed_time(const OpenRegEvent& other) const {
TORCH_CHECK_VALUE(
!(enable_timing_ & orEventDisableTiming) &&
!(other.enable_timing_ & orEventDisableTiming),
"Both events must be created with argument 'enable_timing=True'.");
TORCH_CHECK_VALUE(
is_created_ && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
float time_ms = 0;
OPENREG_CHECK(orEventElapsedTime(&time_ms, event_, other.event_));
return time_ms;
}
void synchronize() const {
if (is_created_) {
OPENREG_CHECK(orEventSynchronize(event_));
}
}
private:
unsigned int enable_timing_{orEventDisableTiming};
bool is_created_{false};
bool was_recorded_{false};
DeviceIndex device_index_{-1};
orEvent_t event_{};
void createEvent(DeviceIndex device_index) {
device_index_ = device_index;
OPENREG_CHECK(orEventCreateWithFlags(&event_, enable_timing_));
is_created_ = true;
}
void moveHelper(OpenRegEvent&& other) {
std::swap(enable_timing_, other.enable_timing_);
std::swap(is_created_, other.is_created_);
std::swap(was_recorded_, other.was_recorded_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
};
} // namespace c10::openreg

View File

@ -0,0 +1,9 @@
#include "OpenRegException.h"
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg) {
throw ::c10::Error({func, file, line}, msg);
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <include/openreg.h>
#include <c10/util/Exception.h>
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,5 +1,6 @@
#include <include/openreg.h>
#include "OpenRegException.h"
#include "OpenRegFunctions.h"
namespace c10::openreg {

View File

@ -1,14 +1,10 @@
#pragma once
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
#include <include/Macros.h>
#include <limits>
namespace c10::openreg {

View File

@ -0,0 +1,253 @@
#include "OpenRegStream.h"
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <array>
#include <atomic>
#include <cstdint>
#include <deque>
namespace c10::openreg {
namespace {
// Global stream state and constants
static c10::once_flag init_flag;
static DeviceIndex num_devices = -1;
static constexpr int kStreamsPerPoolBits = 5;
static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
static constexpr int kStreamTypeBits = 2;
/*
* The stream pools are lazily initialized when the first queue is requested
* for a device. The device flags track the initialization of each device. When
* a queue is requested, the next queue in the pool to be returned in a
* round-robin fashion, see Note [Stream Management].
*/
static std::deque<c10::once_flag> device_flags;
static std::vector<std::array<
std::array<orStream_t, kStreamsPerPool>,
c10::openreg::max_compile_time_stream_priorities>>
streams;
static std::deque<
std::array<std::atomic<uint32_t>, max_compile_time_stream_priorities>>
priority_counters;
static thread_local std::unique_ptr<StreamId[]> current_streams = nullptr;
/*
* Note [StreamId assignment]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~
* How do we assign stream IDs?
*
* -- 56 bits -- -- 5 bits -- -- 2 bits -- -- 1 bit --
* zeros StreamIdIndex StreamIdType Ext/native stream
* ignored for ext ignored for ext
*
* Where StreamIdType:
* 00 = default stream
* 01 = normal stream
* 11 = external stream
*
* For external stream, StreamID is a orStream_t pointer. This means that last
* bit will always be 0. So when constructing StreamId for a native stream we
* set last bit to 1 to distinguish between native and external streams.
*
* StreamId is 64-bit, so we can just rely on regular promotion rules.
* We rely on StreamIdIndex and StreamIdType being non-negative;
*/
using StreamIdIndex = uint8_t;
enum class StreamIdType : uint8_t {
DEFAULT = 0x0,
NORMAL = 0x1,
EXT = 0x3,
};
inline std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
switch (s) {
case StreamIdType::DEFAULT:
return stream << "DEFAULT";
case StreamIdType::NORMAL:
return stream << "NORMAL";
case StreamIdType::EXT:
return stream << "EXT";
default:
break;
}
return stream << static_cast<int16_t>(s);
}
static inline StreamIdType streamIdType(StreamId s) {
// Externally allocated streams have their id being the orStream_ptr
// so the last bit will be 0
if (!(s & 1)) {
return StreamIdType(StreamIdType::EXT);
}
int mask_for_type = (1 << kStreamTypeBits) - 1;
auto st = static_cast<StreamIdType>((s >> 1) & mask_for_type);
TORCH_CHECK(
st == StreamIdType::DEFAULT || st == StreamIdType::NORMAL,
"invalid StreamId: ",
s);
return st;
}
static inline size_t streamIdIndex(StreamId s) {
return static_cast<size_t>(
(s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
}
StreamId makeStreamId(StreamIdType st, size_t si) {
if (st == StreamIdType::EXT) {
return static_cast<StreamId>(0);
}
return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
(static_cast<StreamId>(st) << 1) | 1;
}
static void initGlobalStreamState() {
num_devices = device_count();
device_flags.resize(num_devices);
streams.resize(num_devices);
priority_counters.resize(num_devices);
}
static void initSingleDeviceStream(
int priority,
DeviceIndex device_index,
int i) {
auto& stream = streams[device_index][priority][i];
OPENREG_CHECK(orStreamCreateWithPriority(&stream, 0, priority));
priority_counters[device_index][priority] = 0;
}
// Creates stream pools for the specified device. It should be call only once.
static void initDeviceStreamState(DeviceIndex device_index) {
for (const auto i : c10::irange(kStreamsPerPool)) {
for (const auto p : c10::irange(max_compile_time_stream_priorities)) {
initSingleDeviceStream(p, device_index, i);
}
}
}
static void initOpenRegStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState);
if (current_streams) {
return;
}
// Inits current streams (thread local) to the last queue in the "normal
// priority" queue pool. Note: the queue pool have not been initialized yet.
// It will be initialized in initDeviceStreamState for the specified device.
current_streams = std::make_unique<StreamId[]>(num_devices);
for (const auto i : c10::irange(num_devices)) {
current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
}
}
static uint32_t get_idx(std::atomic<uint32_t>& counter) {
auto raw_idx = counter++;
return raw_idx % kStreamsPerPool;
}
OpenRegStream OpenRegStreamForId(DeviceIndex device_index, StreamId stream_id) {
return OpenRegStream(
OpenRegStream::UNCHECKED,
Stream(
Stream::UNSAFE,
c10::Device(DeviceType::PrivateUse1, device_index),
stream_id));
}
} // anonymous namespace
// See Note [StreamId assignment]
orStream_t OpenRegStream::stream() const {
c10::DeviceIndex device_index = stream_.device_index();
StreamId stream_id = stream_.id();
StreamIdType st = streamIdType(stream_id);
size_t si = streamIdIndex(stream_id);
switch (st) {
// The index 0 stream is default as well.
case StreamIdType::DEFAULT:
case StreamIdType::NORMAL:
return streams[device_index][static_cast<uint8_t>(st)][si];
case StreamIdType::EXT:
return reinterpret_cast<orStream_t>(stream_id);
default:
TORCH_CHECK(
false,
"Unrecognized stream ",
stream_,
" (I didn't recognize the stream type, ",
st,
").",
" Did you manufacture the StreamId yourself? Don't do that;");
}
}
// Returns a stream from the requested pool
// Note: when called the first time on a device, this will create the
// stream pools for that device.
OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
initOpenRegStreamsOnce();
if (device_index == -1) {
device_index = current_device();
}
c10::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
auto pri_idx =
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
auto id_type = static_cast<StreamIdType>(pri_idx);
return OpenRegStreamForId(device_index, makeStreamId(id_type, idx));
}
OpenRegStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
initOpenRegStreamsOnce();
int priority = 0;
return getStreamFromPool(priority, device);
}
OpenRegStream getStreamFromExternal(
orStream_t ext_stream,
DeviceIndex device_index) {
return OpenRegStreamForId(
device_index, reinterpret_cast<int64_t>(ext_stream));
}
OpenRegStream getDefaultOpenRegStream(DeviceIndex device_index) {
initOpenRegStreamsOnce();
if (device_index == -1) {
device_index = current_device();
}
return OpenRegStreamForId(
device_index, makeStreamId(StreamIdType::DEFAULT, 0));
}
OpenRegStream getCurrentOpenRegStream(DeviceIndex device_index) {
initOpenRegStreamsOnce();
if (device_index == -1) {
device_index = current_device();
}
return OpenRegStreamForId(device_index, current_streams[device_index]);
}
void setCurrentOpenRegStream(OpenRegStream stream) {
initOpenRegStreamsOnce();
current_streams[stream.device_index()] = stream.id();
}
std::ostream& operator<<(std::ostream& stream, const OpenRegStream& s) {
return stream << s.unwrap();
}
} // namespace c10::openreg

View File

@ -0,0 +1,162 @@
#pragma once
#include <include/openreg.h>
#include "OpenRegException.h"
#include "OpenRegFunctions.h"
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
namespace c10::openreg {
static constexpr int max_compile_time_stream_priorities = 1;
class OpenRegStream {
public:
enum Unchecked { UNCHECKED };
explicit OpenRegStream(Stream stream) : stream_(stream) {
TORCH_CHECK(stream_.device_type() == DeviceType::PrivateUse1);
}
explicit OpenRegStream(Unchecked, Stream stream) : stream_(stream) {}
bool operator==(const OpenRegStream& other) const noexcept {
return unwrap() == other.unwrap();
}
bool operator!=(const OpenRegStream& other) const noexcept {
return unwrap() != other.unwrap();
}
operator orStream_t() const {
return stream();
}
operator Stream() const {
return unwrap();
}
DeviceType device_type() const {
return DeviceType::PrivateUse1;
}
DeviceIndex device_index() const {
return stream_.device_index();
}
Device device() const {
return Device(DeviceType::PrivateUse1, device_index());
}
StreamId id() const {
return stream_.id();
}
bool query() const {
DeviceGuard guard{stream_.device()};
if (orStreamQuery(stream()) == orSuccess) {
return true;
}
return false;
}
void synchronize() const {
DeviceGuard guard{stream_.device()};
OPENREG_CHECK(orStreamSynchronize(stream()));
}
int priority() const {
DeviceGuard guard{stream_.device()};
int priority = 0;
OPENREG_CHECK(orStreamGetPriority(stream(), &priority));
return priority;
}
orStream_t stream() const;
Stream unwrap() const {
return stream_;
}
struct c10::StreamData3 pack3() const {
return stream_.pack3();
}
static OpenRegStream unpack3(
StreamId stream_id,
DeviceIndex device_index,
DeviceType device_type) {
return OpenRegStream(Stream::unpack3(stream_id, device_index, device_type));
}
private:
Stream stream_;
};
/*
* Get a stream from the pool in a round-robin fashion.
*
* You can request a stream from the highest priority pool by setting
* isHighPriority to true for a specific device.
*/
OPENREG_EXPORT OpenRegStream
getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
/*
* Get a stream from the pool in a round-robin fashion.
*
* You can request a stream by setting a priority value for a specific device.
* The priority number lower, the priority higher.
*/
OPENREG_EXPORT OpenRegStream
getStreamFromPool(const int priority, DeviceIndex device = -1);
/*
* Get a OpenRegStream from a externally allocated one.
*
* This is mainly for interoperability with different libraries where we
* want to operate on a non-torch allocated stream for data exchange or similar
* purposes
*/
OPENREG_EXPORT OpenRegStream
getStreamFromExternal(orStream_t ext_stream, DeviceIndex device_index);
/*
* Get the default OpenReg stream, for the passed OpenReg device, or for the
* current device if no device index is passed.
*/
OPENREG_EXPORT OpenRegStream
getDefaultOpenRegStream(DeviceIndex device_index = -1);
/*
* Get the current OpenReg stream, for the passed OpenReg device, or for the
* current device if no device index is passed.
*/
OPENREG_EXPORT OpenRegStream
getCurrentOpenRegStream(DeviceIndex device_index = -1);
/*
* Set the current stream on the device of the passed in stream to be the passed
* in stream.
*/
OPENREG_EXPORT void setCurrentOpenRegStream(OpenRegStream stream);
OPENREG_EXPORT std::ostream& operator<<(
std::ostream& stream,
const OpenRegStream& s);
} // namespace c10::openreg
namespace std {
template <>
struct hash<c10::openreg::OpenRegStream> {
size_t operator()(c10::openreg::OpenRegStream s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
} // namespace std

View File

@ -0,0 +1,7 @@
#pragma once
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif

View File

@ -1,9 +1,9 @@
#include <Python.h>
#ifdef _WIN32
#define OPENREG_EXPORT __declspec(dllexport)
#define OPENREG_EXPORT __declspec(dllexport)
#else
#define OPENREG_EXPORT __attribute__((visibility("default")))
#define OPENREG_EXPORT __attribute__((visibility("default")))
#endif
extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
@ -12,9 +12,9 @@ extern OPENREG_EXPORT PyObject* initOpenRegModule(void);
extern "C"
#endif
OPENREG_EXPORT PyObject* PyInit__C(void);
OPENREG_EXPORT PyObject*
PyInit__C(void);
PyMODINIT_FUNC PyInit__C(void)
{
PyMODINIT_FUNC PyInit__C(void) {
return initOpenRegModule();
}