Refactor gpu trace to be device-agnostic (#121794)

# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.

# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
This commit is contained in:
Yu, Guangye
2024-03-15 11:38:53 +00:00
committed by PyTorch MergeBot
parent 09ce76809c
commit 0ff1109e26
20 changed files with 262 additions and 203 deletions

View File

@ -1905,6 +1905,7 @@ exclude_patterns = [
'torch/compiler/__init__.py',
'torch/contrib/__init__.py',
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
'torch/distributed/__init__.py',
'torch/distributed/_composable_state.py',
@ -2432,7 +2433,6 @@ exclude_patterns = [
'torch/utils/_contextlib.py',
'torch/utils/_cpp_extension_versioner.py',
'torch/utils/_crash_handler.py',
'torch/utils/_cuda_trace.py',
'torch/utils/_device.py',
'torch/utils/_foreach_utils.py',
'torch/utils/_freeze.py',

View File

@ -48,7 +48,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventDestroy(event_));
}
@ -122,7 +122,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
(*interp)->trace_gpu_event_record(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
@ -138,7 +138,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
(*interp)->trace_gpu_event_wait(at::kCUDA,
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
@ -161,7 +161,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
if (is_created_) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventSynchronize(event_));
}
@ -191,7 +191,7 @@ private:
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
}
is_created_ = true;
}

View File

@ -94,17 +94,32 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
}
// Just swallow the event, don't do anything
void trace_gpu_event_creation(uintptr_t event) const override {}
void trace_gpu_event_deletion(uintptr_t event) const override {}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
void trace_gpu_event_creation(c10::DeviceType device_type, uintptr_t event)
const override {}
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {}
void trace_gpu_memory_allocation(uintptr_t ptr) const override {}
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {}
void trace_gpu_stream_creation(uintptr_t stream) const override {}
void trace_gpu_device_synchronization() const override {}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
void trace_gpu_event_synchronization(uintptr_t event) const override {}
void trace_gpu_event_deletion(c10::DeviceType device_type, uintptr_t event)
const override {}
void trace_gpu_event_record(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {}
void trace_gpu_event_wait(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {}
void trace_gpu_memory_allocation(c10::DeviceType device_type, uintptr_t ptr)
const override {}
void trace_gpu_memory_deallocation(c10::DeviceType device_type, uintptr_t ptr)
const override {}
void trace_gpu_stream_creation(c10::DeviceType device_type, uintptr_t stream)
const override {}
void trace_gpu_device_synchronization(
c10::DeviceType device_type) const override {}
void trace_gpu_stream_synchronization(
c10::DeviceType device_type,
uintptr_t stream) const override {}
void trace_gpu_event_synchronization(
c10::DeviceType device_type,
uintptr_t event) const override {}
void reset_backward_hooks(const TensorImpl* self) const override {
PANIC(reset_backward_hooks);

View File

@ -177,18 +177,37 @@ struct C10_API PyInterpreterVTable {
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
virtual void trace_gpu_event_creation(uintptr_t event) const = 0;
virtual void trace_gpu_event_deletion(uintptr_t event) const = 0;
virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const = 0;
virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream)
const = 0;
virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0;
virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0;
virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0;
virtual void trace_gpu_device_synchronization() const = 0;
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
virtual void trace_gpu_event_creation(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_deletion(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_record(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_wait(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_memory_allocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_memory_deallocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_stream_creation(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_device_synchronization(
c10::DeviceType device_type) const = 0;
virtual void trace_gpu_stream_synchronization(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
};

View File

@ -2840,7 +2840,8 @@ static void uncached_delete(void* ptr) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(reinterpret_cast<uintptr_t>(ptr));
(*interp)->trace_gpu_memory_deallocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(ptr));
}
C10_CUDA_CHECK(cudaFree(ptr));
}
@ -2924,7 +2925,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
reinterpret_cast<uintptr_t>(*devPtr));
c10::kCUDA, reinterpret_cast<uintptr_t>(*devPtr));
}
}
@ -2939,7 +2940,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
reinterpret_cast<uintptr_t>(block->ptr));
c10::kCUDA, reinterpret_cast<uintptr_t>(block->ptr));
}
device_allocator[block->device]->free(block);
}
@ -3128,7 +3129,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
reinterpret_cast<uintptr_t>(devPtr));
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
}
} else {
if (size != 0) {

View File

@ -136,7 +136,7 @@ void set_device(DeviceIndex device) {
void device_synchronize() {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_device_synchronization();
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
}
C10_CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@ -87,7 +87,7 @@ C10_CUDA_API void __inline__ memcpy_and_sync(
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
reinterpret_cast<uintptr_t>(stream));
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
}
#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
@ -105,7 +105,7 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
reinterpret_cast<uintptr_t>(stream));
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
}
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}

