Revert "Refactor gpu trace to be device-agnostic (#121794)"

This reverts commit 74deacbf31d032a2659dc1633dc3e5248921d466.

Reverted https://github.com/pytorch/pytorch/pull/121794 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks ROCm jobs in trunk 74deacbf31, please help take a look and reland the change ([comment](https://github.com/pytorch/pytorch/pull/121794#issuecomment-2013674083))
This commit is contained in:
PyTorch MergeBot
2024-03-21 20:33:16 +00:00
parent 13afbcfc85
commit 968c4c4154
20 changed files with 203 additions and 262 deletions

View File

@ -1906,7 +1906,6 @@ exclude_patterns = [
'torch/compiler/__init__.py',
'torch/contrib/__init__.py',
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
'torch/distributed/__init__.py',
'torch/distributed/_composable_state.py',
@ -2434,6 +2433,7 @@ exclude_patterns = [
'torch/utils/_contextlib.py',
'torch/utils/_cpp_extension_versioner.py',
'torch/utils/_crash_handler.py',
'torch/utils/_cuda_trace.py',
'torch/utils/_device.py',
'torch/utils/_foreach_utils.py',
'torch/utils/_freeze.py',

View File

@ -48,7 +48,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
CUDAGuard guard(device_index_);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventDestroy(event_));
}
@ -122,7 +122,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(at::kCUDA,
(*interp)->trace_gpu_event_record(
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
@ -138,7 +138,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(at::kCUDA,
(*interp)->trace_gpu_event_wait(
reinterpret_cast<uintptr_t>(event_),
reinterpret_cast<uintptr_t>(stream.stream())
);
@ -161,7 +161,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
if (is_created_) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_));
}
AT_CUDA_CHECK(cudaEventSynchronize(event_));
}
@ -191,7 +191,7 @@ private:
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
(*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_));
}
is_created_ = true;
}

View File

@ -94,32 +94,17 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
}
// Just swallow the event, don't do anything
void trace_gpu_event_creation(c10::DeviceType device_type, uintptr_t event)
void trace_gpu_event_creation(uintptr_t event) const override {}
void trace_gpu_event_deletion(uintptr_t event) const override {}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const override {}
void trace_gpu_event_deletion(c10::DeviceType device_type, uintptr_t event)
const override {}
void trace_gpu_event_record(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {}
void trace_gpu_event_wait(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {}
void trace_gpu_memory_allocation(c10::DeviceType device_type, uintptr_t ptr)
const override {}
void trace_gpu_memory_deallocation(c10::DeviceType device_type, uintptr_t ptr)
const override {}
void trace_gpu_stream_creation(c10::DeviceType device_type, uintptr_t stream)
const override {}
void trace_gpu_device_synchronization(
c10::DeviceType device_type) const override {}
void trace_gpu_stream_synchronization(
c10::DeviceType device_type,
uintptr_t stream) const override {}
void trace_gpu_event_synchronization(
c10::DeviceType device_type,
uintptr_t event) const override {}
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {}
void trace_gpu_memory_allocation(uintptr_t ptr) const override {}
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {}
void trace_gpu_stream_creation(uintptr_t stream) const override {}
void trace_gpu_device_synchronization() const override {}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
void trace_gpu_event_synchronization(uintptr_t event) const override {}
void reset_backward_hooks(const TensorImpl* self) const override {
PANIC(reset_backward_hooks);

View File

@ -177,37 +177,18 @@ struct C10_API PyInterpreterVTable {
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
virtual void trace_gpu_event_creation(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_deletion(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_record(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_wait(
c10::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const = 0;
virtual void trace_gpu_memory_allocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_memory_deallocation(
c10::DeviceType device_type,
uintptr_t ptr) const = 0;
virtual void trace_gpu_stream_creation(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_device_synchronization(
c10::DeviceType device_type) const = 0;
virtual void trace_gpu_stream_synchronization(
c10::DeviceType device_type,
uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(
c10::DeviceType device_type,
uintptr_t event) const = 0;
virtual void trace_gpu_event_creation(uintptr_t event) const = 0;
virtual void trace_gpu_event_deletion(uintptr_t event) const = 0;
virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const = 0;
virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream)
const = 0;
virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0;
virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0;
virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0;
virtual void trace_gpu_device_synchronization() const = 0;
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
};

View File

@ -2844,8 +2844,7 @@ static void uncached_delete(void* ptr) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(ptr));
(*interp)->trace_gpu_memory_deallocation(reinterpret_cast<uintptr_t>(ptr));
}
C10_CUDA_CHECK(cudaFree(ptr));
}
@ -2929,7 +2928,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(*devPtr));
reinterpret_cast<uintptr_t>(*devPtr));
}
}
@ -2944,7 +2943,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(block->ptr));
reinterpret_cast<uintptr_t>(block->ptr));
}
device_allocator[block->device]->free(block);
}
@ -3133,7 +3132,7 @@ class NativeCachingAllocator : public CUDAAllocator {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
reinterpret_cast<uintptr_t>(devPtr));
}
} else {
if (size != 0) {

View File

@ -136,7 +136,7 @@ void set_device(DeviceIndex device) {
void device_synchronize() {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
(*interp)->trace_gpu_device_synchronization();
}
C10_CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@ -87,7 +87,7 @@ C10_CUDA_API void __inline__ memcpy_and_sync(
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
reinterpret_cast<uintptr_t>(stream));
}
#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
@ -105,7 +105,7 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
reinterpret_cast<uintptr_t>(stream));
}
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}

