Compare commits

...

32 Commits

Author SHA1 Message Date
59382e74d6 Update
[ghstack-poisoned]
2025-11-13 15:19:04 +00:00
3cba13fd7d Update
[ghstack-poisoned]
2025-11-13 14:32:19 +00:00
5db1b7f87e Update (base update)
[ghstack-poisoned]
2025-11-13 14:32:19 +00:00
28d19665ef Update
[ghstack-poisoned]
2025-11-13 13:59:21 +00:00
5d9154ce87 Update (base update)
[ghstack-poisoned]
2025-11-13 13:59:21 +00:00
d8ec48000f Update
[ghstack-poisoned]
2025-11-13 13:47:55 +00:00
1d81962569 Update (base update)
[ghstack-poisoned]
2025-11-13 13:47:55 +00:00
69ac9c4b72 Update
[ghstack-poisoned]
2025-11-13 11:06:28 +00:00
acfaaee5bd Update
[ghstack-poisoned]
2025-11-13 11:01:28 +00:00
af6f2b7f63 Update (base update)
[ghstack-poisoned]
2025-11-13 10:47:30 +00:00
1887633c49 Update
[ghstack-poisoned]
2025-11-13 10:47:30 +00:00
10117bb577 Update (base update)
[ghstack-poisoned]
2025-08-09 02:13:31 +00:00
42fac35faf Update
[ghstack-poisoned]
2025-08-09 02:13:31 +00:00
9b6ee7a57f Update (base update)
[ghstack-poisoned]
2025-07-29 15:42:31 +00:00
9cf97eddd4 Update
[ghstack-poisoned]
2025-07-29 15:42:31 +00:00
8b72926723 Update (base update)
[ghstack-poisoned]
2025-07-18 14:37:08 +00:00
9c90c36e1e Update
[ghstack-poisoned]
2025-07-18 14:37:08 +00:00
e2996c31eb Update
[ghstack-poisoned]
2025-07-18 14:16:24 +00:00
99aace2177 Update
[ghstack-poisoned]
2025-07-17 17:27:15 +00:00
23d08a8c38 Update
[ghstack-poisoned]
2025-07-17 11:55:07 +00:00
54d1704333 Update
[ghstack-poisoned]
2025-07-15 21:33:06 +00:00
43db4b7dac Update
[ghstack-poisoned]
2025-07-15 20:47:09 +00:00
733ced0f33 Update
[ghstack-poisoned]
2025-07-15 20:08:50 +00:00
b3cdad62e2 Update
[ghstack-poisoned]
2025-07-15 19:51:51 +00:00
f6b4d5affc Update
[ghstack-poisoned]
2025-07-15 17:57:33 +00:00
52100144a2 Update
[ghstack-poisoned]
2025-07-15 16:46:50 +00:00
2d706ffe64 Update
[ghstack-poisoned]
2025-07-15 14:39:18 +00:00
e7eeb9c8d8 Update (base update)
[ghstack-poisoned]
2025-07-15 13:22:32 +00:00
a6689c38b6 Update
[ghstack-poisoned]
2025-07-15 13:22:32 +00:00
2757f83841 Update
[ghstack-poisoned]
2025-07-15 12:34:01 +00:00
1c9efafc20 Update (base update)
[ghstack-poisoned]
2025-07-14 14:51:14 +00:00
44141db281 Update
[ghstack-poisoned]
2025-07-14 14:51:14 +00:00
5 changed files with 381 additions and 242 deletions

View File