View File

@ -206,7 +206,8 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) {
C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_creation(reinterpret_cast<uintptr_t>(stream));
(*interp)->trace_gpu_stream_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
priority_counters[p][device_index] = 0;
}
}

View File

@ -96,7 +96,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
reinterpret_cast<uintptr_t>(cuda_event));
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
}
@ -111,7 +111,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
reinterpret_cast<uintptr_t>(cuda_event));
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
@ -146,6 +146,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
@ -168,6 +169,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}

View File

@ -5,7 +5,7 @@ import unittest
import unittest.mock
import torch
import torch.utils._cuda_trace as cuda_trace
import torch.cuda._gpu_trace as gpu_trace
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
@ -19,18 +19,18 @@ if not TEST_CUDA:
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaTrace(TestCase):
def setUp(self):
torch._C._activate_cuda_trace()
torch._C._activate_gpu_trace()
self.mock = unittest.mock.MagicMock()
def test_event_creation_callback(self):
cuda_trace.register_callback_for_cuda_event_creation(self.mock)
gpu_trace.register_callback_for_event_creation(self.mock)
event = torch.cuda.Event()
event.record()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_event_deletion_callback(self):
cuda_trace.register_callback_for_cuda_event_deletion(self.mock)
gpu_trace.register_callback_for_event_deletion(self.mock)
event = torch.cuda.Event()
event.record()
@ -39,7 +39,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event_id)
def test_event_record_callback(self):
cuda_trace.register_callback_for_cuda_event_record(self.mock)
gpu_trace.register_callback_for_event_record(self.mock)
event = torch.cuda.Event()
event.record()
@ -48,7 +48,7 @@ class TestCudaTrace(TestCase):
)
def test_event_wait_callback(self):
cuda_trace.register_callback_for_cuda_event_wait(self.mock)
gpu_trace.register_callback_for_event_wait(self.mock)
event = torch.cuda.Event()
event.record()
@ -58,13 +58,13 @@ class TestCudaTrace(TestCase):
)
def test_memory_allocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
gpu_trace.register_callback_for_memory_allocation(self.mock)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())
def test_memory_deallocation_callback(self):
cuda_trace.register_callback_for_cuda_memory_deallocation(self.mock)
gpu_trace.register_callback_for_memory_deallocation(self.mock)
tensor = torch.empty(3, 8, device="cuda")
data_ptr = tensor.data_ptr()
@ -72,7 +72,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(data_ptr)
def test_stream_creation_callback(self):
cuda_trace.register_callback_for_cuda_stream_creation(self.mock)
gpu_trace.register_callback_for_stream_creation(self.mock)
# see Note [HIP Lazy Streams]
if torch.version.hip:
@ -85,20 +85,20 @@ class TestCudaTrace(TestCase):
self.mock.assert_called()
def test_device_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_device_synchronization(self.mock)
gpu_trace.register_callback_for_device_synchronization(self.mock)
torch.cuda.synchronize()
self.mock.assert_called()
def test_stream_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
gpu_trace.register_callback_for_stream_synchronization(self.mock)
stream = torch.cuda.Stream()
stream.synchronize()
self.mock.assert_called_once_with(stream.cuda_stream)
def test_event_synchronization_callback(self):
cuda_trace.register_callback_for_cuda_event_synchronization(self.mock)
gpu_trace.register_callback_for_event_synchronization(self.mock)
event = torch.cuda.Event()
event.record()
@ -106,7 +106,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_memcpy_synchronization(self):
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
gpu_trace.register_callback_for_stream_synchronization(self.mock)
tensor = torch.rand(5, device="cuda")
tensor.nonzero()
@ -114,8 +114,8 @@ class TestCudaTrace(TestCase):
def test_all_trace_callbacks_called(self):
other = unittest.mock.MagicMock()
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_allocation(other)
gpu_trace.register_callback_for_memory_allocation(self.mock)
gpu_trace.register_callback_for_memory_allocation(other)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())

