mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor gpu trace to be device-agnostic (#121794)
# Motivation Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend. # Solution move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794 Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
This commit is contained in:
committed by
PyTorch MergeBot
parent
09ce76809c
commit
0ff1109e26
@ -1905,6 +1905,7 @@ exclude_patterns = [
|
||||
'torch/compiler/__init__.py',
|
||||
'torch/contrib/__init__.py',
|
||||
'torch/contrib/_tensorboard_vis.py',
|
||||
"torch/cuda/_gpu_trace.py",
|
||||
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
|
||||
'torch/distributed/__init__.py',
|
||||
'torch/distributed/_composable_state.py',
|
||||
@ -2432,7 +2433,6 @@ exclude_patterns = [
|
||||
'torch/utils/_contextlib.py',
|
||||
'torch/utils/_cpp_extension_versioner.py',
|
||||
'torch/utils/_crash_handler.py',
|
||||
'torch/utils/_cuda_trace.py',
|
||||
'torch/utils/_device.py',
|
||||
'torch/utils/_foreach_utils.py',
|
||||
'torch/utils/_freeze.py',
|
||||
|
@ -48,7 +48,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
|
||||
CUDAGuard guard(device_index_);
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_));
|
||||
(*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaEventDestroy(event_));
|
||||
}
|
||||
@ -122,7 +122,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
|
||||
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
(*interp)->trace_gpu_event_record(at::kCUDA,
|
||||
reinterpret_cast<uintptr_t>(event_),
|
||||
reinterpret_cast<uintptr_t>(stream.stream())
|
||||
);
|
||||
@ -138,7 +138,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
|
||||
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
(*interp)->trace_gpu_event_wait(at::kCUDA,
|
||||
reinterpret_cast<uintptr_t>(event_),
|
||||
reinterpret_cast<uintptr_t>(stream.stream())
|
||||
);
|
||||
@ -161,7 +161,7 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
|
||||
if (is_created_) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_));
|
||||
(*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaEventSynchronize(event_));
|
||||
}
|
||||
@ -191,7 +191,7 @@ private:
|
||||
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_));
|
||||
(*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
|
||||
}
|
||||
is_created_ = true;
|
||||
}
|
||||
|
@ -94,17 +94,32 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
}
|
||||
|
||||
// Just swallow the event, don't do anything
|
||||
void trace_gpu_event_creation(uintptr_t event) const override {}
|
||||
void trace_gpu_event_deletion(uintptr_t event) const override {}
|
||||
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
|
||||
void trace_gpu_event_creation(c10::DeviceType device_type, uintptr_t event)
|
||||
const override {}
|
||||
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {}
|
||||
void trace_gpu_memory_allocation(uintptr_t ptr) const override {}
|
||||
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {}
|
||||
void trace_gpu_stream_creation(uintptr_t stream) const override {}
|
||||
void trace_gpu_device_synchronization() const override {}
|
||||
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
|
||||
void trace_gpu_event_synchronization(uintptr_t event) const override {}
|
||||
void trace_gpu_event_deletion(c10::DeviceType device_type, uintptr_t event)
|
||||
const override {}
|
||||
void trace_gpu_event_record(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const override {}
|
||||
void trace_gpu_event_wait(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const override {}
|
||||
void trace_gpu_memory_allocation(c10::DeviceType device_type, uintptr_t ptr)
|
||||
const override {}
|
||||
void trace_gpu_memory_deallocation(c10::DeviceType device_type, uintptr_t ptr)
|
||||
const override {}
|
||||
void trace_gpu_stream_creation(c10::DeviceType device_type, uintptr_t stream)
|
||||
const override {}
|
||||
void trace_gpu_device_synchronization(
|
||||
c10::DeviceType device_type) const override {}
|
||||
void trace_gpu_stream_synchronization(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t stream) const override {}
|
||||
void trace_gpu_event_synchronization(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event) const override {}
|
||||
|
||||
void reset_backward_hooks(const TensorImpl* self) const override {
|
||||
PANIC(reset_backward_hooks);
|
||||
|
@ -177,18 +177,37 @@ struct C10_API PyInterpreterVTable {
|
||||
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
|
||||
|
||||
virtual void trace_gpu_event_creation(uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_deletion(uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
|
||||
const = 0;
|
||||
virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream)
|
||||
const = 0;
|
||||
virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0;
|
||||
virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0;
|
||||
virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_device_synchronization() const = 0;
|
||||
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_creation(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_deletion(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_record(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_event_wait(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_memory_allocation(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t ptr) const = 0;
|
||||
virtual void trace_gpu_memory_deallocation(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t ptr) const = 0;
|
||||
virtual void trace_gpu_stream_creation(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_device_synchronization(
|
||||
c10::DeviceType device_type) const = 0;
|
||||
virtual void trace_gpu_stream_synchronization(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_event_synchronization(
|
||||
c10::DeviceType device_type,
|
||||
uintptr_t event) const = 0;
|
||||
|
||||
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
|
||||
};
|
||||
|
@ -2840,7 +2840,8 @@ static void uncached_delete(void* ptr) {
|
||||
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_memory_deallocation(reinterpret_cast<uintptr_t>(ptr));
|
||||
(*interp)->trace_gpu_memory_deallocation(
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(ptr));
|
||||
}
|
||||
C10_CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
@ -2924,7 +2925,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_memory_allocation(
|
||||
reinterpret_cast<uintptr_t>(*devPtr));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(*devPtr));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2939,7 +2940,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_memory_deallocation(
|
||||
reinterpret_cast<uintptr_t>(block->ptr));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(block->ptr));
|
||||
}
|
||||
device_allocator[block->device]->free(block);
|
||||
}
|
||||
@ -3128,7 +3129,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_memory_allocation(
|
||||
reinterpret_cast<uintptr_t>(devPtr));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
|
||||
}
|
||||
} else {
|
||||
if (size != 0) {
|
||||
|
@ -136,7 +136,7 @@ void set_device(DeviceIndex device) {
|
||||
void device_synchronize() {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_device_synchronization();
|
||||
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
|
||||
}
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ C10_CUDA_API void __inline__ memcpy_and_sync(
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_stream_synchronization(
|
||||
reinterpret_cast<uintptr_t>(stream));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
|
||||
}
|
||||
#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
|
||||
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
|
||||
@ -105,7 +105,7 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_stream_synchronization(
|
||||
reinterpret_cast<uintptr_t>(stream));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
|
||||
}
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
@ -206,7 +206,8 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) {
|
||||
C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_stream_creation(reinterpret_cast<uintptr_t>(stream));
|
||||
(*interp)->trace_gpu_stream_creation(
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
|
||||
priority_counters[p][device_index] = 0;
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_creation(
|
||||
reinterpret_cast<uintptr_t>(cuda_event));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +111,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_deletion(
|
||||
reinterpret_cast<uintptr_t>(cuda_event));
|
||||
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
|
||||
}
|
||||
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
|
||||
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
|
||||
@ -146,6 +146,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_record(
|
||||
c10::kCUDA,
|
||||
reinterpret_cast<uintptr_t>(cuda_event),
|
||||
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
|
||||
}
|
||||
@ -168,6 +169,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_event_wait(
|
||||
c10::kCUDA,
|
||||
reinterpret_cast<uintptr_t>(cuda_event),
|
||||
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import unittest
|
||||
import unittest.mock
|
||||
|
||||
import torch
|
||||
import torch.utils._cuda_trace as cuda_trace
|
||||
import torch.cuda._gpu_trace as gpu_trace
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
|
||||
|
||||
# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks
|
||||
@ -19,18 +19,18 @@ if not TEST_CUDA:
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCudaTrace(TestCase):
|
||||
def setUp(self):
|
||||
torch._C._activate_cuda_trace()
|
||||
torch._C._activate_gpu_trace()
|
||||
self.mock = unittest.mock.MagicMock()
|
||||
|
||||
def test_event_creation_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_event_creation(self.mock)
|
||||
gpu_trace.register_callback_for_event_creation(self.mock)
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
self.mock.assert_called_once_with(event._as_parameter_.value)
|
||||
|
||||
def test_event_deletion_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_event_deletion(self.mock)
|
||||
gpu_trace.register_callback_for_event_deletion(self.mock)
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
@ -39,7 +39,7 @@ class TestCudaTrace(TestCase):
|
||||
self.mock.assert_called_once_with(event_id)
|
||||
|
||||
def test_event_record_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_event_record(self.mock)
|
||||
gpu_trace.register_callback_for_event_record(self.mock)
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
@ -48,7 +48,7 @@ class TestCudaTrace(TestCase):
|
||||
)
|
||||
|
||||
def test_event_wait_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_event_wait(self.mock)
|
||||
gpu_trace.register_callback_for_event_wait(self.mock)
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
@ -58,13 +58,13 @@ class TestCudaTrace(TestCase):
|
||||
)
|
||||
|
||||
def test_memory_allocation_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
|
||||
gpu_trace.register_callback_for_memory_allocation(self.mock)
|
||||
|
||||
tensor = torch.empty(10, 4, device="cuda")
|
||||
self.mock.assert_called_once_with(tensor.data_ptr())
|
||||
|
||||
def test_memory_deallocation_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_memory_deallocation(self.mock)
|
||||
gpu_trace.register_callback_for_memory_deallocation(self.mock)
|
||||
|
||||
tensor = torch.empty(3, 8, device="cuda")
|
||||
data_ptr = tensor.data_ptr()
|
||||
@ -72,7 +72,7 @@ class TestCudaTrace(TestCase):
|
||||
self.mock.assert_called_once_with(data_ptr)
|
||||
|
||||
def test_stream_creation_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_stream_creation(self.mock)
|
||||
gpu_trace.register_callback_for_stream_creation(self.mock)
|
||||
|
||||
# see Note [HIP Lazy Streams]
|
||||
if torch.version.hip:
|
||||
@ -85,20 +85,20 @@ class TestCudaTrace(TestCase):
|
||||
self.mock.assert_called()
|
||||
|
||||
def test_device_synchronization_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_device_synchronization(self.mock)
|
||||
gpu_trace.register_callback_for_device_synchronization(self.mock)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.mock.assert_called()
|
||||
|
||||
def test_stream_synchronization_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
|
||||
gpu_trace.register_callback_for_stream_synchronization(self.mock)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
stream.synchronize()
|
||||
self.mock.assert_called_once_with(stream.cuda_stream)
|
||||
|
||||
def test_event_synchronization_callback(self):
|
||||
cuda_trace.register_callback_for_cuda_event_synchronization(self.mock)
|
||||
gpu_trace.register_callback_for_event_synchronization(self.mock)
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
@ -106,7 +106,7 @@ class TestCudaTrace(TestCase):
|
||||
self.mock.assert_called_once_with(event._as_parameter_.value)
|
||||
|
||||
def test_memcpy_synchronization(self):
|
||||
cuda_trace.register_callback_for_cuda_stream_synchronization(self.mock)
|
||||
gpu_trace.register_callback_for_stream_synchronization(self.mock)
|
||||
|
||||
tensor = torch.rand(5, device="cuda")
|
||||
tensor.nonzero()
|
||||
@ -114,8 +114,8 @@ class TestCudaTrace(TestCase):
|
||||
|
||||
def test_all_trace_callbacks_called(self):
|
||||
other = unittest.mock.MagicMock()
|
||||
cuda_trace.register_callback_for_cuda_memory_allocation(self.mock)
|
||||
cuda_trace.register_callback_for_cuda_memory_allocation(other)
|
||||
gpu_trace.register_callback_for_memory_allocation(self.mock)
|
||||
gpu_trace.register_callback_for_memory_allocation(other)
|
||||
|
||||
tensor = torch.empty(10, 4, device="cuda")
|
||||
self.mock.assert_called_once_with(tensor.data_ptr())
|
||||
|
@ -1273,6 +1273,7 @@ def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Any: ...
|
||||
def _set_dispatch_mode(mode: Any) -> None: ...
|
||||
def _get_dispatch_stack_at(idx: _int) -> Any: ...
|
||||
def _len_torch_dispatch_stack() -> _int: ...
|
||||
def _activate_gpu_trace() -> None: ...
|
||||
|
||||
class _DisableTorchDispatch:
|
||||
def __init__(self): ...
|
||||
@ -2150,7 +2151,6 @@ def _c10d_init() -> _bool: ...
|
||||
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
|
||||
def _faulty_agent_init() -> _bool: ...
|
||||
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
|
||||
def _activate_cuda_trace() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _current_graph_task_id() -> _int: ...
|
||||
|
@ -291,7 +291,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._assert_async",
|
||||
"torch._assert_tensor_metadata",
|
||||
"torch._batch_norm_impl_index",
|
||||
"torch._C._activate_cuda_trace",
|
||||
"torch._C._activate_gpu_trace",
|
||||
"torch._C._add_cached_tensor",
|
||||
"torch._C._add_docstr",
|
||||
"torch._C._are_functorch_transforms_active",
|
||||
|
@ -1,10 +1,13 @@
|
||||
import copyreg
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, List, Optional
|
||||
from typing import Any, Callable, DefaultDict, Generic, List, Optional
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
@ -935,3 +938,25 @@ class _LazySeedTracker:
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class CallbackRegistry(Generic[P]):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.callback_list: List[Callable[P, None]] = []
|
||||
|
||||
def add_callback(self, cb: Callable[P, None]) -> None:
|
||||
self.callback_list.append(cb)
|
||||
|
||||
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
for cb in self.callback_list:
|
||||
try:
|
||||
cb(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception in callback for %s registered with gpu trace", self.name
|
||||
)
|
||||
|
@ -18,17 +18,22 @@ namespace {
|
||||
// because passing in constexpr char* as template argument breaks some
|
||||
// versions of MSVC that are being used internally at Meta.
|
||||
// MSVC 14.16.27023 (vs2017_15.9)
|
||||
#define CONCRETE_TRACE_CUDA(func_name, ...) \
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard; \
|
||||
if (Py_IsInitialized()) { \
|
||||
pybind11::gil_scoped_acquire gil; \
|
||||
try { \
|
||||
py::module mod = py::module::import("torch.utils._cuda_trace"); \
|
||||
py::object hook = mod.attr(func_name).attr("fire_callbacks"); \
|
||||
hook(__VA_ARGS__); \
|
||||
} catch (const std::exception& e) { \
|
||||
LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \
|
||||
} \
|
||||
#define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard; \
|
||||
if (Py_IsInitialized()) { \
|
||||
pybind11::gil_scoped_acquire gil; \
|
||||
try { \
|
||||
py::module utils_mod = py::module::import("torch._utils"); \
|
||||
py::object get_device_module = utils_mod.attr("_get_device_module"); \
|
||||
py::object hook = get_device_module(DeviceTypeName(device_type, true)) \
|
||||
.attr("_gpu_trace") \
|
||||
.attr(func_name) \
|
||||
.attr("fire_callbacks"); \
|
||||
hook(__VA_ARGS__); \
|
||||
} catch (const std::exception& e) { \
|
||||
LOG(ERROR) << device_type \
|
||||
<< " trace hook execution failed: " << e.what(); \
|
||||
} \
|
||||
}
|
||||
|
||||
struct ConcretePyInterpreterVTable final
|
||||
@ -83,36 +88,51 @@ struct ConcretePyInterpreterVTable final
|
||||
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
|
||||
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
|
||||
|
||||
void trace_gpu_event_creation(uintptr_t event) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks", event);
|
||||
}
|
||||
void trace_gpu_event_deletion(uintptr_t event) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks", event);
|
||||
}
|
||||
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
|
||||
void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
|
||||
const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks", event, stream);
|
||||
CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
|
||||
}
|
||||
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks", event, stream);
|
||||
void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
|
||||
const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
|
||||
}
|
||||
void trace_gpu_memory_allocation(uintptr_t ptr) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks", ptr);
|
||||
void trace_gpu_event_record(
|
||||
at::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
|
||||
}
|
||||
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks", ptr);
|
||||
void trace_gpu_event_wait(
|
||||
at::DeviceType device_type,
|
||||
uintptr_t event,
|
||||
uintptr_t stream) const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
|
||||
}
|
||||
void trace_gpu_stream_creation(uintptr_t stream) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks", stream);
|
||||
void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
|
||||
const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
|
||||
}
|
||||
void trace_gpu_device_synchronization() const override {
|
||||
CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks");
|
||||
void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
|
||||
const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
|
||||
}
|
||||
void trace_gpu_stream_synchronization(uintptr_t stream) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks", stream);
|
||||
void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
|
||||
const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
|
||||
}
|
||||
void trace_gpu_event_synchronization(uintptr_t event) const override {
|
||||
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
|
||||
void trace_gpu_device_synchronization(
|
||||
at::DeviceType device_type) const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
|
||||
}
|
||||
void trace_gpu_stream_synchronization(
|
||||
at::DeviceType device_type,
|
||||
uintptr_t stream) const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
|
||||
}
|
||||
void trace_gpu_event_synchronization(
|
||||
at::DeviceType device_type,
|
||||
uintptr_t event) const override {
|
||||
CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
|
||||
}
|
||||
|
||||
void reset_backward_hooks(const c10::TensorImpl* self) const override;
|
||||
|
@ -431,7 +431,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
}
|
||||
});
|
||||
|
||||
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
|
||||
_C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
|
||||
|
||||
py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
|
||||
_C_m, "_InferenceMode");
|
||||
|
@ -250,7 +250,7 @@ static PyObject* getPythonTensorClass(c10::Device d) {
|
||||
return device_to_py_class_[static_cast<size_t>(d.type())];
|
||||
}
|
||||
|
||||
void activateCUDATrace() {
|
||||
void activateGPUTrace() {
|
||||
c10::impl::GPUTrace::set_trace(getPyInterpreter());
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ TORCH_PYTHON_API void registerPythonTensorClass(
|
||||
const std::string& device,
|
||||
PyObject* python_tensor_class);
|
||||
|
||||
TORCH_PYTHON_API void activateCUDATrace();
|
||||
TORCH_PYTHON_API void activateGPUTrace();
|
||||
|
||||
TORCH_PYTHON_API extern PyObject* THPVariableClass;
|
||||
TORCH_PYTHON_API extern PyObject* ParameterClass;
|
||||
|
75
torch/cuda/_gpu_trace.py
Normal file
75
torch/cuda/_gpu_trace.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import Callable
|
||||
|
||||
from torch._utils import CallbackRegistry
|
||||
|
||||
|
||||
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event creation"
|
||||
)
|
||||
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event deletion"
|
||||
)
|
||||
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
||||
"CUDA event record"
|
||||
)
|
||||
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
||||
"CUDA event wait"
|
||||
)
|
||||
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA memory allocation"
|
||||
)
|
||||
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA memory deallocation"
|
||||
)
|
||||
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA stream creation"
|
||||
)
|
||||
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
|
||||
"CUDA device synchronization"
|
||||
)
|
||||
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA stream synchronization"
|
||||
)
|
||||
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event synchronization"
|
||||
)
|
||||
|
||||
|
||||
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
|
||||
EventCreationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
|
||||
EventDeletionCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
|
||||
EventRecordCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
|
||||
EventWaitCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
|
||||
MemoryAllocationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
|
||||
MemoryDeallocationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
|
||||
StreamCreationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
|
||||
DeviceSynchronizationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
|
||||
StreamSynchronizationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
|
||||
EventSynchronizationCallbacks.add_callback(cb)
|
@ -22,7 +22,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.utils._cuda_trace as cuda_trace
|
||||
import torch.cuda._gpu_trace as gpu_trace
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
@ -528,35 +528,35 @@ class ArgumentHandler:
|
||||
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
self.event_handler = EventHandler()
|
||||
torch._C._activate_cuda_trace()
|
||||
cuda_trace.register_callback_for_cuda_event_creation(
|
||||
torch._C._activate_gpu_trace()
|
||||
gpu_trace.register_callback_for_event_creation(
|
||||
self.event_handler._handle_event_creation
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_event_deletion(
|
||||
gpu_trace.register_callback_for_event_deletion(
|
||||
self.event_handler._handle_event_deletion
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_event_record(
|
||||
gpu_trace.register_callback_for_event_record(
|
||||
self.event_handler._handle_event_record
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_event_wait(
|
||||
gpu_trace.register_callback_for_event_wait(
|
||||
self.event_handler._handle_event_wait
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_memory_allocation(
|
||||
gpu_trace.register_callback_for_memory_allocation(
|
||||
self.event_handler._handle_memory_allocation
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_memory_deallocation(
|
||||
gpu_trace.register_callback_for_memory_deallocation(
|
||||
self.event_handler._handle_memory_deallocation
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_stream_creation(
|
||||
gpu_trace.register_callback_for_stream_creation(
|
||||
self.event_handler._handle_stream_creation
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_device_synchronization(
|
||||
gpu_trace.register_callback_for_device_synchronization(
|
||||
self.event_handler._handle_device_synchronization
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_stream_synchronization(
|
||||
gpu_trace.register_callback_for_stream_synchronization(
|
||||
self.event_handler._handle_stream_synchronization
|
||||
)
|
||||
cuda_trace.register_callback_for_cuda_event_synchronization(
|
||||
gpu_trace.register_callback_for_event_synchronization(
|
||||
self.event_handler._handle_event_synchronization
|
||||
)
|
||||
|
||||
|
@ -1,99 +0,0 @@
|
||||
import logging
|
||||
from typing import Callable, Generic, List
|
||||
|
||||
from typing_extensions import ParamSpec # Python 3.10+
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class CallbackRegistry(Generic[P]):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.callback_list: List[Callable[P, None]] = []
|
||||
|
||||
def add_callback(self, cb: Callable[P, None]) -> None:
|
||||
self.callback_list.append(cb)
|
||||
|
||||
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
for cb in self.callback_list:
|
||||
try:
|
||||
cb(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Exception in callback for %s registered with CUDA trace", self.name
|
||||
)
|
||||
|
||||
|
||||
CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event creation"
|
||||
)
|
||||
CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event deletion"
|
||||
)
|
||||
CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
||||
"CUDA event record"
|
||||
)
|
||||
CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
|
||||
"CUDA event wait"
|
||||
)
|
||||
CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA memory allocation"
|
||||
)
|
||||
CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA memory deallocation"
|
||||
)
|
||||
CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA stream creation"
|
||||
)
|
||||
CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
|
||||
"CUDA device synchronization"
|
||||
)
|
||||
CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA stream synchronization"
|
||||
)
|
||||
CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
|
||||
"CUDA event synchronization"
|
||||
)
|
||||
|
||||
|
||||
def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
|
||||
CUDAEventCreationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
|
||||
CUDAEventDeletionCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
|
||||
CUDAEventRecordCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
|
||||
CUDAEventWaitCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
|
||||
CUDAMemoryAllocationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
|
||||
CUDAMemoryDeallocationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
|
||||
CUDAStreamCreationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
|
||||
CUDADeviceSynchronizationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_stream_synchronization(
|
||||
cb: Callable[[int], None]
|
||||
) -> None:
|
||||
CUDAStreamSynchronizationCallbacks.add_callback(cb)
|
||||
|
||||
|
||||
def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
|
||||
CUDAEventSynchronizationCallbacks.add_callback(cb)
|
Reference in New Issue
Block a user