View File

@ -206,8 +206,7 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) {
C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
(*interp)->trace_gpu_stream_creation(reinterpret_cast<uintptr_t>(stream));
priority_counters[p][device_index] = 0;
}
}

View File

@ -96,7 +96,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
reinterpret_cast<uintptr_t>(cuda_event));
}
}
@ -111,7 +111,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
reinterpret_cast<uintptr_t>(cuda_event));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
@ -146,7 +146,6 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
@ -169,7 +168,6 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}

View File

@ -5,7 +5,7 @@ import unittest
import unittest.mock
import torch
import torch.cuda._gpu_trace as gpu_trace
import torch.utils._cuda_trace as cuda_trace
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
@ -19,18 +19,18 @@ if not TEST_CUDA:
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaTrace(TestCase):
def setUp(self):
torch._C._activate_gpu_trace()
torch._C._activate_cuda_trace()
self.mock = unittest.mock.MagicMock()
def test_event_creation_callback(self):
gpu_trace.register_callback_for_event_creation(self.mock)
cuda_trace.register_callback_for_cuda_event_creation(self.mock)
event = torch.cuda.Event()
event.record()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_event_deletion_callback(self):
gpu_trace.register_callback_for_event_deletion(self.mock)
cuda_trace.register_callback_for_cuda_event_deletion(self.mock)
event = torch.cuda.Event()
event.record()
@ -39,7 +39,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event_id)
def test_event_record_callback(self):
gpu_trace.register_callback_for_event_record(self.mock)
cuda_trace.register_callback_for_cuda_event_record(self.mock)
event = torch.cuda.Event()
event.record()
@ -48,7 +48,7 @@ class TestCudaTrace(TestCase):
)
def test_event_wait_callback(self):
gpu_trace.register_callback_for_event_wait(self.mock)
cuda_trace.register_callback_for_cuda_event_wait(self.mock)
event = torch.cuda.Event()
event.record()
@ -58,13 +58,13 @@ class TestCudaTrace(TestCase):
)
def test_memory_allocation_callback(self):
gpu_trace.register_callback_for_memory_allocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())
def test_memory_deallocation_callback(self):
gpu_trace.register_callback_for_memory_deallocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_deallocation(self.mock)
tensor = torch.empty(3, 8, device="cuda")
data_ptr = tensor.data_ptr()
@ -72,7 +72,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(data_ptr)
def test_stream_creation_callback(self):
gpu_trace.register_callback_for_stream_creation(self.mock)
cuda_trace.register_callback_for_cuda_stream_creation(self.mock)
# see Note [HIP Lazy Streams]
if torch.version.hip:
@ -85,20 +85,20 @@ class TestCudaTrace(TestCase):
self.mock.assert_called()
def test_device_synchronization_callback(self):
gpu_trace.register_callback_for_device_synchronization(self.mock)
cuda_trace.register_callback_for_cuda_device_synchronization(self.mock)
torch.cuda.synchronize()
self.mock.assert_called()
def test_stream_synchronization_callback(self):
gpu_trace.register_callback_for_stream_synchronization(self.mock)
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
stream = torch.cuda.Stream()
stream.synchronize()
self.mock.assert_called_once_with(stream.cuda_stream)
def test_event_synchronization_callback(self):
gpu_trace.register_callback_for_event_synchronization(self.mock)
cuda_trace.register_callback_for_cuda_event_synchronization(self.mock)
event = torch.cuda.Event()
event.record()
@ -106,7 +106,7 @@ class TestCudaTrace(TestCase):
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_memcpy_synchronization(self):
gpu_trace.register_callback_for_stream_synchronization(self.mock)
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
tensor = torch.rand(5, device="cuda")
tensor.nonzero()
@ -114,8 +114,8 @@ class TestCudaTrace(TestCase):
def test_all_trace_callbacks_called(self):
other = unittest.mock.MagicMock()
gpu_trace.register_callback_for_memory_allocation(self.mock)
gpu_trace.register_callback_for_memory_allocation(other)
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
cuda_trace.register_callback_for_cuda_memory_allocation(other)
tensor = torch.empty(10, 4, device="cuda")
self.mock.assert_called_once_with(tensor.data_ptr())

