Compare commits

...

1 Commits

Author SHA1 Message Date
5ffea8b44e add ABI stable method for updating constant buffer 2025-10-01 15:00:36 -07:00
2 changed files with 35 additions and 0 deletions

View File

@ -249,6 +249,26 @@ AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer(
})
}
AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBufferPairs(
AOTInductorModelContainerHandle container_handle,
const AOTInductorConstantMapEntry* pairs,
size_t num_pairs,
bool use_inactive,
bool validate_full_update) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
// Build a local unordered_map inside
std::unordered_map<std::string, AtenTensorHandle> input_map;
input_map.reserve(num_pairs);
for (size_t i = 0; i < num_pairs; ++i) {
input_map.emplace(pairs[i].name, pairs[i].handle);
}
CONVERT_EXCEPTION_TO_ERROR_CODE({
container->update_constant_buffer(
input_map, use_inactive, validate_full_update, /*user_managed=*/true);
})
}
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle,

View File

@ -30,6 +30,11 @@ using AOTInductorStreamHandle = AOTInductorStreamOpaque*;
struct AOTInductorConstantMap;
using AOTInductorConstantMapHandle = AOTInductorConstantMap*;
struct AOTInductorConstantMapEntry {
const char* name;
AtenTensorHandle handle;
};
// TODO: Deprecate this API. This was kept for BC compatibility.
// Please use AOTInductorModelContainerCreateWithDevice instead.
AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate(
@ -151,6 +156,16 @@ AOTInductorModelContainerUpdateUserManagedConstantBuffer(
bool use_inactive,
bool validate_full_update);
// Same as AOTInductorModelContainerUpdateUserManagedConstantBuffer,
// but no std::unordered_map crosses DLL boundaries for cross-compilation.
AOTI_API AOTIRuntimeError
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs(
AOTInductorModelContainerHandle container_handle,
const AOTInductorConstantMapEntry* pairs,
size_t num_pairs,
bool use_inactive,
bool validate_full_update);
// Setup the constant buffer in model container with provided ConstantMap
// use_inactive should be set as true if the inactive buffer is to be updated.
// validate_full_update checks if all constants are included in the ConstantMap