mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Easy] Add more check for elapsedTime of torch.xxx.Event and torch.Event (#151404)
As the title stated **Changes:** - Add **record**, **query** and **enable_timing** check - Add related tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/151404 Approved by: https://github.com/albanD
This commit is contained in:
@ -147,8 +147,16 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
|
|||||||
|
|
||||||
// Note: cudaEventElapsedTime can be safely called from any device
|
// Note: cudaEventElapsedTime can be safely called from any device
|
||||||
float elapsed_time(const CUDAEvent& other) const {
|
float elapsed_time(const CUDAEvent& other) const {
|
||||||
TORCH_CHECK(is_created_ && other.isCreated(),
|
TORCH_CHECK_VALUE(
|
||||||
|
!(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming),
|
||||||
|
"Both events must be created with argument 'enable_timing=True'.");
|
||||||
|
TORCH_CHECK_VALUE(
|
||||||
|
is_created_ && other.isCreated(),
|
||||||
"Both events must be recorded before calculating elapsed time.");
|
"Both events must be recorded before calculating elapsed time.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
query() && other.query(),
|
||||||
|
"Both events must be completed before calculating elapsed time.");
|
||||||
|
|
||||||
float time_ms = 0;
|
float time_ms = 0;
|
||||||
// We do not strictly have to set the device index to the same as our event,
|
// We do not strictly have to set the device index to the same as our event,
|
||||||
// but if we don't and the current device is not initialized, it will
|
// but if we don't and the current device is not initialized, it will
|
||||||
|
@ -106,11 +106,6 @@ struct InlineEvent final {
|
|||||||
}
|
}
|
||||||
|
|
||||||
double elapsedTime(const InlineEvent& other) const {
|
double elapsedTime(const InlineEvent& other) const {
|
||||||
TORCH_CHECK(
|
|
||||||
other.was_marked_for_recording(),
|
|
||||||
"other was not marked for recording.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
was_marked_for_recording(), "self was not marked for recording.");
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
other.device_type() == device_type_,
|
other.device_type() == device_type_,
|
||||||
"Event device type ",
|
"Event device type ",
|
||||||
@ -118,6 +113,19 @@ struct InlineEvent final {
|
|||||||
" does not match other's device type ",
|
" does not match other's device type ",
|
||||||
DeviceTypeName(other.device_type()),
|
DeviceTypeName(other.device_type()),
|
||||||
".");
|
".");
|
||||||
|
TORCH_CHECK_VALUE(
|
||||||
|
(flag_ == EventFlag::BACKEND_DEFAULT) &&
|
||||||
|
(other.flag_ == EventFlag::BACKEND_DEFAULT),
|
||||||
|
"Both events must be created with argument 'enable_timing=True'.");
|
||||||
|
TORCH_CHECK_VALUE(
|
||||||
|
was_marked_for_recording() && other.was_marked_for_recording(),
|
||||||
|
"Both events must be recorded before calculating elapsed time.");
|
||||||
|
// elapsedTime in MPS can wait event to be completed if event is not ready,
|
||||||
|
// which is a little differenct from CUDA
|
||||||
|
TORCH_CHECK(
|
||||||
|
(query() && other.query()) || device_type_ == DeviceType::MPS,
|
||||||
|
"Both events must be completed before calculating elapsed time.");
|
||||||
|
|
||||||
return backend_.elapsedTime(event_, other.event_, device_index_);
|
return backend_.elapsedTime(event_, other.event_, device_index_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,6 +130,23 @@ class TestAccelerator(TestCase):
|
|||||||
self.assertTrue(t_host.is_pinned())
|
self.assertTrue(t_host.is_pinned())
|
||||||
self.assertEqual(t_acc.cpu(), t_host)
|
self.assertEqual(t_acc.cpu(), t_host)
|
||||||
|
|
||||||
|
def test_generic_event_behavior(self):
|
||||||
|
event1 = torch.Event(enable_timing=False)
|
||||||
|
event2 = torch.Event(enable_timing=False)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Both events must be created with argument 'enable_timing=True'",
|
||||||
|
):
|
||||||
|
event1.elapsed_time(event2)
|
||||||
|
|
||||||
|
event1 = torch.Event(enable_timing=True)
|
||||||
|
event2 = torch.Event(enable_timing=True)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Both events must be recorded before calculating elapsed time",
|
||||||
|
):
|
||||||
|
event1.elapsed_time(event2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -81,21 +81,30 @@ class TestCppExtensionStreamAndEvent(common.TestCase):
|
|||||||
def test_stream_event(self):
|
def test_stream_event(self):
|
||||||
s = torch.Stream()
|
s = torch.Stream()
|
||||||
self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
|
self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
|
||||||
e = torch.Event()
|
e = torch.Event(enable_timing=True)
|
||||||
|
e1 = torch.Event(enable_timing=True)
|
||||||
|
e1.record()
|
||||||
self.assertTrue(e.device.type, "mtia")
|
self.assertTrue(e.device.type, "mtia")
|
||||||
# Should be nullptr by default
|
# Should be nullptr by default
|
||||||
self.assertTrue(e.event_id == 0)
|
self.assertTrue(e.event_id == 0)
|
||||||
s.record_event(event=e)
|
s.record_event(event=e)
|
||||||
print(f"recorded event 1: {e}")
|
print(f"recorded event 1: {e}")
|
||||||
self.assertTrue(e.event_id != 0)
|
self.assertTrue(e.event_id != 0)
|
||||||
|
# The enable_timing of event created by record_event() is false
|
||||||
e2 = s.record_event()
|
e2 = s.record_event()
|
||||||
print(f"recorded event 2: {e2}")
|
print(f"recorded event 2: {e2}")
|
||||||
self.assertTrue(e2.event_id != 0)
|
self.assertTrue(e2.event_id != 0)
|
||||||
self.assertTrue(e2.event_id != e.event_id)
|
self.assertTrue(e2.event_id != e.event_id)
|
||||||
e.synchronize()
|
e.synchronize()
|
||||||
|
e1.synchronize()
|
||||||
e2.synchronize()
|
e2.synchronize()
|
||||||
|
time_elapsed = e.elapsed_time(e1)
|
||||||
|
print(f"time elapsed between e and e1: {time_elapsed}")
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Both events must be created with argument 'enable_timing=True'",
|
||||||
|
):
|
||||||
time_elapsed = e.elapsed_time(e2)
|
time_elapsed = e.elapsed_time(e2)
|
||||||
print(f"time elapsed between e1 and e2: {time_elapsed}")
|
|
||||||
old_event_id = e.event_id
|
old_event_id = e.event_id
|
||||||
e.record(stream=s)
|
e.record(stream=s)
|
||||||
print(f"recorded event 1: {e}")
|
print(f"recorded event 1: {e}")
|
||||||
|
@ -962,6 +962,22 @@ class TestCuda(TestCase):
|
|||||||
self.assertTrue(event.query())
|
self.assertTrue(event.query())
|
||||||
self.assertGreater(start_event.elapsed_time(event), 0)
|
self.assertGreater(start_event.elapsed_time(event), 0)
|
||||||
|
|
||||||
|
def test_events_elapsedtime(self):
|
||||||
|
event1 = torch.cuda.Event(enable_timing=False)
|
||||||
|
event2 = torch.cuda.Event(enable_timing=False)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Both events must be created with argument 'enable_timing=True'",
|
||||||
|
):
|
||||||
|
event1.elapsed_time(event2)
|
||||||
|
|
||||||
|
event1 = torch.cuda.Event(enable_timing=True)
|
||||||
|
event2 = torch.cuda.Event(enable_timing=True)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "Both events must be recorded before calculating elapsed time"
|
||||||
|
):
|
||||||
|
event1.elapsed_time(event2)
|
||||||
|
|
||||||
def test_generic_stream_event(self):
|
def test_generic_stream_event(self):
|
||||||
stream = torch.Stream("cuda")
|
stream = torch.Stream("cuda")
|
||||||
self.assertEqual(stream.device_index, torch.cuda.current_device())
|
self.assertEqual(stream.device_index, torch.cuda.current_device())
|
||||||
|
Reference in New Issue
Block a user