mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix record issue on XPUGuardImpl (#123523)
# Motivation Previously, `xpu_event` became a dangling pointer because the variable on the stack is destroyed when the scope ends. It results in these event-related functions (`destroyEvent`, `record`, `block`, and `queryEvent`) used in `c10/core/impl/InlineEvent.h`, which serves `c10::Event`, do not work correctly. # Solution Use `new` allocated on the heap to assign `xpu_event` to avoid the dangling pointer. # Additional Context Add a UT to cover this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123523 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
266e278ccf
commit
270dd99180
@ -67,7 +67,18 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
// Event-related functions
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override {}
|
||||
const noexcept override {
|
||||
if (!event)
|
||||
return;
|
||||
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
c10::kXPU, reinterpret_cast<uintptr_t>(event));
|
||||
}
|
||||
|
||||
delete reinterpret_cast<sycl::event*>(event);
|
||||
}
|
||||
|
||||
void record(
|
||||
void** event,
|
||||
@ -84,7 +95,12 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
|
||||
const XPUStream xpu_stream{stream};
|
||||
*xpu_event = xpu_stream.queue().ext_oneapi_submit_barrier();
|
||||
|
||||
// Delete the event previously recorded.
|
||||
if (xpu_event)
|
||||
delete xpu_event;
|
||||
xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
|
||||
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
|
@ -1,7 +1,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Event.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <c10/xpu/test/impl/XPUTest.h>
|
||||
|
||||
bool has_xpu() {
|
||||
return c10::xpu::device_count() > 0;
|
||||
@ -42,3 +44,51 @@ TEST(XPUGuardTest, GuardBehavior) {
|
||||
EXPECT_EQ(streams1[1].device_index(), 1);
|
||||
EXPECT_EQ(c10::xpu::current_device(), 0);
|
||||
}
|
||||
|
||||
TEST(XPUGuardTest, EventBehavior) {
|
||||
if (!has_xpu()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto device = c10::Device(c10::kXPU, c10::xpu::current_device());
|
||||
c10::impl::VirtualGuardImpl impl(device.type());
|
||||
c10::Stream stream1 = impl.getStream(device);
|
||||
c10::Stream stream2 = impl.getStream(device);
|
||||
c10::Event event(device.type());
|
||||
|
||||
constexpr int numel = 1024;
|
||||
int hostData1[numel];
|
||||
initHostData(hostData1, numel);
|
||||
int hostData2[numel];
|
||||
clearHostData(hostData2, numel);
|
||||
|
||||
auto xpu_stream1 = c10::xpu::XPUStream(stream1);
|
||||
int* deviceData = sycl::malloc_device<int>(numel, xpu_stream1);
|
||||
|
||||
// Copy hostData1 to deviceData via stream1, and then copy deviceData to
|
||||
// hostData2 via stream2.
|
||||
xpu_stream1.queue().memcpy(deviceData, hostData1, sizeof(int) * numel);
|
||||
// stream2 wait on stream1's completion.
|
||||
event.record(stream1);
|
||||
event.block(stream2);
|
||||
auto xpu_stream2 = c10::xpu::XPUStream(stream2);
|
||||
xpu_stream2.queue().memcpy(hostData2, deviceData, sizeof(int) * numel);
|
||||
xpu_stream2.synchronize();
|
||||
|
||||
EXPECT_TRUE(event.query());
|
||||
validateHostData(hostData2, numel);
|
||||
event.record(stream2);
|
||||
EXPECT_TRUE(event.query());
|
||||
|
||||
clearHostData(hostData2, numel);
|
||||
xpu_stream1.queue().memcpy(deviceData, hostData1, sizeof(int) * numel);
|
||||
// stream2 wait on stream1's completion.
|
||||
event.record(stream1);
|
||||
event.block(stream2);
|
||||
// event will overwrite the previously captured state.
|
||||
event.record(stream2);
|
||||
xpu_stream2.queue().memcpy(hostData2, deviceData, sizeof(int) * numel);
|
||||
xpu_stream2.synchronize();
|
||||
EXPECT_TRUE(event.query());
|
||||
validateHostData(hostData2, numel);
|
||||
}
|
||||
|
Reference in New Issue
Block a user