Refactor gpu trace to be device-agnostic (#121794)

# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.

# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
This commit is contained in:
Yu, Guangye
2024-03-15 11:38:53 +00:00
committed by PyTorch MergeBot
parent 09ce76809c
commit 0ff1109e26
20 changed files with 262 additions and 203 deletions

View File

@ -5,7 +5,7 @@ import unittest
import unittest.mock
import torch
import torch.utils._cuda_trace as cuda_trace
import torch.cuda._gpu_trace as gpu_trace
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
@ -19,18 +19,18 @@ if not TEST_CUDA:
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaTrace(TestCase):
def setUp(self):
torch._C._activate_cuda_trace()
torch._C._activate_gpu_trace()
self.mock = unittest.mock.MagicMock()
def test_event_creation_callback(self):
cuda_trace.register_callback_for_cuda_event_creation(self.mock)
gpu_trace.register_callback_for_event_creation(self.mock)
event = torch.cuda.Event()
event.record()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_event_deletion_callback(self):
cuda_trace.register_callback_for_cuda_event_deletion(self.mock)
gpu_trace.register_callback_for_event_deletion(self.mock)
event = torch.cuda.Event()
event.record()
@ -39,7 +39,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event_id)
def test_event_record_callback(self):
cuda_trace.register_callback_for_cuda_event_record(self.mock)
gpu_trace.register_callback_for_event_record(self.mock)
event = torch.cuda.Event()
event.record()
@ -48,7 +48,7 @@ class TestCudaTrace(TestCase):
)
def test_event_wait_callback(self):
cuda_trace.register_callback_for_cuda_event_wait(self.mock)
gpu_trace.register_callback_for_event_wait(self.mock)
event = torch.cuda.Event()
event.record()
@ -58,13 +58,13 @@ class TestCudaTrace(TestCase):
)
def test_memory_allocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
gpu_trace.register_callback_for_memory_allocation(self.mock)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())
def test_memory_deallocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_deallocation(self.mock)
gpu_trace.register_callback_for_memory_deallocation(self.mock)
tensor = torch.empty(3, 8, device="cuda")
data_ptr = tensor.data_ptr()
@ -72,7 +72,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(data_ptr)
def test_stream_creation_callback(self):
cuda_trace.register_callback_for_cuda_stream_creation(self.mock)
gpu_trace.register_callback_for_stream_creation(self.mock)
# see Note [HIP Lazy Streams]
if torch.version.hip:
@ -85,20 +85,20 @@ class TestCudaTrace(TestCase):
self.mock.assert_called()
def test_device_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_device_synchronization(self.mock)
gpu_trace.register_callback_for_device_synchronization(self.mock)
torch.cuda.synchronize()
self.mock.assert_called()
def test_stream_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
gpu_trace.register_callback_for_stream_synchronization(self.mock)
stream = torch.cuda.Stream()
stream.synchronize()
self.mock.assert_called_once_with(stream.cuda_stream)
def test_event_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_event_synchronization(self.mock)
gpu_trace.register_callback_for_event_synchronization(self.mock)
event = torch.cuda.Event()
event.record()
@ -106,7 +106,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_memcpy_synchronization(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
gpu_trace.register_callback_for_stream_synchronization(self.mock)
tensor = torch.rand(5, device="cuda")
tensor.nonzero()
@ -114,8 +114,8 @@ class TestCudaTrace(TestCase):
def test_all_trace_callbacks_called(self):
other = unittest.mock.MagicMock()
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_allocation(other)
gpu_trace.register_callback_for_memory_allocation(self.mock)
gpu_trace.register_callback_for_memory_allocation(other)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())