[Mem Snapshot] Add Metadata Field (#165490)

Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
This commit is contained in:
Shivam Raikundalia
2025-10-17 23:46:02 +00:00
committed by PyTorch MergeBot
parent 69c33898fa
commit a25a649e70
7 changed files with 103 additions and 3 deletions

View File

@ -1260,6 +1260,9 @@ class DeviceCachingAllocator {
// thread local compile context for each device
static thread_local std::stack<std::string> compile_context;
// thread local user metadata for annotating allocations
static thread_local std::string user_metadata;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit DeviceCachingAllocator(c10::DeviceIndex id)
@ -1302,6 +1305,14 @@ class DeviceCachingAllocator {
}
}
void setUserMetadata(const std::string& metadata) {
user_metadata = metadata;
}
std::string getUserMetadata() {
return user_metadata;
}
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) const {
@ -3682,7 +3693,8 @@ class DeviceCachingAllocator {
mempool_id,
getApproximateTime(),
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
compile_string);
compile_string,
user_metadata);
// Callbacks should not include any Pytorch call
for (const auto& cb : trace_trackers_) {
@ -3737,6 +3749,7 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
@ -3934,6 +3947,18 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->popCompileContext();
}
void setUserMetadata(const std::string& metadata) override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->setUserMetadata(metadata);
}
std::string getUserMetadata() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return device_allocator[device]->getUserMetadata();
}
bool isHistoryEnabled() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));

View File

@ -118,7 +118,8 @@ struct TraceEntry {
MempoolId_t mempool,
approx_time_t time,
std::shared_ptr<GatheredContext> context = nullptr,
std::string compile_context = "")
std::string compile_context = "",
std::string user_metadata = "")
: action_(action),
device_(device),
addr_(addr),
@ -126,7 +127,8 @@ struct TraceEntry {
stream_(stream),
size_(size),
mempool_(std::move(mempool)),
compile_context_(std::move(compile_context)) {
compile_context_(std::move(compile_context)),
user_metadata_(std::move(user_metadata)) {
time_.approx_t_ = time;
}
Action action_;
@ -138,6 +140,7 @@ struct TraceEntry {
MempoolId_t mempool_;
trace_time_ time_{};
std::string compile_context_;
std::string user_metadata_;
};
// Calls made by record_function will save annotations
@ -297,6 +300,10 @@ class CUDAAllocator : public DeviceAllocator {
const std::vector<std::pair<std::string, std::string>>& /*md*/) {}
virtual void pushCompileContext(std::string& md) {}
virtual void popCompileContext() {}
virtual void setUserMetadata(const std::string& metadata) {}
virtual std::string getUserMetadata() {
return "";
}
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
// Attached AllocatorTraceTracker callbacks will be called while the
@ -536,6 +543,14 @@ inline void enablePeerAccess(
get()->enablePeerAccess(dev, dev_to_access);
}
inline void setUserMetadata(const std::string& metadata) {
get()->setUserMetadata(metadata);
}
inline std::string getUserMetadata() {
return get()->getUserMetadata();
}
} // namespace c10::cuda::CUDACachingAllocator
namespace c10::cuda {