Skip buffer in dense update (#148533)

Summary:
as title.

PyTorch Module buffer will not be published in delta publishing.  In Quinn's previous diff, constant type annotations have been introduced.

In addition to skip constant, we also need to skip buffer if it is not found in the user-provided delta weights list

Test Plan: https://docs.google.com/document/d/1wiqUo0PyZ4g6YJIJlL_LE084ZEuE74iu74gZjqGGjWY/edit?tab=t.0#heading=h.dby6cwiw1xrn

Differential Revision: D69553929

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148533
Approved by: https://github.com/22quinn, https://github.com/jingsh
This commit is contained in:
Zhuoran Zhao
2025-03-07 01:59:55 +00:00
committed by PyTorch MergeBot
parent 00cd6c07b9
commit dfb4094b9c

View File

@ -223,9 +223,17 @@ class AOTInductorModelContainer {
bool _should_skip_update(const size_t idx) const {
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
// We should skip constants
return constant_type == ConstantType::TensorConstant;
}
bool _could_skip_update(const size_t idx) const {
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
// Buffer can be optionally skipped, so if it not provided by upstream
// services, it is OK to relax the check.
return constant_type == ConstantType::Buffer;
}
void assert_all_constants(
const std::unordered_map<std::string, AtenTensorHandle>& constants_map) {
auto num_constants = models_[0]->num_constants();
@ -238,10 +246,11 @@ class AOTInductorModelContainer {
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
auto it = constants_map.find(constant_name);
if (it == constants_map.end()) {
if (_should_skip_update(idx)) {
if (_should_skip_update(idx) || _could_skip_update(idx)) {
// tracing sometimes creates tensors that are non-existent in
// original graph. We could skip those and do a direct copy.
std::cerr << "[WARNING] Found constant " << constant_name
std::cerr << "[WARNING] Found constant or module state buffer "
<< constant_name
<< " in model, but not provided by user!\n";
continue;
}
@ -272,13 +281,12 @@ class AOTInductorModelContainer {
auto constant_name =
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
auto it = constants_map.find(constant_name);
if (it == constants_map.end() &&
!(_should_skip_update(idx) && use_inactive)) {
if (it == constants_map.end() && !use_inactive) {
continue;
}
AtenTensorHandle tensor;
if (_should_skip_update(idx) && use_inactive) {
if (it == constants_map.end() && use_inactive) {
aoti_torch_clone(
original_constants_map->find(constant_name)->second.get(), &tensor);
} else {
@ -313,13 +321,12 @@ class AOTInductorModelContainer {
auto constant_name =
std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
auto it = constants_map.find(constant_name);
if (it == constants_map.end() &&
!(_should_skip_update(idx) && use_inactive)) {
if (it == constants_map.end() && !use_inactive) {
continue;
}
AtenTensorHandle tensor;
if (_should_skip_update(idx) && use_inactive) {
if (it == constants_map.end() && use_inactive) {
tensor = original_constants_map->find(constant_name)->second.get();
} else {
tensor = it->second;