mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Memory Snapshot] Add Flag to Toggle Global and Local Callbacks for Annotations (#154932)
Summary: There are some cases where we want only local annotations for memory snapshot such as executing inside the cudastream callback, which cannot execute CUDA operators. Thus the cuda errors happen: Exception in RecordFunction callback: CUDA error: operation not permitted However, we need to have an option to turn on the globally so that on-demand snapshot can get annotations. Additionally, there may be some cases in which auto-trace will also want annotations using record functions so we expose the flag to the auto-trace as well. Test Plan: Run MVAI executable and see that the errors go away Rollback Plan: Differential Revision: D75831687 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154932 Approved by: https://github.com/mzzchy, https://github.com/sanrise
This commit is contained in:
committed by
PyTorch MergeBot
parent
7cf5b36ec2
commit
1083bc749d
@ -1125,7 +1125,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
|
||||
|
||||
m.def(
|
||||
"_cuda_record_memory_history_legacy",
|
||||
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool, bool)>(
|
||||
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool, bool, bool)>(
|
||||
torch::cuda::_record_memory_history));
|
||||
|
||||
m.def(
|
||||
@ -1136,6 +1136,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
|
||||
const std::string&,
|
||||
size_t,
|
||||
bool,
|
||||
bool,
|
||||
bool)>(torch::cuda::_record_memory_history));
|
||||
|
||||
m.def("_cuda_isHistoryEnabled", []() {
|
||||
|
@ -129,8 +129,11 @@ CapturedTraceback* getFromContext(
|
||||
"attempting to gather stack context from the wrong StackContext type.");
|
||||
}
|
||||
|
||||
at::CallbackHandle _initRecordAnnotations() {
|
||||
return at::addGlobalCallback(
|
||||
#define ADD_CALLBACK(callbackType) at::add##callbackType##Callback
|
||||
at::CallbackHandle _initRecordAnnotations(bool useGlobalCallback) {
|
||||
auto addCallback =
|
||||
useGlobalCallback ? ADD_CALLBACK(Global) : ADD_CALLBACK(ThreadLocal);
|
||||
return addCallback(
|
||||
at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction& fn)
|
||||
-> std::unique_ptr<at::ObserverContext> {
|
||||
@ -169,12 +172,16 @@ at::CallbackHandle _initCompileContexts() {
|
||||
.scopes({at::RecordScope::FUNCTION}));
|
||||
}
|
||||
|
||||
void setRecordFunctionCallbacks(bool enabled, bool compileContext) {
|
||||
void setRecordFunctionCallbacks(
|
||||
bool enabled,
|
||||
bool compileContext,
|
||||
bool globalRecordAnnotations) {
|
||||
// Handle Callbacks under mutex
|
||||
auto lock = callbackManager.lockCallbackMutex();
|
||||
if (enabled) {
|
||||
if (callbackManager.getAnnotationHandle() == 0) {
|
||||
callbackManager.setAnnotationHandle(_initRecordAnnotations());
|
||||
callbackManager.setAnnotationHandle(
|
||||
_initRecordAnnotations(globalRecordAnnotations));
|
||||
}
|
||||
if (compileContext && callbackManager.getCompileContextHandle() == 0) {
|
||||
callbackManager.setCompileContextHandle(_initCompileContexts());
|
||||
@ -200,7 +207,8 @@ void _record_memory_history(
|
||||
bool trace_alloc_record_context,
|
||||
bool record_cpp_context,
|
||||
bool clearHistory,
|
||||
bool compileContext) {
|
||||
bool compileContext,
|
||||
bool globalRecordAnnotations) {
|
||||
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
|
||||
if (enabled && record_cpp_context &&
|
||||
(trace_alloc_record_context || record_context)) {
|
||||
@ -216,7 +224,7 @@ void _record_memory_history(
|
||||
}
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
|
||||
setRecordFunctionCallbacks(enabled, compileContext);
|
||||
setRecordFunctionCallbacks(enabled, compileContext, globalRecordAnnotations);
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled, recorder, trace_alloc_max_entries, when, clearHistory);
|
||||
}
|
||||
@ -235,7 +243,8 @@ void _record_memory_history(
|
||||
const std::string& stacks,
|
||||
size_t max_entries,
|
||||
bool clearHistory,
|
||||
bool compileContext) {
|
||||
bool compileContext,
|
||||
bool globalRecordAnnotations) {
|
||||
if (enabled) {
|
||||
checkOptionIn(
|
||||
*enabled,
|
||||
@ -269,7 +278,8 @@ void _record_memory_history(
|
||||
}
|
||||
}
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
setRecordFunctionCallbacks(enabled.has_value(), compileContext);
|
||||
setRecordFunctionCallbacks(
|
||||
enabled.has_value(), compileContext, globalRecordAnnotations);
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled.has_value(), recorder, max_entries, when, clearHistory);
|
||||
}
|
||||
|
@ -16,7 +16,8 @@ TORCH_CUDA_CU_API void _record_memory_history(
|
||||
bool trace_alloc_record_context = false,
|
||||
bool record_cpp_context = false,
|
||||
bool clearHistory = false,
|
||||
bool compileContext = false);
|
||||
bool compileContext = false,
|
||||
bool globalRecordAllocations = false);
|
||||
|
||||
TORCH_CUDA_CU_API void _record_memory_history(
|
||||
std::optional<std::string> enabled = "all",
|
||||
@ -24,7 +25,8 @@ TORCH_CUDA_CU_API void _record_memory_history(
|
||||
const std::string& stacks = "all",
|
||||
size_t max_entries = SIZE_MAX,
|
||||
bool clearHistory = false,
|
||||
bool compileContext = false);
|
||||
bool compileContext = false,
|
||||
bool globalRecordAllocations = false);
|
||||
|
||||
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
|
||||
|
||||
|
@ -847,6 +847,7 @@ def _record_memory_history_legacy(
|
||||
record_context_cpp=False,
|
||||
clear_history=False,
|
||||
compile_context=False,
|
||||
global_record_annotations=False,
|
||||
):
|
||||
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
|
||||
enabled,
|
||||
@ -856,6 +857,7 @@ def _record_memory_history_legacy(
|
||||
record_context_cpp,
|
||||
clear_history,
|
||||
compile_context,
|
||||
global_record_annotations,
|
||||
)
|
||||
|
||||
|
||||
@ -912,9 +914,16 @@ def _record_memory_history_impl(
|
||||
device: "Device" = None,
|
||||
clear_history: bool = False,
|
||||
compile_context: bool = False,
|
||||
global_record_annotations: bool = False,
|
||||
):
|
||||
_C._cuda_record_memory_history( # type: ignore[call-arg]
|
||||
enabled, context, stacks, max_entries, clear_history, compile_context
|
||||
enabled,
|
||||
context,
|
||||
stacks,
|
||||
max_entries,
|
||||
clear_history,
|
||||
compile_context,
|
||||
global_record_annotations,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user