@ -3,252 +3,15 @@
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Exception.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#include <utility>
/*
* `cudaEventExternal` is a torch-specific flag that is used to
* indicate that the CUDAEvent will be used only for synchronization
* with work outside of the cuda graph, rather than creation of
* cross-stream dependencies within a cuda graph. Resources:
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e
*/
#define cudaEventExternal 0x08
namespace at::cuda {
/*
* CUDAEvents are movable not copyable wrappers around CUDA's events.
*
* CUDAEvents are constructed lazily when first recorded unless it is
* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
* device is acquired from the first recording stream. However, if reconstructed
* from a handle, the device should be explicitly specified; or if ipc_handle() is
* called before the event is ever recorded, it will use the current device.
* Later streams that record the event must match this device.
*/
struct TORCH_CUDA_CPP_API CUDAEvent {
// Constructors
// Default value for `flags` is specified below - it's cudaEventDisableTiming
CUDAEvent() noexcept = default;
CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
CUDAEvent(
DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) {
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
is_created_ = true;
}
// Note: event destruction done on creating device to avoid creating a
// CUDA context on other devices.
~CUDAEvent() {
try {
if (is_created_) {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventDestroy(event_));
}
} catch (...) { /* No throw */ }
}
CUDAEvent(const CUDAEvent&) = delete;
CUDAEvent& operator=(const CUDAEvent&) = delete;
CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
CUDAEvent& operator=(CUDAEvent&& other) noexcept {
if (this != &other) {
moveHelper(std::move(other));
}
return *this;
}
operator cudaEvent_t() const { return event(); }
// Less than operator (to allow use in sets)
friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
return left.event_ < right.event_;
}
std::optional<at::Device> device() const {
if (is_created_) {
return at::Device(at::kCUDA, device_index_);
} else {
return {};
}
}
bool isCreated() const { return is_created_; }
DeviceIndex device_index() const {return device_index_;}
cudaEvent_t event() const { return event_; }
// Note: cudaEventQuery can be safely called from any device
bool query() const {
if (!is_created_) {
return true;
}
cudaError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
} else if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return false;
}
void record() { record(getCurrentCUDAStream()); }
void recordOnce(const CUDAStream& stream) {
if (!was_recorded_) record(stream);
}
// Note: cudaEventRecord must be called on the same device as the event.
void record(const CUDAStream& stream) {
if (!is_created_) {
createEvent(stream.device_index());
}
TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
" does not match recording stream's device ", stream.device_index(), ".");
CUDAGuard guard(device_index_);
#ifndef USE_ROCM
// it is an error to use cudaEventRecordExternal when not doing stream capture
unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault;
AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags));
#else
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
#endif
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
}
was_recorded_ = true;
}
// Note: cudaStreamWaitEvent must be called on the same device as the stream.
// The event has no actual GPU resources associated with it.
void block(const CUDAStream& stream) {
if (is_created_) {
CUDAGuard guard(stream.device_index());
#ifndef USE_ROCM
// it is an error to use cudaEventWaitExternal when not doing stream capture
unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault;
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags));
#else
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_));
#endif
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
}
}
}
// Note: cudaEventElapsedTime can be safely called from any device
float elapsed_time(const CUDAEvent& other) const {
TORCH_CHECK_VALUE(
!(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming),
"Both events must be created with argument 'enable_timing=True'.");
TORCH_CHECK_VALUE(
is_created_ && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
float time_ms = 0;
// We do not strictly have to set the device index to the same as our event,
// but if we don't and the current device is not initialized, it will
// create a new cuda context, which will consume a lot of memory.
CUDAGuard guard(device_index_);
// raise cudaErrorNotReady if either event is recorded but not yet completed
AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
return time_ms;
}
// Note: cudaEventSynchronize can be safely called from any device
void synchronize() const {
if (is_created_) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventSynchronize(event_));
}
}
// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t * handle) {
if (!is_created_) {
// this CUDAEvent object was initially constructed from flags but event_
// is not created yet.
createEvent(getCurrentCUDAStream().device_index());
}
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
}
private:
unsigned int flags_ = cudaEventDisableTiming;
bool is_created_ = false;
bool was_recorded_ = false;
bool external_ = false;
DeviceIndex device_index_ = -1;
cudaEvent_t event_{};
void createEvent(DeviceIndex device_index) {
external_ = (flags_ & cudaEventExternal) != 0;
#ifdef USE_ROCM
TORCH_CHECK(!external_, "External events are disallowed in rocm");
#endif
flags_ &= ~cudaEventExternal;
device_index_ = device_index;
CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
is_created_ = true;
}
void moveHelper(CUDAEvent&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
std::swap(was_recorded_, other.was_recorded_);
std::swap(device_index_, other.device_index_);
std::swap(event_, other.event_);
}
};
// EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate
// calls. cudaEventCreate when concurrently invoked from multiple threads can be
// very expensive (especially on certain device/driver combinations).
// EventPool - Thread-safe pool of CUDA events to avoid expensive
// cudaEventCreate calls. cudaEventCreate when concurrently invoked from
// multiple threads can be very expensive (especially on certain device/driver
// combinations).
using CUDAEventPtr =
std::unique_ptr<CUDAEvent, std::function<void(CUDAEvent*)>>;

View File

