Switch storage ID refresh to insert_or_assign

This commit is contained in:
Natalia Gimelshein
2025-11-07 20:11:46 -08:00
parent 29d6bb79e1
commit 9653671a67

View File

@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
ID get_tensor_storage_ID(const c10::Storage& t_storage) {
const std::lock_guard<std::recursive_mutex> lock(gMutex);
const void* raw_data_ptr = t_storage.data();
auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
if (iter == data_ptr_to_weak_storage_ptr.end()) {
const void* raw_data_ptr = nullptr;
bool should_track_liveness = false;
// FakeTensor/FunctionalTensor may clear the Storage handle entirely or use
// a nullptr data pointer. Treat both cases as a shared cache key but avoid
// touching the weak-ref table so they can reuse the same ID without
// tripping the liveness check.
if (t_storage.unsafeGetStorageImpl()) {
raw_data_ptr = t_storage.data();
should_track_liveness = raw_data_ptr != nullptr;
}
auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
if (!should_track_liveness) {
if (id_iter != data_ptr_to_storage_id.end()) {
return id_iter->second;
}
ID id = storage_id_++;
data_ptr_to_storage_id.emplace(raw_data_ptr, id);
return id;
}
auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
if (weak_iter == data_ptr_to_weak_storage_ptr.end()) {
ID id = storage_id_++;
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
data_ptr_to_weak_storage_ptr.emplace(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
} else {
// check if the storage is still alive
if (iter->second.expired()) {
ID id = storage_id_++;
// std::unorder_map does not change if the key is already in the map.
// So we need to remove the key and insert the key with the new value.
data_ptr_to_storage_id.erase(raw_data_ptr);
data_ptr_to_storage_id[raw_data_ptr] = id;
data_ptr_to_weak_storage_ptr.insert_or_assign(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
} else {
return data_ptr_to_storage_id[raw_data_ptr];
}
}
if (weak_iter->second.expired()) {
ID id = storage_id_++;
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
data_ptr_to_weak_storage_ptr.insert_or_assign(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
}
id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end());
return id_iter->second;
}
// Observer run state.