View File

@ -1273,7 +1273,6 @@ def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Any: ...
def _set_dispatch_mode(mode: Any) -> None: ...
def _get_dispatch_stack_at(idx: _int) -> Any: ...
def _len_torch_dispatch_stack() -> _int: ...
def _activate_gpu_trace() -> None: ...
class _DisableTorchDispatch:
def __init__(self): ...
@ -2156,6 +2155,7 @@ def _c10d_init() -> _bool: ...
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
def _faulty_agent_init() -> _bool: ...
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
def _activate_cuda_trace() -> None: ...
# Defined in torch/csrc/Module.cpp
def _current_graph_task_id() -> _int: ...

View File

@ -291,7 +291,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._assert_async",
"torch._assert_tensor_metadata",
"torch._batch_norm_impl_index",
"torch._C._activate_gpu_trace",
"torch._C._activate_cuda_trace",
"torch._C._add_cached_tensor",
"torch._C._add_docstr",
"torch._C._are_functorch_transforms_active",

View File

@ -1,13 +1,10 @@
import copyreg
import functools
import logging
import sys
import traceback
import warnings
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Generic, List, Optional
from typing_extensions import ParamSpec
from typing import Any, DefaultDict, List, Optional
import torch
@ -938,25 +935,3 @@ class _LazySeedTracker:
def get_calls(self) -> List:
return self.call_order
logger = logging.getLogger(__name__)
P = ParamSpec("P")
class CallbackRegistry(Generic[P]):
def __init__(self, name: str):
self.name = name
self.callback_list: List[Callable[P, None]] = []
def add_callback(self, cb: Callable[P, None]) -> None:
self.callback_list.append(cb)
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
for cb in self.callback_list:
try:
cb(*args, **kwargs)
except Exception as e:
logger.exception(
"Exception in callback for %s registered with gpu trace", self.name
)

View File

@ -18,22 +18,17 @@ namespace {
// because passing in constexpr char* as template argument breaks some
// versions of MSVC that are being used internally at Meta.
// MSVC 14.16.27023 (vs2017_15.9)
#define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
py::module utils_mod = py::module::import("torch._utils"); \
py::object get_device_module = utils_mod.attr("_get_device_module"); \
py::object hook = get_device_module(DeviceTypeName(device_type, true)) \
.attr("_gpu_trace") \
.attr(func_name) \
.attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << device_type \
<< " trace hook execution failed: " << e.what(); \
} \
#define CONCRETE_TRACE_CUDA(func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
py::module mod = py::module::import("torch.utils._cuda_trace"); \
py::object hook = mod.attr(func_name).attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \
} \
}
struct ConcretePyInterpreterVTable final
@ -88,51 +83,36 @@ struct ConcretePyInterpreterVTable final
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
void trace_gpu_event_creation(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks", event);
}
void trace_gpu_event_deletion(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks", event);
}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const override {
CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks", event, stream);
}
void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks", event, stream);
}
void trace_gpu_event_record(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
void trace_gpu_memory_allocation(uintptr_t ptr) const override {
CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks", ptr);
}
void trace_gpu_event_wait(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {
CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks", ptr);
}
void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
void trace_gpu_stream_creation(uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks", stream);
}
void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
void trace_gpu_device_synchronization() const override {
CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks");
}
void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
const override {
CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
void trace_gpu_stream_synchronization(uintptr_t stream) const override {
CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks", stream);
}
void trace_gpu_device_synchronization(
at::DeviceType device_type) const override {
CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
}
void trace_gpu_stream_synchronization(
at::DeviceType device_type,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
}
void trace_gpu_event_synchronization(
at::DeviceType device_type,
uintptr_t event) const override {
CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
void trace_gpu_event_synchronization(uintptr_t event) const override {
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
}
void reset_backward_hooks(const c10::TensorImpl* self) const override;

View File

@ -434,7 +434,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
}
});
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
_C_m, "_InferenceMode");

