[Mem Snapshot] Add Metadata Field (#165490)

Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
This commit is contained in:
Shivam Raikundalia
2025-10-17 23:46:02 +00:00
committed by PyTorch MergeBot
parent 69c33898fa
commit a25a649e70
7 changed files with 103 additions and 3 deletions

View File

@ -1260,6 +1260,9 @@ class DeviceCachingAllocator {
// thread local compile context for each device
static thread_local std::stack<std::string> compile_context;
// thread local user metadata for annotating allocations
static thread_local std::string user_metadata;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit DeviceCachingAllocator(c10::DeviceIndex id)
@ -1302,6 +1305,14 @@ class DeviceCachingAllocator {
}
}
void setUserMetadata(const std::string& metadata) {
user_metadata = metadata;
}
std::string getUserMetadata() {
return user_metadata;
}
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) const {
@ -3682,7 +3693,8 @@ class DeviceCachingAllocator {
mempool_id,
getApproximateTime(),
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
compile_string);
compile_string,
user_metadata);
// Callbacks should not include any Pytorch call
for (const auto& cb : trace_trackers_) {
@ -3737,6 +3749,7 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
@ -3934,6 +3947,18 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->popCompileContext();
}
void setUserMetadata(const std::string& metadata) override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->setUserMetadata(metadata);
}
std::string getUserMetadata() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return device_allocator[device]->getUserMetadata();
}
bool isHistoryEnabled() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));

View File

@ -118,7 +118,8 @@ struct TraceEntry {
MempoolId_t mempool,
approx_time_t time,
std::shared_ptr<GatheredContext> context = nullptr,
std::string compile_context = "")
std::string compile_context = "",
std::string user_metadata = "")
: action_(action),
device_(device),
addr_(addr),
@ -126,7 +127,8 @@ struct TraceEntry {
stream_(stream),
size_(size),
mempool_(std::move(mempool)),
compile_context_(std::move(compile_context)) {
compile_context_(std::move(compile_context)),
user_metadata_(std::move(user_metadata)) {
time_.approx_t_ = time;
}
Action action_;
@ -138,6 +140,7 @@ struct TraceEntry {
MempoolId_t mempool_;
trace_time_ time_{};
std::string compile_context_;
std::string user_metadata_;
};
// Calls made by record_function will save annotations
@ -297,6 +300,10 @@ class CUDAAllocator : public DeviceAllocator {
const std::vector<std::pair<std::string, std::string>>& /*md*/) {}
virtual void pushCompileContext(std::string& md) {}
virtual void popCompileContext() {}
virtual void setUserMetadata(const std::string& metadata) {}
virtual std::string getUserMetadata() {
return "";
}
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
// Attached AllocatorTraceTracker callbacks will be called while the
@ -536,6 +543,14 @@ inline void enablePeerAccess(
get()->enablePeerAccess(dev, dev_to_access);
}
inline void setUserMetadata(const std::string& metadata) {
get()->setUserMetadata(metadata);
}
inline std::string getUserMetadata() {
return get()->getUserMetadata();
}
} // namespace c10::cuda::CUDACachingAllocator
namespace c10::cuda {

View File

@ -4378,6 +4378,28 @@ class TestCudaMallocAsync(TestCase):
finally:
torch.cuda.memory._record_memory_history(None)
@unittest.skipIf(
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
)
@requiresCppContext
def test_memory_plots_metadata(self):
for context in ["alloc", "all", "state"]:
try:
torch._C._cuda_clearCublasWorkspaces()
torch.cuda.memory.empty_cache()
torch.cuda.memory._set_memory_metadata("metadata test")
torch.cuda.memory._record_memory_history(context="all")
x = torch.rand(3, 4, device="cuda")
del x
torch.cuda.memory.empty_cache()
torch.cuda.memory._set_memory_metadata("")
ss = torch.cuda.memory._snapshot()
for event in ss["device_traces"][0]:
self.assertTrue(event["user_metadata"] == "metadata test")
finally:
torch.cuda.memory._record_memory_history(None)
@unittest.skipIf(
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
)

View File

@ -2081,6 +2081,8 @@ def _cuda_hostMemoryStats() -> dict[str, Any]: ...
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
def _cuda_resetPeakHostMemoryStats() -> None: ...
def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ...
def _cuda_setMemoryMetadata(metadata: str) -> None: ...
def _cuda_getMemoryMetadata() -> str: ...
def _cuda_record_memory_history_legacy(
enabled: _bool,
record_context: _bool,

View File

@ -765,6 +765,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
py::str frames_s = "frames";
py::str time_us_s = "time_us";
py::str compile_context_s = "compile_context";
py::str user_metadata_s = "user_metadata";
py::list empty_frames;
std::vector<CapturedTraceback*> to_gather_frames;
@ -882,6 +883,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
trace_entry[stream_s] = int64_t(te.stream_);
trace_entry[time_us_s] = te.time_.t_;
trace_entry[compile_context_s] = te.compile_context_;
trace_entry[user_metadata_s] = te.user_metadata_;
trace.append(trace_entry);
}
traces.append(trace);
@ -1137,6 +1139,14 @@ static void registerCudaDeviceProperties(PyObject* module) {
return c10::cuda::CUDACachingAllocator::isHistoryEnabled();
});
m.def("_cuda_setMemoryMetadata", [](const std::string& metadata) {
c10::cuda::CUDACachingAllocator::setUserMetadata(metadata);
});
m.def("_cuda_getMemoryMetadata", []() {
return c10::cuda::CUDACachingAllocator::getUserMetadata();
});
m.def("_cuda_get_conv_benchmark_empty_cache", []() {
return at::native::_cudnn_get_conv_benchmark_empty_cache();
});

View File

@ -311,6 +311,7 @@ std::string _memory_snapshot_pickled() {
IValue is_expandable_s = "is_expandable";
IValue time_us_s = "time_us";
IValue compile_contexts_s = "compile_context";
IValue user_metadata_s = "user_metadata";
auto empty_frames = new_list();
@ -428,6 +429,7 @@ std::string _memory_snapshot_pickled() {
trace_entry.insert(size_s, (int64_t)te.size_);
trace_entry.insert(stream_s, int64_t(te.stream_));
trace_entry.insert(compile_contexts_s, te.compile_context_);
trace_entry.insert(user_metadata_s, te.user_metadata_);
if (te.context_) {
auto sc = getFromContext(te.context_);
frame_tracebacks.push_back(sc);

View File

@ -1063,6 +1063,30 @@ def _dump_snapshot(filename="dump_snapshot.pickle"):
pickle.dump(s, f)
def _set_memory_metadata(metadata: str):
"""
Set custom metadata that will be attached to all subsequent CUDA memory allocations.
This metadata will be recorded in the memory snapshot for all allocations made
after this call until the metadata is cleared or changed.
Args:
metadata (str): Custom metadata string to attach to allocations.
Pass an empty string to clear the metadata.
"""
torch._C._cuda_setMemoryMetadata(metadata)
def _get_memory_metadata() -> str:
"""
Get the current custom metadata that is being attached to CUDA memory allocations.
Returns:
str: The current metadata string, or empty string if no metadata is set.
"""
return torch._C._cuda_getMemoryMetadata()
def _save_segment_usage(filename="output.svg", snapshot=None):
if snapshot is None:
snapshot = _snapshot()