mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[Easy] Add more check for elapsedTime of torch.xxx.Event and torch.Event (#151404)
As the title stated **Changes:** - Add **record**, **query** and **enable_timing** check - Add related tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/151404 Approved by: https://github.com/albanD
This commit is contained in:
		| @ -81,21 +81,30 @@ class TestCppExtensionStreamAndEvent(common.TestCase): | ||||
|     def test_stream_event(self): | ||||
|         s = torch.Stream() | ||||
|         self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA)) | ||||
|         e = torch.Event() | ||||
|         e = torch.Event(enable_timing=True) | ||||
|         e1 = torch.Event(enable_timing=True) | ||||
|         e1.record() | ||||
|         self.assertTrue(e.device.type, "mtia") | ||||
|         # Should be nullptr by default | ||||
|         self.assertTrue(e.event_id == 0) | ||||
|         s.record_event(event=e) | ||||
|         print(f"recorded event 1: {e}") | ||||
|         self.assertTrue(e.event_id != 0) | ||||
|         # The enable_timing of event created by record_event() is false | ||||
|         e2 = s.record_event() | ||||
|         print(f"recorded event 2: {e2}") | ||||
|         self.assertTrue(e2.event_id != 0) | ||||
|         self.assertTrue(e2.event_id != e.event_id) | ||||
|         e.synchronize() | ||||
|         e1.synchronize() | ||||
|         e2.synchronize() | ||||
|         time_elapsed = e.elapsed_time(e2) | ||||
|         print(f"time elapsed between e1 and e2: {time_elapsed}") | ||||
|         time_elapsed = e.elapsed_time(e1) | ||||
|         print(f"time elapsed between e and e1: {time_elapsed}") | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             "Both events must be created with argument 'enable_timing=True'", | ||||
|         ): | ||||
|             time_elapsed = e.elapsed_time(e2) | ||||
|         old_event_id = e.event_id | ||||
|         e.record(stream=s) | ||||
|         print(f"recorded event 1: {e}") | ||||
|  | ||||
		Reference in New Issue
	
	Block a user