Files
pytorch/torch/_inductor/__init__.py
angelayi cd9ee49a69 [aoti] Add cpp loader (#135374)
* Added a cpp loader, AOTIModelPackageLoader, which can load the .pt2, build the .so, and create a runner. The python-facing API is that users can directly call the `run` function, whereas in cpp users can directly access the `runner_` if they are more familiar with that. I couldn't figure out how to bind the `get_runner()` function to python...
* Added a new config, `aot_inductor.package_cpp_only` which will **not** package the so. This means that whenever the package is loaded, we will need to build the so. This is turned off by default so that new environments do not need to rebuild their so. The `package_cpp_only` is a feature which torchchat intends to use to provide flexibility to users.
* Added a new config, `aot_inductor.metadata` which stores user-provided metadata, serialized to the pt2 as a json file. It also stores the device used when exporting, "cuda" or "cpu", so that during load time, we can use that data to determine which AOTIModelContainerRunner to use. The metadata can be accessed through `loader.get_metadata()`. TODO is to move this metadata to the toplevel `package_aoti` function so that we can remove the metadata as a config.
* Separated out `package_aoti` as a standalone function, instead of it automatically being called in inductor. This is to prepare for the case where users will compile multiple models, and want to bundle it in one package. The specific use case is in torchchat, where we want to package the separately-exported encoder and decoder layers. An example of how to use this is in `test_multiple_methods`.
* `load_package` will load a singular model, given the model name.
* The loader doesn't support windows for now, I think I need to add some more casing to make the build commands work on windows?

Differential Revision: [D62329906](https://our.internmc.facebook.com/intern/diff/D62329906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135374
Approved by: https://github.com/desertfire, https://github.com/malfet
2024-09-11 03:00:01 +00:00

222 lines
6.6 KiB
Python

# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Tuple
import torch.fx
import torch.utils._pytree as pytree
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
def compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
options: Optional[Dict[str, Any]] = None,
):
"""
Compile a given FX graph with TorchInductor. This allows compiling
FX graphs captured without using TorchDynamo.
Args:
gm: The FX graph to compile.
example_inputs: List of tensor inputs.
options: Optional dict of config options. See `torch._inductor.config`.
Returns:
Callable with same behavior as gm but faster.
"""
from .compile_fx import compile_fx
return compile_fx(gm, example_inputs, config_patches=options)
def aoti_compile_and_package(
exported_program,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
package_path: Optional[str] = None,
inductor_configs: Optional[Dict[str, Any]] = None,
) -> str:
"""
Compiles the exported program with AOTInductor, and packages it into a .pt2
file specified by the input package_path.
"""
from torch._inductor.package import package_aoti
from torch.export import ExportedProgram
if not isinstance(exported_program, ExportedProgram):
raise ValueError("Only ExportedProgram is supported")
assert package_path is None or package_path.endswith(".pt2")
inductor_configs = inductor_configs or {}
if inductor_configs.get("aot_inductor.output_path"):
raise RuntimeError(
"Please pass in a package path to aot_inductor_compile() instead "
"of setting the aot_inductor.output_path config."
)
inductor_configs["aot_inductor.package"] = True
m = exported_program.module()
assert isinstance(m, torch.fx.GraphModule)
aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type]
if package_path is None:
package_path = aoti_files + ".pt2"
res = package_aoti(package_path, aoti_files)
assert res == package_path
return package_path
def aot_compile(
gm: torch.fx.GraphModule,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
options: Optional[Dict[str, Any]] = None,
) -> str:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
Args:
gm: The FX graph to compile.
args: Example arguments
kwargs: Example keyword arguments
options: Optional dict of config options. See `torch._inductor.config`.
Returns:
Path to the generated shared library
"""
from .compile_fx import compile_fx_aot, graph_returns_tuple
assert graph_returns_tuple(gm), (
"Graph output must be a tuple(). This is so that we can avoid "
"pytree processing of the outputs. Please change the module to "
"have tuple outputs."
)
# We will serialize the pytree info into the .so as constant strings
in_spec = None
out_spec = None
if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
codegen = gm.graph._codegen
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.recompile()
if codegen.pytree_info.in_spec is not None:
in_spec = codegen.pytree_info.in_spec
if codegen.pytree_info.out_spec is not None:
out_spec = codegen.pytree_info.out_spec
else:
if hasattr(gm, "_in_spec"):
in_spec = gm._in_spec
if hasattr(gm, "_out_spec"):
out_spec = gm._out_spec
serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else ""
serialized_out_spec = (
pytree.treespec_dumps(out_spec) if out_spec is not None else ""
)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, kwargs or {})
)
# Replace non-tensor (constant) inputs with Nones, since these are not being
# used anyways by the graph
flat_example_inputs = [
x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path
]
if in_spec is not None and received_spec != in_spec:
raise ValueError( # noqa: B904
"Trying to flatten user inputs with exported input tree spec: \n"
f"{in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
options = (
{
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
if options is None
else {
**options,
"aot_inductor.serialized_in_spec": serialized_in_spec,
"aot_inductor.serialized_out_spec": serialized_out_spec,
}
)
return compile_fx_aot(
gm,
flat_example_inputs, # type: ignore[arg-type]
config_patches=options,
)
def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> Dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs.
Args:
mode (str, optional): The mode to return the optimizations for.
If None, returns optimizations for all modes
dynamic (bool, optional): Whether dynamic shape is enabled.
Example::
>>> torch._inductor.list_mode_options()
"""
mode_options: Dict[str, Dict[str, bool]] = {
"default": {},
# enable cudagraphs
"reduce-overhead": {
"triton.cudagraphs": True,
},
# enable max-autotune
"max-autotune-no-cudagraphs": {
"max_autotune": True,
},
# enable max-autotune
# enable cudagraphs
"max-autotune": {
"max_autotune": True,
"triton.cudagraphs": True,
},
}
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
def list_options() -> List[str]:
r"""Returns a dictionary describing the optimizations and debug configurations
that are available to `torch.compile()`.
The options are documented in `torch._inductor.config`.
Example::
>>> torch._inductor.list_options()
"""
from torch._inductor import config
current_config: Dict[str, Any] = config.shallow_copy_dict()
return list(current_config.keys())
def cudagraph_mark_step_begin():
"Indicates that a new iteration of inference or training is about to begin."
from .cudagraph_trees import mark_step_begin
mark_step_begin()