From 2c5c793085ba1de70dd0f9bcd2ddad3d84f07153 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 23 Apr 2025 14:26:12 +0800 Subject: [PATCH] [Easy] Add more check for elapsedTime of torch.xxx.Event and torch.Event (#151404) As the title stated **Changes:** - Add **record**, **query** and **enable_timing** check - Add related tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/151404 Approved by: https://github.com/albanD --- aten/src/ATen/cuda/CUDAEvent.h | 12 ++++++++++-- c10/core/impl/InlineEvent.h | 18 +++++++++++++----- test/test_accelerator.py | 17 +++++++++++++++++ test/test_cpp_extensions_stream_and_event.py | 15 ++++++++++++--- test/test_cuda.py | 16 ++++++++++++++++ 5 files changed, 68 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 94ce34645b02..63b41343f9c0 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -147,8 +147,16 @@ struct TORCH_CUDA_CPP_API CUDAEvent { // Note: cudaEventElapsedTime can be safely called from any device float elapsed_time(const CUDAEvent& other) const { - TORCH_CHECK(is_created_ && other.isCreated(), - "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + float time_ms = 0; // We do not strictly have to set the device index to the same as our event, // but if we don't and the current device is not initialized, it will diff --git a/c10/core/impl/InlineEvent.h b/c10/core/impl/InlineEvent.h index 82fa3384907e..a731621a5bfd 100644 --- a/c10/core/impl/InlineEvent.h +++ b/c10/core/impl/InlineEvent.h @@ -106,11 +106,6 @@ struct InlineEvent final { } double elapsedTime(const InlineEvent& other) const { - TORCH_CHECK( - other.was_marked_for_recording(), - "other was not marked for recording."); - TORCH_CHECK( - was_marked_for_recording(), "self was not marked for recording."); TORCH_CHECK( other.device_type() == device_type_, "Event device type ", @@ -118,6 +113,19 @@ struct InlineEvent final { " does not match other's device type ", DeviceTypeName(other.device_type()), "."); + TORCH_CHECK_VALUE( + (flag_ == EventFlag::BACKEND_DEFAULT) && + (other.flag_ == EventFlag::BACKEND_DEFAULT), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + was_marked_for_recording() && other.was_marked_for_recording(), + "Both events must be recorded before calculating elapsed time."); + // elapsedTime in MPS can wait event to be completed if event is not ready, + // which is a little differenct from CUDA + TORCH_CHECK( + (query() && other.query()) || device_type_ == DeviceType::MPS, + "Both events must be completed before calculating elapsed time."); + return backend_.elapsedTime(event_, other.event_, device_index_); } diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 8e607b89717f..c07f2ddc6891 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -130,6 +130,23 @@ class TestAccelerator(TestCase): self.assertTrue(t_host.is_pinned()) self.assertEqual(t_acc.cpu(), t_host) + def test_generic_event_behavior(self): + event1 = torch.Event(enable_timing=False) + event2 = torch.Event(enable_timing=False) + with self.assertRaisesRegex( + ValueError, + "Both events must be created with argument 'enable_timing=True'", + ): + event1.elapsed_time(event2) + + event1 = torch.Event(enable_timing=True) + event2 = torch.Event(enable_timing=True) + with self.assertRaisesRegex( + ValueError, + "Both events must be recorded before calculating elapsed time", + ): + event1.elapsed_time(event2) + if __name__ == "__main__": run_tests() diff --git a/test/test_cpp_extensions_stream_and_event.py b/test/test_cpp_extensions_stream_and_event.py index f6b2281e1711..a6a5ae8cd9b4 100644 --- a/test/test_cpp_extensions_stream_and_event.py +++ b/test/test_cpp_extensions_stream_and_event.py @@ -81,21 +81,30 @@ class TestCppExtensionStreamAndEvent(common.TestCase): def test_stream_event(self): s = torch.Stream() self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA)) - e = torch.Event() + e = torch.Event(enable_timing=True) + e1 = torch.Event(enable_timing=True) + e1.record() self.assertTrue(e.device.type, "mtia") # Should be nullptr by default self.assertTrue(e.event_id == 0) s.record_event(event=e) print(f"recorded event 1: {e}") self.assertTrue(e.event_id != 0) + # The enable_timing of event created by record_event() is false e2 = s.record_event() print(f"recorded event 2: {e2}") self.assertTrue(e2.event_id != 0) self.assertTrue(e2.event_id != e.event_id) e.synchronize() + e1.synchronize() e2.synchronize() - time_elapsed = e.elapsed_time(e2) - print(f"time elapsed between e1 and e2: {time_elapsed}") + time_elapsed = e.elapsed_time(e1) + print(f"time elapsed between e and e1: {time_elapsed}") + with self.assertRaisesRegex( + ValueError, + "Both events must be created with argument 'enable_timing=True'", + ): + time_elapsed = e.elapsed_time(e2) old_event_id = e.event_id e.record(stream=s) print(f"recorded event 1: {e}") diff --git a/test/test_cuda.py b/test/test_cuda.py index 5ba59fa5968a..2d21e8186ba0 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -962,6 +962,22 @@ class TestCuda(TestCase): self.assertTrue(event.query()) self.assertGreater(start_event.elapsed_time(event), 0) + def test_events_elapsedtime(self): + event1 = torch.cuda.Event(enable_timing=False) + event2 = torch.cuda.Event(enable_timing=False) + with self.assertRaisesRegex( + ValueError, + "Both events must be created with argument 'enable_timing=True'", + ): + event1.elapsed_time(event2) + + event1 = torch.cuda.Event(enable_timing=True) + event2 = torch.cuda.Event(enable_timing=True) + with self.assertRaisesRegex( + ValueError, "Both events must be recorded before calculating elapsed time" + ): + event1.elapsed_time(event2) + def test_generic_stream_event(self): stream = torch.Stream("cuda") self.assertEqual(stream.device_index, torch.cuda.current_device())