[export] Make aoti_call_delegate hop traceable (#148804)

Summary: The `aoti_call_delegate` hop now uses a stateless `original_gm` for tracing with fake tensors and the OSS AOTI Runner for running with real tensors

Differential Revision: D70738393

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148804
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Yiming Zhou
2025-04-03 20:44:31 +00:00
committed by PyTorch MergeBot
parent 51da241c0a
commit a3f9e04656
3 changed files with 83 additions and 28 deletions

View File

@ -178,6 +178,8 @@ def lift_constants_pass(
continue
if "LoweredBackendModule" in type(constant_val).__name__:
continue
if "AOTInductorRunnerWrapper" in type(constant_val).__name__:
continue
if isinstance(constant_val, torch.utils._pytree.TreeSpec):
continue
@ -237,7 +239,6 @@ def lift_constants_pass(
constant_name = f"lifted_tensor_{num_tensor_constants}"
constant_fqn = get_constant_fqn(node, constant_name)
num_tensor_constants += 1
else:
raise SpecViolationError(
f"getattr node {node} referencing unsupported type {type(constant_val)}"

View File

@ -271,6 +271,8 @@ class Verifier(metaclass=_VerifierMeta):
elif type(attr).__name__ == "AOTInductorEPModule":
continue
elif type(attr).__name__ == "AOTInductorRunnerWrapper":
continue
if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)):
raise SpecViolationError(

View File

@ -1,20 +1,25 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import torch
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
AOTI_LOWERED_MODULE = "AOTInductorEPModule"
AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper"
class AOTICallDelegate(HigherOrderOperator):
@ -22,7 +27,7 @@ class AOTICallDelegate(HigherOrderOperator):
It has the following signature:
aoti_call_delegate(
lowered_module: AOTInductorEPModule,
lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper]
original_gm:fx.GraphModule,
weight_args: List[Tensor],
input_args: List[Tensor],
@ -30,15 +35,9 @@ class AOTICallDelegate(HigherOrderOperator):
where,
- lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs
- original_gm is the original GraphModule before lowering, allowing FakeTensor propagation
- original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation
- weight_args is the list of weights in original GraphModule, including parameters and buffers
- input_args is the list of flatten inputs
NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes.
This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing.
When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable
and stateful original_gm is failing the verifier.
"""
def __init__(self) -> None:
@ -62,7 +61,6 @@ aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
@aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
# pyre-ignore
def call_delegate_cpu(
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
@ -77,27 +75,60 @@ def call_delegate_cpu(
new_args = pytree.tree_map_only(
tuple(map_types.keys()),
lambda a: map_types[type(a)](a),
input_args,
weight_args + input_args,
lambda a: isinstance(a, tuple(map_types.keys())),
)
has_fake_input_args = any(isinstance(arg, FakeTensor) for arg in new_args)
has_fake_params = any(
isinstance(param, FakeTensor) for param in original_gm.parameters()
)
has_fake_buffers = any(
isinstance(buffer, FakeTensor) for buffer in original_gm.buffers()
)
if has_fake_input_args or has_fake_params or has_fake_buffers:
# aoti lowered module doesn't support fake tensor
return original_gm(*new_args)
has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args)
if has_fake_args:
# use stateless original_gm for tracing with fake tensors
fake_out = original_gm(*new_args)
return fake_out
else:
return lowered_module(new_args) # type: ignore[misc]
# use AOTI Runner for real tensors
new_input_args = new_args[len(weight_args) :]
if type(lowered_module).__name__ == "AOTInductorRunnerWrapper":
return lowered_module(*new_input_args) # type: ignore[misc]
elif type(lowered_module).__name__ == "AOTInductorEPModule":
return lowered_module(new_input_args) # type: ignore[misc]
else:
raise RuntimeError(
f"Unexpected lowered_module type: {type(lowered_module)}."
)
def trace_aoti_call_delegate(
proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args
):
proxy_mode.tracer.root.register_module("lowered_module", lowered_module)
proxy_mode.tracer.root.register_module("original_gm", original_gm)
node_args = (lowered_module, original_gm, weight_args, input_args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="aoti_call_delegate"
)
with disable_proxy_modes_tracing():
out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@aoti_call_delegate.py_impl(ProxyTorchDispatchMode)
def call_delegate_proxy_torch_dispatch_mode(
mode: ProxyTorchDispatchMode,
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
weight_args: list[torch.Tensor],
input_args: list[torch.Tensor],
):
res = trace_aoti_call_delegate(
mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args
)
return res
@aoti_call_delegate.py_impl(FakeTensorMode)
# pyre-ignore
def call_delegate_fake_tensor_mode(
mode: FakeTensorMode,
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode(
) -> list[torch.Tensor]:
with mode:
return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
@aoti_call_delegate.py_functionalize_impl
def call_delegate_functionalize(
ctx,
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
original_gm: torch.fx.GraphModule,
weight_args: list[torch.Tensor],
input_args: list[torch.Tensor],
):
unwrapped_weight_args = tuple(
ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args
)
unwrapped_input_args = tuple(
ctx.unwrap_tensors(input_arg) for input_arg in input_args
)
with ctx.redispatch_to_next():
res = aoti_call_delegate(
lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type]
)
return ctx.wrap_tensors(res)