View File

@ -1273,6 +1273,7 @@ def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Any: ...
def _set_dispatch_mode(mode: Any) -> None: ...
def _get_dispatch_stack_at(idx: _int) -> Any: ...
def _len_torch_dispatch_stack() -> _int: ...
def _activate_gpu_trace() -> None: ...
class _DisableTorchDispatch:
def __init__(self): ...
@ -2150,7 +2151,6 @@ def _c10d_init() -> _bool: ...
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
def _faulty_agent_init() -> _bool: ...
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
def _activate_cuda_trace() -> None: ...
# Defined in torch/csrc/Module.cpp
def _current_graph_task_id() -> _int: ...

View File

@ -291,7 +291,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._assert_async",
"torch._assert_tensor_metadata",
"torch._batch_norm_impl_index",
"torch._C._activate_cuda_trace",
"torch._C._activate_gpu_trace",
"torch._C._add_cached_tensor",
"torch._C._add_docstr",
"torch._C._are_functorch_transforms_active",

View File

@ -1,10 +1,13 @@
import copyreg
import functools
import logging
import sys
import traceback
import warnings
from collections import defaultdict
from typing import Any, DefaultDict, List, Optional
from typing import Any, Callable, DefaultDict, Generic, List, Optional
from typing_extensions import ParamSpec
import torch
@ -935,3 +938,25 @@ class _LazySeedTracker:
def get_calls(self) -> List:
return self.call_order
logger = logging.getLogger(__name__)
P = ParamSpec("P")
class CallbackRegistry(Generic[P]):
def __init__(self, name: str):
self.name = name
self.callback_list: List[Callable[P, None]] = []
def add_callback(self, cb: Callable[P, None]) -> None:
self.callback_list.append(cb)
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
for cb in self.callback_list:
try:
cb(*args, **kwargs)
except Exception as e:
logger.exception(
"Exception in callback for %s registered with gpu trace", self.name
)

View File

@ -18,17 +18,22 @@ namespace {
// because passing in constexpr char* as template argument breaks some
// versions of MSVC that are being used internally at Meta.
// MSVC 14.16.27023 (vs2017_15.9)
#define CONCRETE_TRACE_CUDA(func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
py::module mod = py::module::import("torch.utils._cuda_trace"); \
py::object hook = mod.attr(func_name).attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \
} \
#define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
py::module utils_mod = py::module::import("torch._utils"); \
py::object get_device_module = utils_mod.attr("_get_device_module"); \
py::object hook = get_device_module(DeviceTypeName(device_type, true)) \
.attr("_gpu_trace") \
.attr(func_name) \
.attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << device_type \
<< " trace hook execution failed: " << e.what(); \
} \
}
struct ConcretePyInterpreterVTable final
@ -83,36 +88,51 @@ struct ConcretePyInterpreterVTable final
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
void trace_gpu_event_creation(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks", event);
}
void trace_gpu_event_deletion(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks", event);
}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks", event, stream);
CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
}
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks", event, stream);
void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
}
void trace_gpu_memory_allocation(uintptr_t ptr) const override {
CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks", ptr);
void trace_gpu_event_record(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
}
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {
CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks", ptr);
void trace_gpu_event_wait(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
}
void trace_gpu_stream_creation(uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks", stream);
void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
}
void trace_gpu_device_synchronization() const override {
CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks");
void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks", stream);
void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
const override {
CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
}
void trace_gpu_event_synchronization(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
void trace_gpu_device_synchronization(
at::DeviceType device_type) const override {
CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
}
void trace_gpu_stream_synchronization(
at::DeviceType device_type,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
}
void trace_gpu_event_synchronization(
at::DeviceType device_type,
uintptr_t event) const override {
CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
}
void reset_backward_hooks(const c10::TensorImpl* self) const override;

View File

@ -431,7 +431,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
}
});
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
_C_m, "_InferenceMode");

View File

@ -250,7 +250,7 @@ static PyObject* getPythonTensorClass(c10::Device d) {
return device_to_py_class_[static_cast<size_t>(d.type())];
}
void activateCUDATrace() {
void activateGPUTrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter());
}