@ -0,0 +1,86 @@
#pragma once
#include <c10/hip/HIPEvent.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
// See Note [Masquerading as CUDA] for motivation
struct HIPEventMasqueradingAsCUDA {
HIPEventMasqueradingAsCUDA() noexcept = default;
HIPEventMasqueradingAsCUDA(unsigned int flags) noexcept
: event_(HIPEvent(flags)) {}
HIPEventMasqueradingAsCUDA(
DeviceIndex device_index,
const hipIpcEventHandle_t* handle)
: event_(HIPEvent(device_index, handle)) {}
~HIPEventMasqueradingAsCUDA() = default;
HIPEventMasqueradingAsCUDA(const HIPEventMasqueradingAsCUDA&) = delete;
HIPEventMasqueradingAsCUDA& operator=(const HIPEventMasqueradingAsCUDA&) = delete;
HIPEventMasqueradingAsCUDA(HIPEventMasqueradingAsCUDA&& other) noexcept = default;
HIPEventMasqueradingAsCUDA& operator=(HIPEventMasqueradingAsCUDA&& other) noexcept = default;
operator hipEvent_t() const {
return event_.event();
}
// Less than operator (to allow use in sets)
friend bool operator<(
const HIPEventMasqueradingAsCUDA& left,
const HIPEventMasqueradingAsCUDA& right) {
return left.event_ < right.event_;
}
std::optional<c10::Device> device() const {
// Unsafely coerce HIP device into CUDA device
return Device(c10::DeviceType::CUDA, event_.device_index());
}
bool isCreated() const {
return event_.isCreated();
}
DeviceIndex device_index() const {
return event_.device_index();
}
hipEvent_t event() const {
return event_.event();
}
bool query() const {
return event_.query();
}
void record() {
return event_.record();
}
void recordOnce(const HIPStreamMasqueradingAsCUDA& stream) {
event_.recordOnce(stream.hip_stream());
}
void record(const HIPStreamMasqueradingAsCUDA& stream) {
event_.record(stream.hip_stream());
}
void block(const HIPStreamMasqueradingAsCUDA& stream) {
event_.block(stream.hip_stream());
}
float elapsed_time(const HIPEventMasqueradingAsCUDA& other) const {
return event_.elapsed_time(other.event_);
}
void synchronize() const {
event_.synchronize();
}
void ipc_handle(hipIpcEventHandle_t* handle) {
event_.ipc_handle(handle);
}
private:
HIPEvent event_;
};
}} // namespace c10::hip

View File

