Compare commits

...

2 Commits

4 changed files with 75 additions and 10 deletions

View File

@ -269,6 +269,42 @@ AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBufferPairs(
})
}
AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFQNNames(
AOTInductorModelContainerHandle container_handle,
const AOTInductorConstantMapEntry* pairs,
size_t num_pairs,
bool use_inactive,
bool validate_full_update) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(container_handle);
std::unordered_map<std::string, std::string> fqn_to_constant_name;
for (size_t i=0; i < container->num_constants(); ++i) {
fqn_to_constant_name.emplace(container->constant_original_fqn(i), container->constant_name(i));
}
// Build a local unordered_map inside
std::unordered_map<std::string, AtenTensorHandle> input_map;
input_map.reserve(num_pairs);
for (size_t i = 0; i < num_pairs; ++i) {
std::string constant_name = "";
auto it = fqn_to_constant_name.find(pairs[i].name);
if (it != fqn_to_constant_name.end()) {
input_map.emplace(it->second, pairs[i].handle);
} else {
throw std::runtime_error(std::string("Constant not found for FQN: ") + pairs[i].name);
}
}
CONVERT_EXCEPTION_TO_ERROR_CODE({
container->update_constant_buffer(
input_map, use_inactive, validate_full_update, /*user_managed=*/true);
})
}
AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle,

View File

@ -1585,6 +1585,8 @@ class aot_inductor:
package_constants_in_so: bool = True
# Experimental. Flag to control whether to package weight separately on disk
# The key for the weights are FQN names, not the constant name in AOTI model.
# The FQN name and constant name map can be obtained using getConstantNamesToOriginalFQNs.
package_constants_on_disk: bool = False
# Experimental. Controls automatic precompiling of common AOTI include files.

View File

@ -158,6 +158,7 @@ AOTInductorModelContainerUpdateUserManagedConstantBuffer(
// Same as AOTInductorModelContainerUpdateUserManagedConstantBuffer,
// but no std::unordered_map crosses DLL boundaries for cross-compilation.
// The keys in pairs are constant names in AOTI model.
AOTI_API AOTIRuntimeError
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs(
AOTInductorModelContainerHandle container_handle,
@ -166,6 +167,15 @@ AOTInductorModelContainerUpdateUserManagedConstantBufferPairs(
bool use_inactive,
bool validate_full_update);
// The keys in pairs are original_fqn names in AOTI model.
AOTI_API AOTIRuntimeError
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFQNNames(
AOTInductorModelContainerHandle container_handle,
const AOTInductorConstantMapEntry* pairs,
size_t num_pairs,
bool use_inactive,
bool validate_full_update);
// Setup the constant buffer in model container with provided ConstantMap
// use_inactive should be set as true if the inactive buffer is to be updated.
// validate_full_update checks if all constants are included in the ConstantMap

View File

@ -305,24 +305,41 @@ def _package_aoti_files(
filename = f"{WEIGHT_FILENAME_PREFIX}{idx}"
model_name, weight_name = get_complete(group, all_weights)
complete_tensor, _ = all_weights[model_name].get_weight(weight_name)
buffer = io.BytesIO()
torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
tensor_bytes = _get_raw_tensor_bytes(complete_tensor)
archive_writer.write_bytes(
os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
os.path.join(WEIGHTS_DIR, filename), tensor_bytes
)
for model_name, weight_name in group:
_, w_property = all_weights[model_name].get_weight(weight_name)
weights_configs[model_name][weight_name] = (
filename,
w_property.shape,
w_property.stride,
w_property.offset,
tensor, _ = all_weights[model_name].get_weight(weight_name)
weights_configs[model_name][weight_name] = schema.PayloadMeta(
path_name=filename,
is_param=isinstance(tensor, torch.nn.Parameter),
use_pickle=False,
tensor_meta=serialize_tensor_meta(tensor),
)
# buffer = io.BytesIO()
# torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
# archive_writer.write_bytes(
# os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
# )
# for model_name, weight_name in group:
# tensor, _ = all_weights[model_name].get_weight(weight_name)
# weights_configs[model_name][weight_name] = schema.PayloadMeta(
# path_name=filename,
# is_param=isinstance(tensor, torch.nn.Parameter),
# use_pickle=True,
# tensor_meta=serialize_tensor_meta(tensor),
# )
for model_name, weights_config in weights_configs.items():
payload_config = schema.PayloadConfig(config=weights_config)
archive_writer.write_string(
os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"),
json.dumps(weights_config),
json.dumps(_dataclass_to_dict(payload_config)),
)
logger.debug("packaging weights_config for model %s", model_name)
logger.debug(weights_config)