mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti-fx] Initial AOTInductor FX (#160765)
Using the existing WrapperFxCodegen backend, this PR prototypes an AOT version of it which will directly return a graph module. How to use: ```python exported_gm = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes).module() compiled_gm = torch._inductor.aot_compile( exported_gm, inp, options={"fx_wrapper": True, "compile_threads": 1} ) assert torch.allclose(model(*inp), compiled_gm(*inp)) ``` The motivation behind this is that backends like ExecuTorch/MTIA would like to use inductor's optimization technologies, but might have their own graph lowering pipelines so they might not want to use AOTI (which generates an so). Pull Request resolved: https://github.com/pytorch/pytorch/pull/160765 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
162bf78df6
commit
bab79824cb
@ -541,8 +541,52 @@ class FxirTestCase(InductorTestCase):
|
||||
op="call_function", target=torch.empty_strided
|
||||
)
|
||||
(shape, stride) = empty_strided.args
|
||||
output_is_symbolic = any(isinstance(dim, torch.SymInt) for dim in shape)
|
||||
self.assertEqual(output_is_symbolic, use_dynamic_shapes)
|
||||
|
||||
|
||||
class AOTFxirTestCase(InductorTestCase):
|
||||
device = GPU_TYPE
|
||||
|
||||
def check(self, model, inp, dynamic_shapes=None):
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes)
|
||||
gm = torch._inductor.aot_compile(
|
||||
ep.module(), inp, options={"fx_wrapper": True}
|
||||
)
|
||||
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
|
||||
|
||||
def test_aoti_fx_add(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
||||
self.check(M(), inp)
|
||||
|
||||
def test_aoti_fx_const(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.a = torch.nn.Parameter(torch.ones(3, device=self.device))
|
||||
self.b = torch.ones(3, device=self.device)
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + y + self.a + self.b + torch.tensor(3, device=self.device)
|
||||
|
||||
inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device))
|
||||
self.check(M(self.device), inp)
|
||||
|
||||
def test_aoti_fx_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
inp = (torch.ones(3, 3, device=self.device),)
|
||||
self.check(M().to(self.device), inp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -148,6 +148,7 @@ def aot_compile(
|
||||
with torch.no_grad():
|
||||
so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
|
||||
|
||||
assert isinstance(so_path, (str, list))
|
||||
return so_path
|
||||
|
||||
def aot_load(so_path: str, device: str) -> Callable:
|
||||
|
@ -275,7 +275,7 @@ def aot_compile(
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> Union[str, list[Union[str, Weights]]]:
|
||||
) -> Union[str, list[Union[str, Weights]], torch.fx.GraphModule]:
|
||||
"""
|
||||
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
||||
|
||||
|
@ -461,15 +461,18 @@ def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]:
|
||||
|
||||
|
||||
def get_wrapper_codegen_for_device(
|
||||
device: str, cpp_wrapper: bool = False
|
||||
device: str, cpp_wrapper: bool = False, fx_wrapper: bool = False
|
||||
) -> Optional[WrapperConstructor]:
|
||||
if device in device_codegens:
|
||||
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
||||
return (
|
||||
wrapper_codegen_obj.cpp_wrapper_codegen
|
||||
if cpp_wrapper
|
||||
else wrapper_codegen_obj.wrapper_codegen
|
||||
)
|
||||
if fx_wrapper:
|
||||
from .wrapper_fxir import WrapperFxCodegen
|
||||
|
||||
return WrapperFxCodegen
|
||||
elif cpp_wrapper:
|
||||
return wrapper_codegen_obj.cpp_wrapper_codegen
|
||||
else:
|
||||
return wrapper_codegen_obj.wrapper_codegen
|
||||
return None
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
tracing_triton_hopifier_singleton,
|
||||
triton_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._inductor.select_algorithm import extern_kernels # noqa: F401
|
||||
from torch._inductor.utils import sympy_product, sympy_subs
|
||||
@ -168,6 +168,9 @@ class FxConverter:
|
||||
mod = PyCodeCache.load(module_code)
|
||||
kernel = getattr(mod, kernel_name)
|
||||
|
||||
if isinstance(kernel, LambdaFuture):
|
||||
kernel = kernel.result()
|
||||
|
||||
if not isinstance(kernel, CachingAutotuner):
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(f"""
|
||||
@ -263,16 +266,32 @@ class FxConverter:
|
||||
"""
|
||||
Converts graph inputs to FX placeholders.
|
||||
"""
|
||||
for name, ir_node in V.graph.graph_inputs.items():
|
||||
# Introduce a new symbol for constant inputs.
|
||||
buffer = (
|
||||
SymbolBuffer(sympy.Symbol(name, is_integer=True))
|
||||
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
|
||||
else self._get_buffer(ir_node)
|
||||
)
|
||||
node = self.gm.graph.placeholder(buffer.get_name())
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
name = node.name
|
||||
if name in V.graph.graph_inputs:
|
||||
ir_node = V.graph.graph_inputs[name]
|
||||
|
||||
# Introduce a new symbol for constant inputs.
|
||||
buffer = (
|
||||
SymbolBuffer(sympy.Symbol(name, is_integer=True))
|
||||
if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float))
|
||||
else self._get_buffer(ir_node)
|
||||
)
|
||||
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
|
||||
self._create_meta_from_buffer(placeholder_node, buffer)
|
||||
self._record_allocation(buffer, placeholder_node)
|
||||
|
||||
elif V.aot_compilation:
|
||||
# Create dummy input nodes to match the input signature
|
||||
self.gm.graph.placeholder(name)
|
||||
|
||||
def _generate_graph_constants(self) -> None:
|
||||
for name, value in V.graph.constants.items():
|
||||
node = self.gm.graph.get_attr(name)
|
||||
node.meta["val"] = value
|
||||
setattr(self.gm, name, value)
|
||||
self.buffer_to_node[name] = node
|
||||
|
||||
def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]:
|
||||
"""
|
||||
@ -334,6 +353,7 @@ class FxConverter:
|
||||
Main entrypoint for FX codegen.
|
||||
"""
|
||||
self._generate_graph_inputs()
|
||||
self._generate_graph_constants()
|
||||
|
||||
# Generate FX IR from Wrapper IR lines.
|
||||
for line in self.lines:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
import io
|
||||
@ -724,6 +725,7 @@ class _CompileFxKwargs(TypedDict, total=False):
|
||||
layout_opt: Optional[bool]
|
||||
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
fx_wrapper: bool
|
||||
|
||||
|
||||
class _CompileFxCallable(Protocol):
|
||||
@ -745,6 +747,7 @@ def compile_fx_inner(
|
||||
kwargs.setdefault("is_backward", False)
|
||||
kwargs.setdefault("graph_id", None)
|
||||
kwargs.setdefault("cpp_wrapper", False)
|
||||
kwargs.setdefault("fx_wrapper", False)
|
||||
kwargs.setdefault("is_inference", False)
|
||||
kwargs.setdefault("boxed_forward_device_index", None)
|
||||
kwargs.setdefault("layout_opt", None)
|
||||
@ -840,7 +843,9 @@ def _compile_fx_inner(
|
||||
backends_support_caching = all(
|
||||
backend.supports_caching
|
||||
for backend in (
|
||||
get_wrapper_codegen_for_device(device.type, config.cpp_wrapper)
|
||||
get_wrapper_codegen_for_device(
|
||||
device.type, config.cpp_wrapper, config.fx_wrapper
|
||||
)
|
||||
for device in get_all_devices(gm)
|
||||
)
|
||||
if backend is not None
|
||||
@ -1187,6 +1192,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
is_backward: bool = graph_kwargs.get("is_backward", False)
|
||||
graph_id: Optional[int] = graph_kwargs.get("graph_id", None)
|
||||
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
|
||||
fx_wrapper: bool = graph_kwargs.get("fx_wrapper", False)
|
||||
aot_mode: bool = V.aot_compilation
|
||||
is_inference: bool = graph_kwargs.get("is_inference", False)
|
||||
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
|
||||
@ -1389,6 +1395,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
is_inference=is_inference,
|
||||
is_backward=is_backward,
|
||||
is_const_graph=True,
|
||||
fx_wrapper=fx_wrapper,
|
||||
)
|
||||
with (
|
||||
V.set_graph_handler(const_graph),
|
||||
@ -1422,6 +1429,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
),
|
||||
const_module=const_graph,
|
||||
inputs_to_check=inputs_to_check,
|
||||
fx_wrapper=fx_wrapper,
|
||||
)
|
||||
metrics_helper = metrics.CachedMetricsHelper()
|
||||
|
||||
@ -1459,7 +1467,15 @@ class _InProcessFxCompile(FxCompile):
|
||||
with dynamo_timed(
|
||||
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
|
||||
):
|
||||
if graph.aot_mode:
|
||||
if graph.aot_mode and graph.fx_wrapper:
|
||||
assert not graph.cpp_wrapper
|
||||
compiled_fn = graph.codegen()[0].gm # type: ignore[attr-defined]
|
||||
output_code_log.debug(
|
||||
"Output graph module: \n%s",
|
||||
compiled_fn.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
elif graph.aot_mode:
|
||||
from .codecache import AotCodeCompiler
|
||||
|
||||
assert graph.cpp_wrapper, (
|
||||
@ -1571,7 +1587,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
V.graph.disable_cudagraphs_reason = disable
|
||||
|
||||
if V.aot_compilation:
|
||||
assert isinstance(compiled_fn, (str, list))
|
||||
assert isinstance(
|
||||
compiled_fn, (str, list, torch.fx.GraphModule)
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
@ -1852,17 +1870,17 @@ def compile_fx_aot(
|
||||
example_inputs_: list[InputType],
|
||||
inner_compile: _CompileFxCallable = compile_fx_inner,
|
||||
config_patches: Optional[dict[str, Any]] = None,
|
||||
) -> Union[list[Union[str, Weights]], str]:
|
||||
) -> Union[list[Union[str, Weights]], str, GraphModule]:
|
||||
assert isinstance(model_, GraphModule), model_
|
||||
|
||||
# [See NOTE] Unwrapping subclasses AOT
|
||||
unwrap_tensor_subclass_parameters(model_)
|
||||
|
||||
config_patches: dict[str, Any] = (
|
||||
{"cpp_wrapper": True}
|
||||
if config_patches is None
|
||||
else {**config_patches, "cpp_wrapper": True}
|
||||
)
|
||||
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
|
||||
|
||||
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
|
||||
# If fx_wrapper is not set, then set cpp_wrapper
|
||||
config_patches["cpp_wrapper"] = True
|
||||
|
||||
output_path = config_patches.get(
|
||||
"aot_inductor.output_path", config.aot_inductor.output_path
|
||||
@ -2124,11 +2142,15 @@ def compile_fx(
|
||||
)
|
||||
|
||||
# TODO: This probably shouldn't be a recursive call
|
||||
if config.cpp_wrapper:
|
||||
if config.cpp_wrapper or config.fx_wrapper:
|
||||
cpp_wrapper_config = config.cpp_wrapper
|
||||
fx_wrapper_config = config.fx_wrapper
|
||||
|
||||
with (
|
||||
config.patch(
|
||||
{
|
||||
"cpp_wrapper": False, # reset to break recursive call to compile_fx
|
||||
"fx_wrapper": False, # reset to break recursive call to compile_fx
|
||||
**get_cpp_wrapper_config(),
|
||||
}
|
||||
),
|
||||
@ -2174,7 +2196,11 @@ def compile_fx(
|
||||
return compile_fx(
|
||||
patched_mod,
|
||||
fake_args,
|
||||
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
|
||||
inner_compile=functools.partial(
|
||||
inner_compile,
|
||||
cpp_wrapper=cpp_wrapper_config,
|
||||
fx_wrapper=fx_wrapper_config,
|
||||
),
|
||||
decompositions=decompositions,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
)
|
||||
|
@ -187,6 +187,8 @@ cpp_wrapper_build_separate: bool = (
|
||||
os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1"
|
||||
)
|
||||
|
||||
fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1"
|
||||
|
||||
# Controls automatic precompiling of common include files for codecache.CppCodeCache
|
||||
# (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is
|
||||
# controlled by a separate flag.
|
||||
|
@ -312,6 +312,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
const_module: Optional[GraphLowering] = None,
|
||||
name: Optional[str] = None,
|
||||
inputs_to_check: Optional[Sequence[int]] = None,
|
||||
fx_wrapper: bool = False,
|
||||
) -> None:
|
||||
super().__init__(gm)
|
||||
self.example_inputs = example_inputs
|
||||
@ -411,6 +412,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.creation_time = time.time()
|
||||
self.name = name # type: ignore[assignment]
|
||||
self.cpp_wrapper = cpp_wrapper
|
||||
self.fx_wrapper = fx_wrapper
|
||||
|
||||
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
||||
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
||||
@ -2016,7 +2018,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
self.device_ops = get_device_op_overrides(self.device_type)
|
||||
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
||||
self.device_type, self.cpp_wrapper
|
||||
self.device_type, self.cpp_wrapper, self.fx_wrapper
|
||||
)
|
||||
assert wrapper_code_gen_cls is not None, (
|
||||
f"Device {self.device_type} not supported"
|
||||
|
@ -5614,7 +5614,10 @@ class ExternKernel(InputsKernel):
|
||||
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
||||
device = d.type if (d := self.get_device()) else V.graph.device_type
|
||||
if V.graph.cpp_wrapper:
|
||||
if V.graph.fx_wrapper:
|
||||
assert self.python_kernel_name is not None
|
||||
return self.python_kernel_name
|
||||
elif V.graph.cpp_wrapper:
|
||||
assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
|
||||
V.graph.wrapper_code
|
||||
)
|
||||
@ -7307,7 +7310,10 @@ class AssertScalar(ExternKernel):
|
||||
# simplify(u0 == 0), you will get True (because we've already runtime assert'ed
|
||||
# that it's true). But we're code generating the actual runtime assert here!!
|
||||
symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False)))
|
||||
if V.graph.cpp_wrapper:
|
||||
if V.graph.fx_wrapper:
|
||||
# TODO fix
|
||||
pass
|
||||
elif V.graph.cpp_wrapper:
|
||||
symbol_str = f"std::to_string({symbol})"
|
||||
sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
|
||||
self.scalar, simplify=False
|
||||
|
@ -723,7 +723,7 @@ class CompiledAOTI(OutputCode):
|
||||
Class holding an AOTInductor compiled so.
|
||||
"""
|
||||
|
||||
filename: Union[str, list[Union[str, Weights]]]
|
||||
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
|
Reference in New Issue
Block a user