@ -43,6 +43,7 @@ set(C10_CUDA_HEADERS
CUDACachingAllocator.h
CUDADeviceAssertionHost.h
CUDAException.h
CUDAEvent.h
CUDAFunctions.h
CUDAGuard.h
CUDAMacros.h

278
c10/cuda/CUDAEvent.h Normal file
View File

@ -0,0 +1,278 @@
#pragma once
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/Exception.h>
/*
* `cudaEventExternal` is a torch-specific flag that is used to
* indicate that the CUDAEvent will be used only for synchronization
* with work outside of the cuda graph, rather than creation of
* cross-stream dependencies within a cuda graph. Resources:
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47
* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e
*/
#define cudaEventExternal 0x08
namespace c10::cuda {
/*
* CUDAEvents are movable not copyable wrappers around CUDA's events.
*
* CUDAEvents are constructed lazily when first recorded unless it is
* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
* device is acquired from the first recording stream. However, if reconstructed
* from a handle, the device should be explicitly specified; or if ipc_handle()
* is called before the event is ever recorded, it will use the current device.
* Later streams that record the event must match this device.
*/
struct CUDAEvent {
// Constructors
// Default value for `flags` is specified below - it's cudaEventDisableTiming
CUDAEvent() noexcept = default;
CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
CUDAEvent(DeviceIndex device_index, const cudaIpcEventHandle_t* handle)
: device_index_(device_index) {
CUDAGuard guard(device_index_);
C10_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
is_created_ = true;
}
// Note: event destruction done on creating device to avoid creating a
// CUDA context on other devices.
~CUDAEvent() {
if (is_created_) {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(event_));
}
}
CUDAEvent(const CUDAEvent&) = delete;
CUDAEvent& operator=(const CUDAEvent&) = delete;
CUDAEvent(CUDAEvent&& other) noexcept {
moveHelper(std::move(other));
}
CUDAEvent& operator=(CUDAEvent&& other) noexcept {
if (this != &other) {
moveHelper(std::move(other));
}
return *this;
}
operator cudaEvent_t() const {
return event();
}
// Less than operator (to allow use in sets)
friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
return left.event_ < right.event_;
}
std::optional<c10::Device> device() const {
if (is_created_) {
return c10::Device(c10::kCUDA, device_index_);
} else {
return {};
}
}
bool isCreated() const {
return is_created_;
}
DeviceIndex device_index() const {
return device_index_;
}
cudaEvent_t event() const {
return event_;
}
// Note: cudaEventQuery can be safely called from any device
bool query() const {
if (!is_created_) {
return true;
}
cudaError_t err = cudaEventQuery(event_);
if (err == cudaSuccess) {
return true;
} else if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return false;
}
void record() {
record(getCurrentCUDAStream());
}
void recordOnce(const CUDAStream& stream) {
if (!was_recorded_)
record(stream);
}
// Note: cudaEventRecord must be called on the same device as the event.
void record(const CUDAStream& stream) {
if (!is_created_) {
createEvent(stream.device_index());
}
TORCH_CHECK(
device_index_ == stream.device_index(),
"Event device ",
device_index_,
" does not match recording stream's device ",
stream.device_index(),
".");
CUDAGuard guard(device_index_);
#ifndef USE_ROCM
// it is an error to use cudaEventRecordExternal when not doing stream
// capture
unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() !=
c10::cuda::CaptureStatus::None &&
external_)
? cudaEventRecordExternal
: cudaEventRecordDefault;
C10_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags));
#else
C10_CUDA_CHECK(cudaEventRecord(event_, stream));
#endif
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream()));
}
was_recorded_ = true;
}
// Note: cudaStreamWaitEvent must be called on the same device as the stream.
// The event has no actual GPU resources associated with it.
void block(const CUDAStream& stream) {
if (is_created_) {
CUDAGuard guard(stream.device_index());
#ifndef USE_ROCM
// it is an error to use cudaEventWaitExternal when not doing stream
// capture
unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() !=
c10::cuda::CaptureStatus::None &&
external_)
? cudaEventWaitExternal
: cudaEventWaitDefault;
C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags));
#else
C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_));
#endif
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream()));
}
}
}
// Note: cudaEventElapsedTime can be safely called from any device
float elapsed_time(const CUDAEvent& other) const {
TORCH_CHECK_VALUE(
!(flags_ & cudaEventDisableTiming) &&
!(other.flags_ & cudaEventDisableTiming),
"Both events must be created with argument 'enable_timing=True'.");
TORCH_CHECK_VALUE(
is_created_ && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
TORCH_CHECK(
query() && other.query(),
"Both events must be completed before calculating elapsed time.");
float time_ms = 0;
// We do not strictly have to set the device index to the same as our event,
// but if we don't and the current device is not initialized, it will
// create a new cuda context, which will consume a lot of memory.
CUDAGuard guard(device_index_);
// raise cudaErrorNotReady if either event is recorded but not yet completed
C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
return time_ms;
}
// Note: cudaEventSynchronize can be safely called from any device
void synchronize() const {
if (is_created_) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
C10_CUDA_CHECK(cudaEventSynchronize(event_));
}
}
// Note: cudaIpcGetEventHandle must be called on the same device as the event
void ipc_handle(cudaIpcEventHandle_t* handle) {
if (!is_created_) {
// this CUDAEvent object was initially constructed from flags but event_
// is not created yet.
createEvent(getCurrentCUDAStream().device_index());
}
CUDAGuard guard(device_index_);
C10_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
}
private:
unsigned int flags_ = cudaEventDisableTiming;
bool is_created_ = false;
bool was_recorded_ = false;
bool external_ = false;
DeviceIndex device_index_ = -1;
cudaEvent_t event_{};
void createEvent(DeviceIndex device_index) {
external_ = (flags_ & cudaEventExternal) != 0;
#ifdef USE_ROCM
TORCH_CHECK(!external_, "External events are disallowed in rocm");
#endif
flags_ &= ~cudaEventExternal;
device_index_ = device_index;
CUDAGuard guard(device_index_);
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
is_created_ = true;
}
void moveHelper(CUDAEvent&& other) {
// Transfer ownership of all state from other to this
flags_ = other.flags_;
is_created_ = other.is_created_;
was_recorded_ = other.was_recorded_;
external_ = other.external_;
device_index_ = other.device_index_;
event_ = other.event_;
// Reset other to a valid empty state to prevent double-free
// The moved-from object must not attempt to destroy the event
other.is_created_ = false;
other.event_ = cudaEvent_t{};
}
};
} // namespace c10::cuda

View File

@ -9231,6 +9231,8 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
API_PYTORCH,
),
),
("cuda::CUDAEvent", ("hip::HIPEventMasqueradingAsCUDA", API_PYTORCH)),
("CUDAEvent", ("HIPEventMasqueradingAsCUDA", API_PYTORCH)),
("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)),
("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)),
(
@ -9285,6 +9287,14 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
"c10/cuda/CUDACachingAllocator.h",
("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH),
),
(
"ATen/cuda/CUDAEvent.h", # To keep BC, we have to keep this mapping
("ATen/hip/HIPEvent.h", API_PYTORCH),
),
(
"c10/cuda/CUDAEvent.h",
("ATen/hip/impl/HIPEventMasqueradingAsCUDA.h", API_PYTORCH),
),
(
"c10/cuda/CUDAStream.h",
("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH),
@ -9425,6 +9435,7 @@ C10_MAPPINGS = collections.OrderedDict(
("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)),
("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)),
("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)),
("c10/cuda/CUDAEvent.h", ("c10/hip/HIPEvent.h", API_C10)),
("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)),
("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)),
("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)),