View File

@ -31,7 +31,7 @@ TORCH_PYTHON_API void registerPythonTensorClass(
const std::string& device,
PyObject* python_tensor_class);
TORCH_PYTHON_API void activateCUDATrace();
TORCH_PYTHON_API void activateGPUTrace();
TORCH_PYTHON_API extern PyObject* THPVariableClass;
TORCH_PYTHON_API extern PyObject* ParameterClass;

75
torch/cuda/_gpu_trace.py Normal file
View File

@ -0,0 +1,75 @@
from typing import Callable
from torch._utils import CallbackRegistry
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
EventCreationCallbacks.add_callback(cb)
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
EventDeletionCallbacks.add_callback(cb)
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
EventRecordCallbacks.add_callback(cb)
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
EventWaitCallbacks.add_callback(cb)
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
MemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
MemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
StreamCreationCallbacks.add_callback(cb)
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
DeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
StreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
EventSynchronizationCallbacks.add_callback(cb)

View File

@ -22,7 +22,7 @@ from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
import torch
import torch.utils._cuda_trace as cuda_trace
import torch.cuda._gpu_trace as gpu_trace
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
@ -528,35 +528,35 @@ class ArgumentHandler:
class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self):
self.event_handler = EventHandler()
torch._C._activate_cuda_trace()
cuda_trace.register_callback_for_cuda_event_creation(
torch._C._activate_gpu_trace()
gpu_trace.register_callback_for_event_creation(
self.event_handler._handle_event_creation
)
cuda_trace.register_callback_for_cuda_event_deletion(
gpu_trace.register_callback_for_event_deletion(
self.event_handler._handle_event_deletion
)
cuda_trace.register_callback_for_cuda_event_record(
gpu_trace.register_callback_for_event_record(
self.event_handler._handle_event_record
)
cuda_trace.register_callback_for_cuda_event_wait(
gpu_trace.register_callback_for_event_wait(
self.event_handler._handle_event_wait
)
cuda_trace.register_callback_for_cuda_memory_allocation(
gpu_trace.register_callback_for_memory_allocation(
self.event_handler._handle_memory_allocation
)
cuda_trace.register_callback_for_cuda_memory_deallocation(
gpu_trace.register_callback_for_memory_deallocation(
self.event_handler._handle_memory_deallocation
)
cuda_trace.register_callback_for_cuda_stream_creation(
gpu_trace.register_callback_for_stream_creation(
self.event_handler._handle_stream_creation
)
cuda_trace.register_callback_for_cuda_device_synchronization(
gpu_trace.register_callback_for_device_synchronization(
self.event_handler._handle_device_synchronization
)
cuda_trace.register_callback_for_cuda_stream_synchronization(
gpu_trace.register_callback_for_stream_synchronization(
self.event_handler._handle_stream_synchronization
)
cuda_trace.register_callback_for_cuda_event_synchronization(
gpu_trace.register_callback_for_event_synchronization(
self.event_handler._handle_event_synchronization
)

View File

@ -1,99 +0,0 @@
import logging
from typing import Callable, Generic, List
from typing_extensions import ParamSpec # Python 3.10+
logger = logging.getLogger(__name__)
P = ParamSpec("P")
class CallbackRegistry(Generic[P]):
def __init__(self, name: str):
self.name = name
self.callback_list: List[Callable[P, None]] = []
def add_callback(self, cb: Callable[P, None]) -> None:
self.callback_list.append(cb)
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
for cb in self.callback_list:
try:
cb(*args, **kwargs)
except Exception as e:
logger.exception(
"Exception in callback for %s registered with CUDA trace", self.name
)
CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
CUDAEventCreationCallbacks.add_callback(cb)
def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
CUDAEventDeletionCallbacks.add_callback(cb)
def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
CUDAEventRecordCallbacks.add_callback(cb)
def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
CUDAEventWaitCallbacks.add_callback(cb)
def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
CUDAMemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
CUDAMemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
CUDAStreamCreationCallbacks.add_callback(cb)
def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
CUDADeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_cuda_stream_synchronization(
cb: Callable[[int], None]
) -> None:
CUDAStreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
CUDAEventSynchronizationCallbacks.add_callback(cb)