mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Easy] Add more check for elapsedTime of torch.xxx.Event and torch.Event (#151404)
As the title stated **Changes:** - Add **record**, **query** and **enable_timing** check - Add related tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/151404 Approved by: https://github.com/albanD
This commit is contained in:
@ -81,21 +81,30 @@ class TestCppExtensionStreamAndEvent(common.TestCase):
|
||||
def test_stream_event(self):
|
||||
s = torch.Stream()
|
||||
self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
|
||||
e = torch.Event()
|
||||
e = torch.Event(enable_timing=True)
|
||||
e1 = torch.Event(enable_timing=True)
|
||||
e1.record()
|
||||
self.assertTrue(e.device.type, "mtia")
|
||||
# Should be nullptr by default
|
||||
self.assertTrue(e.event_id == 0)
|
||||
s.record_event(event=e)
|
||||
print(f"recorded event 1: {e}")
|
||||
self.assertTrue(e.event_id != 0)
|
||||
# The enable_timing of event created by record_event() is false
|
||||
e2 = s.record_event()
|
||||
print(f"recorded event 2: {e2}")
|
||||
self.assertTrue(e2.event_id != 0)
|
||||
self.assertTrue(e2.event_id != e.event_id)
|
||||
e.synchronize()
|
||||
e1.synchronize()
|
||||
e2.synchronize()
|
||||
time_elapsed = e.elapsed_time(e2)
|
||||
print(f"time elapsed between e1 and e2: {time_elapsed}")
|
||||
time_elapsed = e.elapsed_time(e1)
|
||||
print(f"time elapsed between e and e1: {time_elapsed}")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Both events must be created with argument 'enable_timing=True'",
|
||||
):
|
||||
time_elapsed = e.elapsed_time(e2)
|
||||
old_event_id = e.event_id
|
||||
e.record(stream=s)
|
||||
print(f"recorded event 1: {e}")
|
||||
|
Reference in New Issue
Block a user