mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend. # Solution move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794 Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from typing import Callable
|
|
|
|
from torch._utils import CallbackRegistry
|
|
|
|
|
|
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA event creation"
|
|
)
|
|
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA event deletion"
|
|
)
|
|
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
|
"CUDA event record"
|
|
)
|
|
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
|
"CUDA event wait"
|
|
)
|
|
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA memory allocation"
|
|
)
|
|
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA memory deallocation"
|
|
)
|
|
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA stream creation"
|
|
)
|
|
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
|
|
"CUDA device synchronization"
|
|
)
|
|
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA stream synchronization"
|
|
)
|
|
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
|
"CUDA event synchronization"
|
|
)
|
|
|
|
|
|
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
|
|
EventCreationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
|
|
EventDeletionCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
|
|
EventRecordCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
|
|
EventWaitCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
|
|
MemoryAllocationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
|
|
MemoryDeallocationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
|
|
StreamCreationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
|
|
DeviceSynchronizationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
|
|
StreamSynchronizationCallbacks.add_callback(cb)
|
|
|
|
|
|
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
|
|
EventSynchronizationCallbacks.add_callback(cb)
|