mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Add more default options to compile_standalone (#158560)
Summary: When compiling for standalone, make embed_kernel_binary and emit_multi_arch_kernel default to True, and add a default name for model_name_for_generated_files to make the generated cpp project easier to understand. Also improved the weights object file naming to be more readable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158560 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
d87161c3c8
commit
a4b07fe8f6
@ -3427,20 +3427,36 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
|
||||
Returns:
|
||||
dict[str, Any]: The possibly-updated `config_patches` dictionary.
|
||||
"""
|
||||
|
||||
def patch_config(
|
||||
config_patches: dict[str, Any], config_name: str, config_value: Any
|
||||
) -> None:
|
||||
value = config_patches.get(config_name, getattr(config, config_name))
|
||||
if value is None:
|
||||
config_patches[config_name] = config_value
|
||||
elif not value and value != config_value:
|
||||
raise RuntimeError(
|
||||
f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True."
|
||||
)
|
||||
|
||||
compile_standalone = config_patches.get(
|
||||
"aot_inductor.compile_standalone", config.aot_inductor.compile_standalone
|
||||
)
|
||||
# Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
|
||||
config_patches = config_patches.copy()
|
||||
if compile_standalone:
|
||||
package_cpp_only = config_patches.get(
|
||||
"aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only
|
||||
# Standlaone AOTInductor means only generate cpp project for building a standalone binary
|
||||
patch_config(config_patches, "aot_inductor.package_cpp_only", True)
|
||||
# Standlaone AOTInductor needs to embed the kernel code in the binary
|
||||
patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
|
||||
# Default to use multi-arch kernel codegen for non-rocm GPU
|
||||
patch_config(
|
||||
config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
|
||||
)
|
||||
if package_cpp_only is None:
|
||||
config_patches = {**config_patches, "aot_inductor.package_cpp_only": True}
|
||||
elif not package_cpp_only:
|
||||
raise RuntimeError(
|
||||
"compile_standalone=True requires package_cpp_only=True. "
|
||||
"Please set aot_inductor.package_cpp_only=True in your inductor config."
|
||||
)
|
||||
patch_config(
|
||||
config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
|
||||
)
|
||||
|
||||
return config_patches
|
||||
|
||||
|
||||
@ -3471,14 +3487,6 @@ def is_valid_aoti_model_name() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def aoti_model_name_from_config() -> str:
|
||||
from torch._inductor import config
|
||||
|
||||
model_name = config.aot_inductor.model_name_for_generated_files
|
||||
model_name = "aoti_model" if model_name is None else model_name
|
||||
return model_name
|
||||
|
||||
|
||||
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_only:
|
||||
return free_unbacked_symbols(x)
|
||||
|
Reference in New Issue
Block a user