mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Make aoti_call_delegate hop traceable (#148804)
Summary: The `aoti_call_delegate` hop now uses a stateless `original_gm` for tracing with fake tensors and the OSS AOTI Runner for running with real tensors Differential Revision: D70738393 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148804 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
51da241c0a
commit
a3f9e04656
@ -178,6 +178,8 @@ def lift_constants_pass(
|
||||
continue
|
||||
if "LoweredBackendModule" in type(constant_val).__name__:
|
||||
continue
|
||||
if "AOTInductorRunnerWrapper" in type(constant_val).__name__:
|
||||
continue
|
||||
if isinstance(constant_val, torch.utils._pytree.TreeSpec):
|
||||
continue
|
||||
|
||||
@ -237,7 +239,6 @@ def lift_constants_pass(
|
||||
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
||||
constant_fqn = get_constant_fqn(node, constant_name)
|
||||
num_tensor_constants += 1
|
||||
|
||||
else:
|
||||
raise SpecViolationError(
|
||||
f"getattr node {node} referencing unsupported type {type(constant_val)}"
|
||||
|
@ -271,6 +271,8 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
elif type(attr).__name__ == "AOTInductorEPModule":
|
||||
continue
|
||||
|
||||
elif type(attr).__name__ == "AOTInductorRunnerWrapper":
|
||||
continue
|
||||
|
||||
if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)):
|
||||
raise SpecViolationError(
|
||||
|
@ -1,20 +1,25 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
||||
|
||||
AOTI_LOWERED_MODULE = "AOTInductorEPModule"
|
||||
AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper"
|
||||
|
||||
|
||||
class AOTICallDelegate(HigherOrderOperator):
|
||||
@ -22,7 +27,7 @@ class AOTICallDelegate(HigherOrderOperator):
|
||||
|
||||
It has the following signature:
|
||||
aoti_call_delegate(
|
||||
lowered_module: AOTInductorEPModule,
|
||||
lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper]
|
||||
original_gm:fx.GraphModule,
|
||||
weight_args: List[Tensor],
|
||||
input_args: List[Tensor],
|
||||
@ -30,15 +35,9 @@ class AOTICallDelegate(HigherOrderOperator):
|
||||
|
||||
where,
|
||||
- lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs
|
||||
- original_gm is the original GraphModule before lowering, allowing FakeTensor propagation
|
||||
- original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation
|
||||
- weight_args is the list of weights in original GraphModule, including parameters and buffers
|
||||
- input_args is the list of flatten inputs
|
||||
|
||||
NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes.
|
||||
This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing.
|
||||
|
||||
When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable
|
||||
and stateful original_gm is failing the verifier.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -62,7 +61,6 @@ aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
|
||||
|
||||
|
||||
@aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
|
||||
# pyre-ignore
|
||||
def call_delegate_cpu(
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
@ -77,27 +75,60 @@ def call_delegate_cpu(
|
||||
new_args = pytree.tree_map_only(
|
||||
tuple(map_types.keys()),
|
||||
lambda a: map_types[type(a)](a),
|
||||
input_args,
|
||||
weight_args + input_args,
|
||||
lambda a: isinstance(a, tuple(map_types.keys())),
|
||||
)
|
||||
|
||||
has_fake_input_args = any(isinstance(arg, FakeTensor) for arg in new_args)
|
||||
has_fake_params = any(
|
||||
isinstance(param, FakeTensor) for param in original_gm.parameters()
|
||||
)
|
||||
has_fake_buffers = any(
|
||||
isinstance(buffer, FakeTensor) for buffer in original_gm.buffers()
|
||||
)
|
||||
|
||||
if has_fake_input_args or has_fake_params or has_fake_buffers:
|
||||
# aoti lowered module doesn't support fake tensor
|
||||
return original_gm(*new_args)
|
||||
has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args)
|
||||
if has_fake_args:
|
||||
# use stateless original_gm for tracing with fake tensors
|
||||
fake_out = original_gm(*new_args)
|
||||
return fake_out
|
||||
else:
|
||||
return lowered_module(new_args) # type: ignore[misc]
|
||||
# use AOTI Runner for real tensors
|
||||
new_input_args = new_args[len(weight_args) :]
|
||||
if type(lowered_module).__name__ == "AOTInductorRunnerWrapper":
|
||||
return lowered_module(*new_input_args) # type: ignore[misc]
|
||||
elif type(lowered_module).__name__ == "AOTInductorEPModule":
|
||||
return lowered_module(new_input_args) # type: ignore[misc]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected lowered_module type: {type(lowered_module)}."
|
||||
)
|
||||
|
||||
|
||||
def trace_aoti_call_delegate(
|
||||
proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args
|
||||
):
|
||||
proxy_mode.tracer.root.register_module("lowered_module", lowered_module)
|
||||
proxy_mode.tracer.root.register_module("original_gm", original_gm)
|
||||
|
||||
node_args = (lowered_module, original_gm, weight_args, input_args)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="aoti_call_delegate"
|
||||
)
|
||||
with disable_proxy_modes_tracing():
|
||||
out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
|
||||
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@aoti_call_delegate.py_impl(ProxyTorchDispatchMode)
|
||||
def call_delegate_proxy_torch_dispatch_mode(
|
||||
mode: ProxyTorchDispatchMode,
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
weight_args: list[torch.Tensor],
|
||||
input_args: list[torch.Tensor],
|
||||
):
|
||||
res = trace_aoti_call_delegate(
|
||||
mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
@aoti_call_delegate.py_impl(FakeTensorMode)
|
||||
# pyre-ignore
|
||||
def call_delegate_fake_tensor_mode(
|
||||
mode: FakeTensorMode,
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode(
|
||||
) -> list[torch.Tensor]:
|
||||
with mode:
|
||||
return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args)
|
||||
|
||||
|
||||
@aoti_call_delegate.py_functionalize_impl
|
||||
def call_delegate_functionalize(
|
||||
ctx,
|
||||
lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type]
|
||||
original_gm: torch.fx.GraphModule,
|
||||
weight_args: list[torch.Tensor],
|
||||
input_args: list[torch.Tensor],
|
||||
):
|
||||
unwrapped_weight_args = tuple(
|
||||
ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args
|
||||
)
|
||||
unwrapped_input_args = tuple(
|
||||
ctx.unwrap_tensors(input_arg) for input_arg in input_args
|
||||
)
|
||||
with ctx.redispatch_to_next():
|
||||
res = aoti_call_delegate(
|
||||
lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type]
|
||||
)
|
||||
return ctx.wrap_tensors(res)
|
||||
|
Reference in New Issue
Block a user