mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable XPUEvent elapsed_time function (#134666)
# Motivation This PR aims to enable `elapsed_time` function for `XPUEvent`. # Additional Context This PR depends on toolchain oneAPI 2025.0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134666 Approved by: https://github.com/EikanWang, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9fb2c6abe
commit
4bbd6da331
@ -85,8 +85,7 @@ struct TORCH_XPU_API XPUEvent {
|
|||||||
void record(const XPUStream& stream) {
|
void record(const XPUStream& stream) {
|
||||||
if (!isCreated()) {
|
if (!isCreated()) {
|
||||||
device_index_ = stream.device_index();
|
device_index_ = stream.device_index();
|
||||||
event_ = std::make_unique<sycl::event>(
|
assignEvent(stream.queue());
|
||||||
stream.queue().ext_oneapi_submit_barrier());
|
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
if (C10_UNLIKELY(interp)) {
|
if (C10_UNLIKELY(interp)) {
|
||||||
(*interp)->trace_gpu_event_creation(
|
(*interp)->trace_gpu_event_creation(
|
||||||
@ -100,9 +99,7 @@ struct TORCH_XPU_API XPUEvent {
|
|||||||
" does not match recording stream's device ",
|
" does not match recording stream's device ",
|
||||||
stream.device_index(),
|
stream.device_index(),
|
||||||
".");
|
".");
|
||||||
event_.reset();
|
reassignEvent(stream.queue());
|
||||||
event_ = std::make_unique<sycl::event>(
|
|
||||||
stream.queue().ext_oneapi_submit_barrier());
|
|
||||||
}
|
}
|
||||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||||
if (C10_UNLIKELY(interp)) {
|
if (C10_UNLIKELY(interp)) {
|
||||||
@ -128,7 +125,7 @@ struct TORCH_XPU_API XPUEvent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float elapsed_time(const XPUEvent& other) const {
|
double elapsed_time(const XPUEvent& other) const {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
isCreated() && other.isCreated(),
|
isCreated() && other.isCreated(),
|
||||||
"Both events must be recorded before calculating elapsed time.");
|
"Both events must be recorded before calculating elapsed time.");
|
||||||
@ -138,10 +135,20 @@ struct TORCH_XPU_API XPUEvent {
|
|||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
enable_timing_ && other.enable_timing_,
|
enable_timing_ && other.enable_timing_,
|
||||||
"Both events must be created with argument 'enable_timing=True'.");
|
"Both events must be created with argument 'enable_timing=True'.");
|
||||||
// TODO: provides the ability to time the execution of commands in a SYCL
|
|
||||||
// queue without enabling profiling on the entire queue
|
#if SYCL_COMPILER_VERSION < 20250000
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
false, "elapsed_time is not supported by XPUEvent.");
|
false,
|
||||||
|
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace sycl::info::event_profiling;
|
||||||
|
// Block until both of the recorded events are completed.
|
||||||
|
uint64_t end_time_ns = other.event().get_profiling_info<command_end>();
|
||||||
|
uint64_t start_time_ns = event().get_profiling_info<command_end>();
|
||||||
|
// Return the eplased time in milliseconds.
|
||||||
|
return 1e-6 *
|
||||||
|
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||||
}
|
}
|
||||||
|
|
||||||
void synchronize() const {
|
void synchronize() const {
|
||||||
@ -156,6 +163,24 @@ struct TORCH_XPU_API XPUEvent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void assignEvent(sycl::queue& queue) {
|
||||||
|
#if SYCL_COMPILER_VERSION >= 20250000
|
||||||
|
if (enable_timing_) {
|
||||||
|
event_ = std::make_unique<sycl::event>(
|
||||||
|
sycl::ext::oneapi::experimental::submit_profiling_tag(queue));
|
||||||
|
} else {
|
||||||
|
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
event_ = std::make_unique<sycl::event>(queue.ext_oneapi_submit_barrier());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void reassignEvent(sycl::queue& queue) {
|
||||||
|
event_.reset();
|
||||||
|
assignEvent(queue);
|
||||||
|
}
|
||||||
|
|
||||||
bool enable_timing_ = false;
|
bool enable_timing_ = false;
|
||||||
DeviceIndex device_index_ = -1;
|
DeviceIndex device_index_ = -1;
|
||||||
// Only need to track the last event, as events in an in-order queue are
|
// Only need to track the last event, as events in an in-order queue are
|
||||||
|
|||||||
@ -140,6 +140,30 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||||||
event_command_status::complete;
|
event_command_status::complete;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double elapsedTime(
|
||||||
|
void* start_event,
|
||||||
|
void* end_event,
|
||||||
|
const DeviceIndex device_index) const override {
|
||||||
|
#if SYCL_COMPILER_VERSION < 20250000
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK(
|
||||||
|
start_event && end_event,
|
||||||
|
"Both events must be recorded before calculating elapsed time.");
|
||||||
|
auto* xpu_start_event = reinterpret_cast<sycl::event*>(start_event);
|
||||||
|
auto* xpu_end_event = reinterpret_cast<sycl::event*>(end_event);
|
||||||
|
|
||||||
|
using namespace sycl::info::event_profiling;
|
||||||
|
// Block until both of the recorded events are completed.
|
||||||
|
uint64_t end_time_ns = xpu_end_event->get_profiling_info<command_end>();
|
||||||
|
uint64_t start_time_ns = xpu_start_event->get_profiling_info<command_end>();
|
||||||
|
// Return the eplased time in milliseconds.
|
||||||
|
return 1e-6 *
|
||||||
|
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
|
||||||
|
}
|
||||||
|
|
||||||
// Stream-related functions
|
// Stream-related functions
|
||||||
bool queryStream(const Stream& stream) const override {
|
bool queryStream(const Stream& stream) const override {
|
||||||
const XPUStream xpu_stream{stream};
|
const XPUStream xpu_stream{stream};
|
||||||
@ -176,12 +200,6 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||||||
const XPUStream xpu_stream{stream};
|
const XPUStream xpu_stream{stream};
|
||||||
XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
|
XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
|
|
||||||
const override {
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "elapsedTime is not supported by XPU backend.");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace c10::xpu::impl
|
} // namespace c10::xpu::impl
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -235,6 +236,21 @@ print(torch.xpu.device_count())
|
|||||||
stream.record_event(event)
|
stream.record_event(event)
|
||||||
event.synchronize()
|
event.synchronize()
|
||||||
self.assertTrue(event.query())
|
self.assertTrue(event.query())
|
||||||
|
start_event = torch.xpu.Event(enable_timing=True)
|
||||||
|
end_event = torch.xpu.Event(enable_timing=True)
|
||||||
|
stream.record_event(start_event)
|
||||||
|
time.sleep(0.1)
|
||||||
|
stream.record_event(end_event)
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
if int(torch.version.xpu) >= 20250000:
|
||||||
|
self.assertGreater(start_event.elapsed_time(end_event), 0)
|
||||||
|
self.assertLess(end_event.elapsed_time(start_event), 0)
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
NotImplementedError,
|
||||||
|
"elapsed_time of XPUEvent requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.",
|
||||||
|
):
|
||||||
|
start_event.elapsed_time(end_event)
|
||||||
|
|
||||||
def test_generic_stream_event(self):
|
def test_generic_stream_event(self):
|
||||||
stream = torch.Stream("xpu")
|
stream = torch.Stream("xpu")
|
||||||
@ -250,8 +266,8 @@ print(torch.xpu.device_count())
|
|||||||
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
|
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
|
||||||
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
|
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
|
||||||
|
|
||||||
event1 = torch.Event("xpu")
|
event1 = torch.Event("xpu", enable_timing=True)
|
||||||
event2 = torch.Event("xpu")
|
event2 = torch.Event("xpu", enable_timing=True)
|
||||||
self.assertEqual(event1.event_id, 0)
|
self.assertEqual(event1.event_id, 0)
|
||||||
a = torch.randn(1000)
|
a = torch.randn(1000)
|
||||||
b = torch.randn(1000)
|
b = torch.randn(1000)
|
||||||
@ -268,10 +284,15 @@ print(torch.xpu.device_count())
|
|||||||
self.assertTrue(event2.query())
|
self.assertTrue(event2.query())
|
||||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||||
self.assertEqual(c_xpu.cpu(), a + b)
|
self.assertEqual(c_xpu.cpu(), a + b)
|
||||||
with self.assertRaisesRegex(
|
if int(torch.version.xpu) >= 20250000:
|
||||||
NotImplementedError, "elapsedTime is not supported by XPU backend."
|
self.assertGreater(event1.elapsed_time(event2), 0)
|
||||||
):
|
self.assertLess(event2.elapsed_time(event1), 0)
|
||||||
event1.elapsed_time(event2)
|
else:
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
NotImplementedError,
|
||||||
|
"elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.",
|
||||||
|
):
|
||||||
|
event1.elapsed_time(event2)
|
||||||
xpu_event = torch.xpu.Event()
|
xpu_event = torch.xpu.Event()
|
||||||
self.assertIsInstance(xpu_event, torch.Event)
|
self.assertIsInstance(xpu_event, torch.Event)
|
||||||
self.assertTrue(issubclass(type(xpu_event), torch.Event))
|
self.assertTrue(issubclass(type(xpu_event), torch.Event))
|
||||||
|
|||||||
Reference in New Issue
Block a user