mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Switch storage ID refresh to insert_or_assign
This commit is contained in:
@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
|
||||
ID get_tensor_storage_ID(const c10::Storage& t_storage) {
|
||||
const std::lock_guard<std::recursive_mutex> lock(gMutex);
|
||||
|
||||
const void* raw_data_ptr = t_storage.data();
|
||||
auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
const void* raw_data_ptr = nullptr;
|
||||
bool should_track_liveness = false;
|
||||
// FakeTensor/FunctionalTensor may clear the Storage handle entirely or use
|
||||
// a nullptr data pointer. Treat both cases as a shared cache key but avoid
|
||||
// touching the weak-ref table so they can reuse the same ID without
|
||||
// tripping the liveness check.
|
||||
if (t_storage.unsafeGetStorageImpl()) {
|
||||
raw_data_ptr = t_storage.data();
|
||||
should_track_liveness = raw_data_ptr != nullptr;
|
||||
}
|
||||
|
||||
auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
if (!should_track_liveness) {
|
||||
if (id_iter != data_ptr_to_storage_id.end()) {
|
||||
return id_iter->second;
|
||||
}
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.emplace(raw_data_ptr, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (weak_iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.emplace(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
// check if the storage is still alive
|
||||
if (iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
// std::unorder_map does not change if the key is already in the map.
|
||||
// So we need to remove the key and insert the key with the new value.
|
||||
data_ptr_to_storage_id.erase(raw_data_ptr);
|
||||
data_ptr_to_storage_id[raw_data_ptr] = id;
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
return data_ptr_to_storage_id[raw_data_ptr];
|
||||
}
|
||||
}
|
||||
|
||||
if (weak_iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
}
|
||||
|
||||
id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end());
|
||||
return id_iter->second;
|
||||
}
|
||||
|
||||
// Observer run state.
|
||||
|
||||
Reference in New Issue
Block a user