[aoti-fx] Initial AOTInductor FX (#160765)

Using the existing WrapperFxCodegen backend, this PR prototypes an AOT version of it which will directly return a graph module.

How to use:
```python
exported_gm = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes).module()
compiled_gm = torch._inductor.aot_compile(
    exported_gm, inp, options={"fx_wrapper": True, "compile_threads": 1}
)
assert torch.allclose(model(*inp), compiled_gm(*inp))
```

The motivation behind this is that backends like ExecuTorch/MTIA would like to use inductor's optimization technologies, but might have their own graph lowering pipelines so they might not want to use AOTI (which generates an so).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160765
Approved by: https://github.com/jansel
This commit is contained in:
angelayi
2025-08-16 16:49:58 -07:00
committed by PyTorch MergeBot
parent 162bf78df6
commit bab79824cb
10 changed files with 139 additions and 35 deletions

View File

@ -541,8 +541,52 @@ class FxirTestCase(InductorTestCase):
op="call_function", target=torch.empty_strided
)
(shape, stride) = empty_strided.args
output_is_symbolic = any(isinstance(dim, torch.SymInt) for dim in shape)
self.assertEqual(output_is_symbolic, use_dynamic_shapes)
class AOTFxirTestCase(InductorTestCase):
device = GPU_TYPE
def check(self, model, inp, dynamic_shapes=None):
with torch.no_grad():
ep = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes)
gm = torch._inductor.aot_compile(
ep.module(), inp, options={"fx_wrapper": True}
)
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
def test_aoti_fx_add(self):
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
self.check(M(), inp)
def test_aoti_fx_const(self):
class M(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.device = device
self.a = torch.nn.Parameter(torch.ones(3, device=self.device))
self.b = torch.ones(3, device=self.device)
def forward(self, x, y):
return x + y + self.a + self.b + torch.tensor(3, device=self.device)
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
self.check(M(self.device), inp)
def test_aoti_fx_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return self.linear(x)
inp = (torch.ones(3, 3, device=self.device),)
self.check(M().to(self.device), inp)
if __name__ == "__main__":

View File

@ -148,6 +148,7 @@ def aot_compile(
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
assert isinstance(so_path, (str, list))
return so_path
def aot_load(so_path: str, device: str) -> Callable:

View File

@ -275,7 +275,7 @@ def aot_compile(
kwargs: Optional[dict[str, Any]] = None,
*,
options: Optional[dict[str, Any]] = None,
) -> Union[str, list[Union[str, Weights]]]:
) -> Union[str, list[Union[str, Weights]], torch.fx.GraphModule]:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.

View File

@ -461,15 +461,18 @@ def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
def get_wrapper_codegen_for_device(
device: str, cpp_wrapper: bool = False
device: str, cpp_wrapper: bool = False, fx_wrapper: bool = False
) -> Optional[WrapperConstructor]:
if device in device_codegens:
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
return (
wrapper_codegen_obj.cpp_wrapper_codegen
if cpp_wrapper
else wrapper_codegen_obj.wrapper_codegen
)
if fx_wrapper:
from .wrapper_fxir import WrapperFxCodegen
return WrapperFxCodegen
elif cpp_wrapper:
return wrapper_codegen_obj.cpp_wrapper_codegen
else:
return wrapper_codegen_obj.wrapper_codegen
return None

View File

@ -14,7 +14,7 @@ from torch._higher_order_ops.triton_kernel_wrap import (
tracing_triton_hopifier_singleton,
triton_kernel_wrapper_mutation,
)
from torch._inductor.codecache import PyCodeCache
from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
from torch._inductor.utils import sympy_product, sympy_subs
@ -168,6 +168,9 @@ class FxConverter:
mod = PyCodeCache.load(module_code)
kernel = getattr(mod, kernel_name)
if isinstance(kernel, LambdaFuture):
kernel = kernel.result()
if not isinstance(kernel, CachingAutotuner):
raise NotImplementedError(
textwrap.dedent(f"""
@ -263,16 +266,32 @@ class FxConverter:
"""
Converts graph inputs to FX placeholders.
"""
for name, ir_node in V.graph.graph_inputs.items():
# Introduce a new symbol for constant inputs.
buffer = (
SymbolBuffer(sympy.Symbol(name, is_integer=True))
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
else self._get_buffer(ir_node)
)
node = self.gm.graph.placeholder(buffer.get_name())
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
name = node.name
if name in V.graph.graph_inputs:
ir_node = V.graph.graph_inputs[name]
# Introduce a new symbol for constant inputs.
buffer = (
SymbolBuffer(sympy.Symbol(name, is_integer=True))
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
else self._get_buffer(ir_node)
)
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
self._create_meta_from_buffer(placeholder_node, buffer)
self._record_allocation(buffer, placeholder_node)
elif V.aot_compilation:
# Create dummy input nodes to match the input signature
self.gm.graph.placeholder(name)
def _generate_graph_constants(self) -> None:
for name, value in V.graph.constants.items():
node = self.gm.graph.get_attr(name)
node.meta["val"] = value
setattr(self.gm, name, value)
self.buffer_to_node[name] = node
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
"""
@ -334,6 +353,7 @@ class FxConverter:
Main entrypoint for FX codegen.
"""
self._generate_graph_inputs()
self._generate_graph_constants()
# Generate FX IR from Wrapper IR lines.
for line in self.lines:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import contextlib
import copy
import enum
import functools
import io
@ -724,6 +725,7 @@ class _CompileFxKwargs(TypedDict, total=False):
layout_opt: Optional[bool]
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
fx_wrapper: bool
class _CompileFxCallable(Protocol):
@ -745,6 +747,7 @@ def compile_fx_inner(
kwargs.setdefault("is_backward", False)
kwargs.setdefault("graph_id", None)
kwargs.setdefault("cpp_wrapper", False)
kwargs.setdefault("fx_wrapper", False)
kwargs.setdefault("is_inference", False)
kwargs.setdefault("boxed_forward_device_index", None)
kwargs.setdefault("layout_opt", None)
@ -840,7 +843,9 @@ def _compile_fx_inner(
backends_support_caching = all(
backend.supports_caching
for backend in (
get_wrapper_codegen_for_device(device.type, config.cpp_wrapper)
get_wrapper_codegen_for_device(
device.type, config.cpp_wrapper, config.fx_wrapper
)
for device in get_all_devices(gm)
)
if backend is not None
@ -1187,6 +1192,7 @@ class _InProcessFxCompile(FxCompile):
is_backward: bool = graph_kwargs.get("is_backward", False)
graph_id: Optional[int] = graph_kwargs.get("graph_id", None)
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
fx_wrapper: bool = graph_kwargs.get("fx_wrapper", False)
aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
@ -1389,6 +1395,7 @@ class _InProcessFxCompile(FxCompile):
is_inference=is_inference,
is_backward=is_backward,
is_const_graph=True,
fx_wrapper=fx_wrapper,
)
with (
V.set_graph_handler(const_graph),
@ -1422,6 +1429,7 @@ class _InProcessFxCompile(FxCompile):
),
const_module=const_graph,
inputs_to_check=inputs_to_check,
fx_wrapper=fx_wrapper,
)
metrics_helper = metrics.CachedMetricsHelper()
@ -1459,7 +1467,15 @@ class _InProcessFxCompile(FxCompile):
with dynamo_timed(
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
):
if graph.aot_mode:
if graph.aot_mode and graph.fx_wrapper:
assert not graph.cpp_wrapper
compiled_fn = graph.codegen()[0].gm # type: ignore[attr-defined]
output_code_log.debug(
"Output graph module: \n%s",
compiled_fn.print_readable(print_output=False),
)
elif graph.aot_mode:
from .codecache import AotCodeCompiler
assert graph.cpp_wrapper, (
@ -1571,7 +1587,9 @@ class _InProcessFxCompile(FxCompile):
V.graph.disable_cudagraphs_reason = disable
if V.aot_compilation:
assert isinstance(compiled_fn, (str, list))
assert isinstance(
compiled_fn, (str, list, torch.fx.GraphModule)
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
@ -1852,17 +1870,17 @@ def compile_fx_aot(
example_inputs_: list[InputType],
inner_compile: _CompileFxCallable = compile_fx_inner,
config_patches: Optional[dict[str, Any]] = None,
) -> Union[list[Union[str, Weights]], str]:
) -> Union[list[Union[str, Weights]], str, GraphModule]:
assert isinstance(model_, GraphModule), model_
# [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_)
config_patches: dict[str, Any] = (
{"cpp_wrapper": True}
if config_patches is None
else {**config_patches, "cpp_wrapper": True}
)
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
# If fx_wrapper is not set, then set cpp_wrapper
config_patches["cpp_wrapper"] = True
output_path = config_patches.get(
"aot_inductor.output_path", config.aot_inductor.output_path
@ -2124,11 +2142,15 @@ def compile_fx(
)
# TODO: This probably shouldn't be a recursive call
if config.cpp_wrapper:
if config.cpp_wrapper or config.fx_wrapper:
cpp_wrapper_config = config.cpp_wrapper
fx_wrapper_config = config.fx_wrapper
with (
config.patch(
{
"cpp_wrapper": False, # reset to break recursive call to compile_fx
"fx_wrapper": False, # reset to break recursive call to compile_fx
**get_cpp_wrapper_config(),
}
),
@ -2174,7 +2196,11 @@ def compile_fx(
return compile_fx(
patched_mod,
fake_args,
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
inner_compile=functools.partial(
inner_compile,
cpp_wrapper=cpp_wrapper_config,
fx_wrapper=fx_wrapper_config,
),
decompositions=decompositions,
ignore_shape_env=ignore_shape_env,
)

View File

@ -187,6 +187,8 @@ cpp_wrapper_build_separate: bool = (
os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1"
)
fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1"
# Controls automatic precompiling of common include files for codecache.CppCodeCache
# (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is
# controlled by a separate flag.

View File

@ -312,6 +312,7 @@ class GraphLowering(torch.fx.Interpreter):
const_module: Optional[GraphLowering] = None,
name: Optional[str] = None,
inputs_to_check: Optional[Sequence[int]] = None,
fx_wrapper: bool = False,
) -> None:
super().__init__(gm)
self.example_inputs = example_inputs
@ -411,6 +412,7 @@ class GraphLowering(torch.fx.Interpreter):
self.creation_time = time.time()
self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper
self.fx_wrapper = fx_wrapper
# record multi_kernel choice for cpp_wrapper so the second pass knows
# which sub-kernel is picked. Copy cpp_wrapper to another variable
@ -2016,7 +2018,7 @@ class GraphLowering(torch.fx.Interpreter):
self.device_ops = get_device_op_overrides(self.device_type)
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
self.device_type, self.cpp_wrapper
self.device_type, self.cpp_wrapper, self.fx_wrapper
)
assert wrapper_code_gen_cls is not None, (
f"Device {self.device_type} not supported"

View File

@ -5614,7 +5614,10 @@ class ExternKernel(InputsKernel):
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
device = d.type if (d := self.get_device()) else V.graph.device_type
if V.graph.cpp_wrapper:
if V.graph.fx_wrapper:
assert self.python_kernel_name is not None
return self.python_kernel_name
elif V.graph.cpp_wrapper:
assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
V.graph.wrapper_code
)
@ -7307,7 +7310,10 @@ class AssertScalar(ExternKernel):
# simplify(u0 == 0), you will get True (because we've already runtime assert'ed
# that it's true). But we're code generating the actual runtime assert here!!
symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False)))
if V.graph.cpp_wrapper:
if V.graph.fx_wrapper:
# TODO fix
pass
elif V.graph.cpp_wrapper:
symbol_str = f"std::to_string({symbol})"
sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
self.scalar, simplify=False

View File

@ -723,7 +723,7 @@ class CompiledAOTI(OutputCode):
Class holding an AOTInductor compiled so.
"""
filename: Union[str, list[Union[str, Weights]]]
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")