[submodule] [Snapshot/Profiler] Memory Snapshot On Demand (#150559)

Summary:
Profiler side of memory snapshot.

1. Add API to actually do snapshot when client interface is called
2. Add ifdefs to builds so that kineto hooks snapshot correctly.

Design Philosophy: There is one interesting part of this implementation and it is during export. For export we are callign the python impl of the export rather than CPP even though we are already in CPP. This is because it is better to simply have one path of export rather than 2. Personally, I want there to be parity between auto-trace and on-demand so it if we can limit the side paths then we will have an easier time maintaining this relationship

Test Plan: {F1976563426}

Reviewed By: sanrise

Differential Revision: D70733247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150559
Approved by: https://github.com/sanrise
This commit is contained in:
Shivam Raikundalia
2025-04-07 13:04:38 +00:00
committed by PyTorch MergeBot
parent e209625334
commit 99c9a31386
10 changed files with 160 additions and 6 deletions

View File

@ -1725,6 +1725,7 @@ def define_buck_targets(
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
"-DUSE_KINETO",
"-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND",
# Need this otherwise USE_KINETO is undefed
# for mobile
"-DEDGE_PROFILER_USE_KINETO",
@ -1750,6 +1751,7 @@ def define_buck_targets(
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
"-DUSE_KINETO",
"-DEDGE_PROFILER_USE_KINETO",
"-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND",
],
# @lint-ignore BUCKLINT link_whole
link_whole = True,
@ -1836,6 +1838,7 @@ def define_buck_targets(
# Need this otherwise USE_KINETO is undefed
# for mobile
"-DEDGE_PROFILER_USE_KINETO",
"-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND",
] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
extra_flags = {
"fbandroid_compiler_flags": ["-frtti"],

View File

@ -1715,7 +1715,7 @@ if(USE_KINETO)
set_property(TARGET kineto PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS kineto)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO")
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND")
if(LIBKINETO_NOCUPTI)
string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI")
endif()

View File

@ -27,7 +27,7 @@ add_library(backend_with_compiler SHARED
)
if(USE_KINETO)
set_target_properties(backend_with_compiler PROPERTIES COMPILE_FLAGS
"-DUSE_KINETO")
"-DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND")
endif()
target_link_libraries(backend_with_compiler torch)

View File

@ -8,7 +8,6 @@
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/overloaded.h>
#include <torch/csrc/profiler/api.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
@ -21,8 +20,6 @@
#include <torch/csrc/profiler/standalone/privateuse1_observer.h>
#include <torch/csrc/profiler/util.h>
#include <ATen/Context.h>
#include <stdexcept>
#include <utility>
@ -860,6 +857,22 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
return result;
}
namespace tracer = torch::profiler::impl::python_tracer;
std::unique_ptr<tracer::PythonMemoryTracerBase> memory_tracer;
void startMemoryProfile() {
if (memory_tracer == nullptr) {
memory_tracer = tracer::PythonMemoryTracerBase::make();
}
memory_tracer->start();
}
void stopMemoryProfile() {
memory_tracer->stop();
}
void exportMemoryProfile(const std::string& filename) {
memory_tracer->export_memory_history(filename);
}
KinetoEvent::KinetoEvent(
const std::shared_ptr<const torch::profiler::impl::Result>& result,

View File

@ -185,6 +185,10 @@ TORCH_API void toggleCollectionDynamic(
const bool enable,
const std::set<torch::profiler::impl::ActivityType>& activities);
TORCH_API void startMemoryProfile();
TORCH_API void stopMemoryProfile();
TORCH_API void exportMemoryProfile(const std::string& path);
/**
* When a C++ thread really has no control over how the profiler was enabled,
* for example, by some unreachable Python code, it can call these functions

View File

@ -27,6 +27,7 @@
#include <torch/csrc/profiler/util.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <optional>
@ -1144,6 +1145,81 @@ std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
return out;
}
// ============================================================================
// == Memory Tracer ======================================================
// ============================================================================
// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere
class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase {
public:
explicit PythonMemoryTracer();
~PythonMemoryTracer() override;
void start() override;
void stop() override;
void export_memory_history(const std::string path) override;
};
PythonMemoryTracer::PythonMemoryTracer() {}
PythonMemoryTracer::~PythonMemoryTracer() {}
static void toggle_memory_tracing(bool enable) {
PyGILState_STATE gil_state = PyGILState_Ensure();
THPObjectPtr torch_cuda_memory_module(
PyImport_ImportModule("torch.cuda.memory"));
if (!torch_cuda_memory_module) {
return;
}
THPObjectPtr snapshot_func(PyObject_GetAttrString(
torch_cuda_memory_module.get(), "_record_memory_history_impl"));
if (!snapshot_func) {
return;
}
// Call the function with arguments
PyObject* args = PyTuple_New(6);
PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None);
PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context
PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks
PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries
PyTuple_SetItem(args, 4, Py_None); // device (None)
PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False)
PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL);
Py_DECREF(args);
if (result == NULL) {
return;
}
PyGILState_Release(gil_state);
}
void PythonMemoryTracer::start() {
toggle_memory_tracing(true);
}
void PythonMemoryTracer::export_memory_history(const std::string path) {
PyGILState_STATE gil_state = PyGILState_Ensure();
THPObjectPtr torch_cuda_memory_module(
PyImport_ImportModule("torch.cuda.memory"));
if (!torch_cuda_memory_module) {
return;
}
THPObjectPtr snapshot_func(
PyObject_GetAttrString(torch_cuda_memory_module.get(), "_dump_snapshot"));
if (!snapshot_func) {
return;
}
PyObject* py_filename = PyUnicode_FromString(path.c_str());
// Call the function with arguments (e.g., a file path)
PyObject* args = PyTuple_Pack(1, py_filename);
PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL);
Py_DECREF(args);
if (result == NULL) {
return;
}
PyGILState_Release(gil_state);
}
void PythonMemoryTracer::stop() {
toggle_memory_tracing(false);
}
// ============================================================================
// == API =====================================================================
@ -1181,6 +1257,11 @@ std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
torch::profiler::impl::RecordQueue* queue) {
return std::make_unique<PythonTracer>(queue);
}
std::unique_ptr<python_tracer::PythonMemoryTracerBase> getMemoryTracer() {
return std::make_unique<PythonMemoryTracer>();
}
} // namespace
} // namespace torch::profiler::impl
@ -1191,5 +1272,7 @@ void init() {
TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
torch::profiler::impl::python_tracer::registerTracer(
&torch::profiler::impl::getTracer);
torch::profiler::impl::python_tracer::registerMemoryTracer(
&torch::profiler::impl::getMemoryTracer);
}
} // namespace torch::autograd::profiler::python_tracer

View File

@ -58,6 +58,20 @@ class LibKinetoClient : public libkineto::ClientInterface {
(void)disableProfiler();
}
void start_memory_profile() override {
LOG(INFO) << "Starting on-demand memory profile";
startMemoryProfile();
}
void stop_memory_profile() override {
LOG(INFO) << "Stopping on-demand memory profile";
stopMemoryProfile();
}
void export_memory_profile(const std::string& path) override {
exportMemoryProfile(path);
}
private:
// Temporarily disable shape collection until
// we re-roll out the feature for on-demand cases

View File

@ -3,6 +3,7 @@
namespace torch::profiler::impl::python_tracer {
namespace {
MakeFn make_fn;
MakeMemoryFn memory_make_fn;
struct NoOpPythonTracer : public PythonTracerBase {
NoOpPythonTracer() = default;
@ -17,6 +18,15 @@ struct NoOpPythonTracer : public PythonTracerBase {
return {};
}
};
struct NoOpMemoryPythonTracer : public PythonMemoryTracerBase {
NoOpMemoryPythonTracer() = default;
~NoOpMemoryPythonTracer() override = default;
void start() override {}
void stop() override {}
void export_memory_history(const std::string path) override {}
};
} // namespace
void registerTracer(MakeFn make_tracer) {
@ -29,4 +39,15 @@ std::unique_ptr<PythonTracerBase> PythonTracerBase::make(RecordQueue* queue) {
}
return make_fn(queue);
}
void registerMemoryTracer(MakeMemoryFn make_memory_tracer) {
memory_make_fn = make_memory_tracer;
}
std::unique_ptr<PythonMemoryTracerBase> PythonMemoryTracerBase::make() {
if (memory_make_fn == nullptr) {
return std::make_unique<NoOpMemoryPythonTracer>();
}
return memory_make_fn();
}
} // namespace torch::profiler::impl::python_tracer

View File

@ -56,5 +56,21 @@ struct TORCH_API PythonTracerBase {
using MakeFn = std::unique_ptr<PythonTracerBase> (*)(RecordQueue*);
TORCH_API void registerTracer(MakeFn make_tracer);
/**
* Memory Tracer Implementation
*/
struct TORCH_API PythonMemoryTracerBase {
static std::unique_ptr<PythonMemoryTracerBase> make();
virtual ~PythonMemoryTracerBase() = default;
virtual void start() = 0;
virtual void stop() = 0;
virtual void export_memory_history(const std::string path) = 0;
};
using MakeMemoryFn = std::unique_ptr<PythonMemoryTracerBase> (*)();
TORCH_API void registerMemoryTracer(MakeMemoryFn make_memory_tracer);
} // namespace python_tracer
} // namespace torch::profiler::impl