mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Skip buffer in dense update (#148533)
Summary: as title. PyTorch Module buffer will not be published in delta publishing. In Quinn's previous diff, constant type annotations have been introduced. In addition to skip constant, we also need to skip buffer if it is not found in the user-provided delta weights list Test Plan: https://docs.google.com/document/d/1wiqUo0PyZ4g6YJIJlL_LE084ZEuE74iu74gZjqGGjWY/edit?tab=t.0#heading=h.dby6cwiw1xrn Differential Revision: D69553929 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148533 Approved by: https://github.com/22quinn, https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
00cd6c07b9
commit
dfb4094b9c
@ -223,9 +223,17 @@ class AOTInductorModelContainer {
|
||||
|
||||
bool _should_skip_update(const size_t idx) const {
|
||||
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
|
||||
// We should skip constants
|
||||
return constant_type == ConstantType::TensorConstant;
|
||||
}
|
||||
|
||||
bool _could_skip_update(const size_t idx) const {
|
||||
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
|
||||
// Buffer can be optionally skipped, so if it not provided by upstream
|
||||
// services, it is OK to relax the check.
|
||||
return constant_type == ConstantType::Buffer;
|
||||
}
|
||||
|
||||
void assert_all_constants(
|
||||
const std::unordered_map<std::string, AtenTensorHandle>& constants_map) {
|
||||
auto num_constants = models_[0]->num_constants();
|
||||
@ -238,10 +246,11 @@ class AOTInductorModelContainer {
|
||||
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
|
||||
auto it = constants_map.find(constant_name);
|
||||
if (it == constants_map.end()) {
|
||||
if (_should_skip_update(idx)) {
|
||||
if (_should_skip_update(idx) || _could_skip_update(idx)) {
|
||||
// tracing sometimes creates tensors that are non-existent in
|
||||
// original graph. We could skip those and do a direct copy.
|
||||
std::cerr << "[WARNING] Found constant " << constant_name
|
||||
std::cerr << "[WARNING] Found constant or module state buffer "
|
||||
<< constant_name
|
||||
<< " in model, but not provided by user!\n";
|
||||
continue;
|
||||
}
|
||||
@ -272,13 +281,12 @@ class AOTInductorModelContainer {
|
||||
auto constant_name =
|
||||
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
|
||||
auto it = constants_map.find(constant_name);
|
||||
if (it == constants_map.end() &&
|
||||
!(_should_skip_update(idx) && use_inactive)) {
|
||||
if (it == constants_map.end() && !use_inactive) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AtenTensorHandle tensor;
|
||||
if (_should_skip_update(idx) && use_inactive) {
|
||||
if (it == constants_map.end() && use_inactive) {
|
||||
aoti_torch_clone(
|
||||
original_constants_map->find(constant_name)->second.get(), &tensor);
|
||||
} else {
|
||||
@ -313,13 +321,12 @@ class AOTInductorModelContainer {
|
||||
auto constant_name =
|
||||
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
|
||||
auto it = constants_map.find(constant_name);
|
||||
if (it == constants_map.end() &&
|
||||
!(_should_skip_update(idx) && use_inactive)) {
|
||||
if (it == constants_map.end() && !use_inactive) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AtenTensorHandle tensor;
|
||||
if (_should_skip_update(idx) && use_inactive) {
|
||||
if (it == constants_map.end() && use_inactive) {
|
||||
tensor = original_constants_map->find(constant_name)->second.get();
|
||||
} else {
|
||||
tensor = it->second;
|
||||
|
Reference in New Issue
Block a user