View File

@ -250,7 +250,7 @@ static PyObject* getPythonTensorClass(c10::Device d) {
return device_to_py_class_[static_cast<size_t>(d.type())];
}
void activateGPUTrace() {
void activateCUDATrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter());
}

View File

@ -31,7 +31,7 @@ TORCH_PYTHON_API void registerPythonTensorClass(
const std::string& device,
PyObject* python_tensor_class);
TORCH_PYTHON_API void activateGPUTrace();
TORCH_PYTHON_API void activateCUDATrace();
TORCH_PYTHON_API extern PyObject* THPVariableClass;
TORCH_PYTHON_API extern PyObject* ParameterClass;

View File

@ -1,75 +0,0 @@
from typing import Callable
from torch._utils import CallbackRegistry
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
EventCreationCallbacks.add_callback(cb)
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
EventDeletionCallbacks.add_callback(cb)
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
EventRecordCallbacks.add_callback(cb)
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
EventWaitCallbacks.add_callback(cb)
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
MemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
MemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
StreamCreationCallbacks.add_callback(cb)
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
DeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
StreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
EventSynchronizationCallbacks.add_callback(cb)

View File

@ -22,7 +22,7 @@ from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
import torch
import torch.cuda._gpu_trace as gpu_trace
import torch.utils._cuda_trace as cuda_trace
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
@ -528,35 +528,35 @@ class ArgumentHandler:
class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self):
self.event_handler = EventHandler()
torch._C._activate_gpu_trace()
gpu_trace.register_callback_for_event_creation(
torch._C._activate_cuda_trace()
cuda_trace.register_callback_for_cuda_event_creation(
self.event_handler._handle_event_creation
)
gpu_trace.register_callback_for_event_deletion(
cuda_trace.register_callback_for_cuda_event_deletion(
self.event_handler._handle_event_deletion
)
gpu_trace.register_callback_for_event_record(
cuda_trace.register_callback_for_cuda_event_record(
self.event_handler._handle_event_record
)
gpu_trace.register_callback_for_event_wait(
cuda_trace.register_callback_for_cuda_event_wait(
self.event_handler._handle_event_wait
)
gpu_trace.register_callback_for_memory_allocation(
cuda_trace.register_callback_for_cuda_memory_allocation(
self.event_handler._handle_memory_allocation
)
gpu_trace.register_callback_for_memory_deallocation(
cuda_trace.register_callback_for_cuda_memory_deallocation(
self.event_handler._handle_memory_deallocation
)
gpu_trace.register_callback_for_stream_creation(
cuda_trace.register_callback_for_cuda_stream_creation(
self.event_handler._handle_stream_creation
)
gpu_trace.register_callback_for_device_synchronization(
cuda_trace.register_callback_for_cuda_device_synchronization(
self.event_handler._handle_device_synchronization
)
gpu_trace.register_callback_for_stream_synchronization(
cuda_trace.register_callback_for_cuda_stream_synchronization(
self.event_handler._handle_stream_synchronization
)
gpu_trace.register_callback_for_event_synchronization(
cuda_trace.register_callback_for_cuda_event_synchronization(
self.event_handler._handle_event_synchronization
)

View File

@ -0,0 +1,99 @@
import logging
from typing import Callable, Generic, List
from typing_extensions import ParamSpec # Python 3.10+
logger = logging.getLogger(__name__)
P = ParamSpec("P")
class CallbackRegistry(Generic[P]):
def __init__(self, name: str):
self.name = name
self.callback_list: List[Callable[P, None]] = []
def add_callback(self, cb: Callable[P, None]) -> None:
self.callback_list.append(cb)
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
for cb in self.callback_list:
try:
cb(*args, **kwargs)
except Exception as e:
logger.exception(
"Exception in callback for %s registered with CUDA trace", self.name
)
CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
CUDAEventCreationCallbacks.add_callback(cb)
def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
CUDAEventDeletionCallbacks.add_callback(cb)
def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
CUDAEventRecordCallbacks.add_callback(cb)
def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
CUDAEventWaitCallbacks.add_callback(cb)
def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
CUDAMemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
CUDAMemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
CUDAStreamCreationCallbacks.add_callback(cb)
def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
CUDADeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_cuda_stream_synchronization(
cb: Callable[[int], None]
) -> None:
CUDAStreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
CUDAEventSynchronizationCallbacks.add_callback(cb)