[Memory Snapshot] Add Flag to Toggle Global and Local Callbacks for Annotations (#154932)

Summary:
There are some cases where we want only local annotations for memory snapshot such as executing inside the cudastream callback, which cannot execute CUDA operators. Thus the cuda errors happen: Exception in RecordFunction callback: CUDA error: operation not permitted

However, we need to have an option to turn on the globally so that on-demand snapshot can get annotations. Additionally, there may be some cases in which auto-trace will also want annotations using record functions so we expose the flag to the auto-trace as well.

Test Plan:
Run MVAI executable and see that the errors go away

Rollback Plan:

Differential Revision: D75831687

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154932
Approved by: https://github.com/mzzchy, https://github.com/sanrise
This commit is contained in:
Shivam Raikundalia
2025-06-04 23:15:16 +00:00
committed by PyTorch MergeBot
parent 7cf5b36ec2
commit 1083bc749d
4 changed files with 34 additions and 12 deletions

View File

@ -1125,7 +1125,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
m.def(
"_cuda_record_memory_history_legacy",
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool, bool)>(
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool, bool, bool)>(
torch::cuda::_record_memory_history));
m.def(
@ -1136,6 +1136,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
const std::string&,
size_t,
bool,
bool,
bool)>(torch::cuda::_record_memory_history));
m.def("_cuda_isHistoryEnabled", []() {

View File

@ -129,8 +129,11 @@ CapturedTraceback* getFromContext(
"attempting to gather stack context from the wrong StackContext type.");
}
at::CallbackHandle _initRecordAnnotations() {
return at::addGlobalCallback(
#define ADD_CALLBACK(callbackType) at::add##callbackType##Callback
at::CallbackHandle _initRecordAnnotations(bool useGlobalCallback) {
auto addCallback =
useGlobalCallback ? ADD_CALLBACK(Global) : ADD_CALLBACK(ThreadLocal);
return addCallback(
at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
@ -169,12 +172,16 @@ at::CallbackHandle _initCompileContexts() {
.scopes({at::RecordScope::FUNCTION}));
}
void setRecordFunctionCallbacks(bool enabled, bool compileContext) {
void setRecordFunctionCallbacks(
bool enabled,
bool compileContext,
bool globalRecordAnnotations) {
// Handle Callbacks under mutex
auto lock = callbackManager.lockCallbackMutex();
if (enabled) {
if (callbackManager.getAnnotationHandle() == 0) {
callbackManager.setAnnotationHandle(_initRecordAnnotations());
callbackManager.setAnnotationHandle(
_initRecordAnnotations(globalRecordAnnotations));
}
if (compileContext && callbackManager.getCompileContextHandle() == 0) {
callbackManager.setCompileContextHandle(_initCompileContexts());
@ -200,7 +207,8 @@ void _record_memory_history(
bool trace_alloc_record_context,
bool record_cpp_context,
bool clearHistory,
bool compileContext) {
bool compileContext,
bool globalRecordAnnotations) {
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
if (enabled && record_cpp_context &&
(trace_alloc_record_context || record_context)) {
@ -216,7 +224,7 @@ void _record_memory_history(
}
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
setRecordFunctionCallbacks(enabled, compileContext);
setRecordFunctionCallbacks(enabled, compileContext, globalRecordAnnotations);
c10::cuda::CUDACachingAllocator::recordHistory(
enabled, recorder, trace_alloc_max_entries, when, clearHistory);
}
@ -235,7 +243,8 @@ void _record_memory_history(
const std::string& stacks,
size_t max_entries,
bool clearHistory,
bool compileContext) {
bool compileContext,
bool globalRecordAnnotations) {
if (enabled) {
checkOptionIn(
*enabled,
@ -269,7 +278,8 @@ void _record_memory_history(
}
}
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
setRecordFunctionCallbacks(enabled.has_value(), compileContext);
setRecordFunctionCallbacks(
enabled.has_value(), compileContext, globalRecordAnnotations);
c10::cuda::CUDACachingAllocator::recordHistory(
enabled.has_value(), recorder, max_entries, when, clearHistory);
}

View File

@ -16,7 +16,8 @@ TORCH_CUDA_CU_API void _record_memory_history(
bool trace_alloc_record_context = false,
bool record_cpp_context = false,
bool clearHistory = false,
bool compileContext = false);
bool compileContext = false,
bool globalRecordAllocations = false);
TORCH_CUDA_CU_API void _record_memory_history(
std::optional<std::string> enabled = "all",
@ -24,7 +25,8 @@ TORCH_CUDA_CU_API void _record_memory_history(
const std::string& stacks = "all",
size_t max_entries = SIZE_MAX,
bool clearHistory = false,
bool compileContext = false);
bool compileContext = false,
bool globalRecordAllocations = false);
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();

View File

@ -847,6 +847,7 @@ def _record_memory_history_legacy(
record_context_cpp=False,
clear_history=False,
compile_context=False,
global_record_annotations=False,
):
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
enabled,
@ -856,6 +857,7 @@ def _record_memory_history_legacy(
record_context_cpp,
clear_history,
compile_context,
global_record_annotations,
)
@ -912,9 +914,16 @@ def _record_memory_history_impl(
device: "Device" = None,
clear_history: bool = False,
compile_context: bool = False,
global_record_annotations: bool = False,
):
_C._cuda_record_memory_history( # type: ignore[call-arg]
enabled, context, stacks, max_entries, clear_history, compile_context
enabled,
context,
stacks,
max_entries,
clear_history,
compile_context,
global_record_annotations,
)