[AOTInductor] Fix state of ConstantFolding (#153152)

Summary:
Bug fix for constant folding states. We are not setting the correct state for each updates.
One race condition would be:
(1) All threads obtain the model_exec_lock from main run.
(2) In second round of updated constant buffer, we should have set secondary as INITIALIZED but primary is mistakenly set instead.
(3) run_const_fold get called and an model_exec_lock is obtained, waiting for available at this time.
(4) main run enters INITIALIZED, waiting for unique_lock (which a shared_lock is being held by (3) at this moment)

Test Plan:
TBD

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153152
Approved by: https://github.com/jingsh, https://github.com/chenyang78
This commit is contained in:
Mu-Chu Lee
2025-05-08 07:09:19 -07:00
committed by PyTorch MergeBot
parent f2ea63658f
commit c227865720

View File

@ -107,25 +107,27 @@ class AOTInductorModelContainer {
std::shared_lock model_lk(model_exec_mutex_);
auto* model = get_available_model();
if (constant_folded_ == ConstantState::INITIALIZED) {
ConstantState& const_folded =
use_secondary_ ? constant_folded_secondary_ : constant_folded_;
if (const_folded == ConstantState::INITIALIZED) {
// At this point, constant is not ready yet. We need to call constant
// folding before we execute the model. We obtain a unique lock at this
// point to make sure constant is ready for all.
model_lk.unlock();
std::unique_lock constants_folding_lk(model_exec_mutex_);
// Double locking to make sure constant folding is only ran once.
if (constant_folded_ == ConstantState::INITIALIZED) {
if (const_folded == ConstantState::INITIALIZED) {
auto folded_const_map = model->run_const_fold(
stream, proxy_executor, /* initialization = */ true);
update_constant_buffer(
std::move(folded_const_map),
/* use_inactive = */ false,
/* validate_full_update = */ false);
constant_folded_ = ConstantState::FOLDED;
const_folded = ConstantState::FOLDED;
}
constants_folding_lk.unlock();
model_lk.lock();
} else if (constant_folded_ != ConstantState::FOLDED) {
} else if (const_folded != ConstantState::FOLDED) {
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
}
@ -159,14 +161,16 @@ class AOTInductorModelContainer {
AOTIProxyExecutorHandle proxy_executor) {
auto* model = available_models_[0];
if (constant_folded_ == ConstantState::INITIALIZED) {
ConstantState& const_folded =
use_secondary_ ? constant_folded_secondary_ : constant_folded_;
if (const_folded == ConstantState::INITIALIZED) {
auto folded_const_map = model->run_const_fold(
stream, proxy_executor, /* initialization = */ true);
update_constant_buffer(
std::move(folded_const_map),
/* use_inactive = */ false,
/* validate_full_update = */ false);
constant_folded_ = ConstantState::FOLDED;
const_folded = ConstantState::FOLDED;
} else if (constant_folded_ != ConstantState::FOLDED) {
throw std::runtime_error(
"Unknown constant state: " + toStringConstantState(constant_folded_));
@ -253,29 +257,31 @@ class AOTInductorModelContainer {
bool inactive_buffer,
DeviceStreamType stream,
AOTIProxyExecutorHandle proxy_executor) {
std::shared_lock model_lk(model_exec_mutex_);
auto* model = get_available_model();
AOTInductorModel* model;
ConstantState& const_folded = inactive_buffer == use_secondary_
? constant_folded_
: constant_folded_secondary_;
if (!inactive_buffer) {
// We would need to acquire a unique lock if we want to run constant
// folding on the active buffer.
model_lk.unlock();
std::unique_lock constants_folding_lk(model_exec_mutex_);
model = get_available_model();
try {
auto folded_const_map = model->run_const_fold(stream, proxy_executor);
update_constant_buffer(
std::move(folded_const_map),
/* use_inactive = */ false,
/* validate_full_update = */ false);
constant_folded_ = ConstantState::FOLDED;
const_folded = ConstantState::FOLDED;
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
throw;
}
constants_folding_lk.unlock();
model_lk.lock();
} else {
std::shared_lock model_lk(model_exec_mutex_);
model = get_available_model();
// We swap the constant mapping to the inactive buffer in the model to run
// const run.
auto constants_map = get_constants_map(/* get_inactive= */ true);
@ -298,7 +304,7 @@ class AOTInductorModelContainer {
model->update_constants_map(
constants_map, /* remap_constants_array= */ false);
model->update_constants_array(constants_array);
constant_folded_secondary_ = ConstantState::FOLDED;
const_folded = ConstantState::FOLDED;
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
@ -535,7 +541,6 @@ class AOTInductorModelContainer {
model->update_constants_array(constants_array);
}
std::swap(constant_folded_, constant_folded_secondary_);
use_secondary_ = !use_secondary_;
}