mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	As the title stated **Changes:** - Add **record**, **query** and **enable_timing** check - Add related tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/151404 Approved by: https://github.com/albanD
		
			
				
	
	
		
			116 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			116 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: mtia"]
 | 
						|
 | 
						|
import os
 | 
						|
import tempfile
 | 
						|
import unittest
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.testing._internal.common_utils as common
 | 
						|
import torch.utils.cpp_extension
 | 
						|
from torch.testing._internal.common_utils import (
 | 
						|
    IS_ARM64,
 | 
						|
    IS_LINUX,
 | 
						|
    skipIfTorchDynamo,
 | 
						|
    TEST_CUDA,
 | 
						|
    TEST_MPS,
 | 
						|
    TEST_PRIVATEUSE1,
 | 
						|
    TEST_XPU,
 | 
						|
)
 | 
						|
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
 | 
						|
 | 
						|
 | 
						|
# define TEST_ROCM before changing TEST_CUDA
 | 
						|
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
 | 
						|
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
 | 
						|
 | 
						|
 | 
						|
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
 | 
						|
# The test will be skipped if any of the following conditions are met:
 | 
						|
@unittest.skipIf(
 | 
						|
    IS_ARM64
 | 
						|
    or not IS_LINUX
 | 
						|
    or TEST_CUDA
 | 
						|
    or TEST_XPU
 | 
						|
    or TEST_MPS
 | 
						|
    or TEST_PRIVATEUSE1
 | 
						|
    or TEST_ROCM,
 | 
						|
    "Only on linux platform and mutual exclusive to other backends",
 | 
						|
)
 | 
						|
@torch.testing._internal.common_utils.markDynamoStrictTest
 | 
						|
class TestCppExtensionStreamAndEvent(common.TestCase):
 | 
						|
    """Tests Stream and Event with C++ extensions."""
 | 
						|
 | 
						|
    module = None
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        # cpp extensions use relative paths. Those paths are relative to
 | 
						|
        # this file, so we'll change the working directory temporarily
 | 
						|
        self.old_working_dir = os.getcwd()
 | 
						|
        os.chdir(os.path.dirname(os.path.abspath(__file__)))
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        super().tearDown()
 | 
						|
        # return the working directory (see setUp)
 | 
						|
        os.chdir(self.old_working_dir)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def tearDownClass(cls):
 | 
						|
        torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def setUpClass(cls):
 | 
						|
        torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
 | 
						|
        build_dir = tempfile.mkdtemp()
 | 
						|
        # Load the fake device guard impl.
 | 
						|
        src = f"{os.path.abspath(os.path.dirname(__file__))}/cpp_extensions/mtia_extension.cpp"
 | 
						|
        cls.module = torch.utils.cpp_extension.load(
 | 
						|
            name="mtia_extension",
 | 
						|
            sources=[src],
 | 
						|
            build_directory=build_dir,
 | 
						|
            extra_include_paths=[
 | 
						|
                "cpp_extensions",
 | 
						|
                "path / with spaces in it",
 | 
						|
                "path with quote'",
 | 
						|
            ],
 | 
						|
            is_python_module=False,
 | 
						|
            verbose=True,
 | 
						|
        )
 | 
						|
 | 
						|
    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
 | 
						|
    def test_stream_event(self):
 | 
						|
        s = torch.Stream()
 | 
						|
        self.assertTrue(s.device_type, int(torch._C._autograd.DeviceType.MTIA))
 | 
						|
        e = torch.Event(enable_timing=True)
 | 
						|
        e1 = torch.Event(enable_timing=True)
 | 
						|
        e1.record()
 | 
						|
        self.assertTrue(e.device.type, "mtia")
 | 
						|
        # Should be nullptr by default
 | 
						|
        self.assertTrue(e.event_id == 0)
 | 
						|
        s.record_event(event=e)
 | 
						|
        print(f"recorded event 1: {e}")
 | 
						|
        self.assertTrue(e.event_id != 0)
 | 
						|
        # The enable_timing of event created by record_event() is false
 | 
						|
        e2 = s.record_event()
 | 
						|
        print(f"recorded event 2: {e2}")
 | 
						|
        self.assertTrue(e2.event_id != 0)
 | 
						|
        self.assertTrue(e2.event_id != e.event_id)
 | 
						|
        e.synchronize()
 | 
						|
        e1.synchronize()
 | 
						|
        e2.synchronize()
 | 
						|
        time_elapsed = e.elapsed_time(e1)
 | 
						|
        print(f"time elapsed between e and e1: {time_elapsed}")
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            ValueError,
 | 
						|
            "Both events must be created with argument 'enable_timing=True'",
 | 
						|
        ):
 | 
						|
            time_elapsed = e.elapsed_time(e2)
 | 
						|
        old_event_id = e.event_id
 | 
						|
        e.record(stream=s)
 | 
						|
        print(f"recorded event 1: {e}")
 | 
						|
        self.assertTrue(e.event_id == old_event_id)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    common.run_tests()
 |