mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553 Approved by: https://github.com/ezyang ghstack dependencies: #144551
327 lines
12 KiB
Python
327 lines
12 KiB
Python
import copy
|
|
import dataclasses
|
|
import functools
|
|
import types
|
|
import typing
|
|
import typing_extensions
|
|
|
|
import torch
|
|
from torch.export.exported_program import _decompose_exported_program
|
|
|
|
|
|
def _copy_graph_module_and_signature(
|
|
ep: torch.fx.GraphModule,
|
|
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
|
|
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
|
|
# and this can break placeholder names in some particular cases.
|
|
# For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
|
|
# So we manually overwrite placeholder names by reading the old graph.
|
|
gm = copy.deepcopy(ep.graph_module)
|
|
new_graph_signature = copy.deepcopy(ep.graph_signature)
|
|
|
|
# iterate over old/new graph modules
|
|
for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr]
|
|
old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"]
|
|
new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"]
|
|
# iterate over placeholders
|
|
assert len(old_phs) == len(new_phs)
|
|
for old_node, new_node in zip(old_phs, new_phs):
|
|
new_node.name = old_node.name
|
|
|
|
return gm, new_graph_signature # type: ignore[return-value]
|
|
|
|
|
|
def _remove_detach_pass(
|
|
gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature
|
|
) -> None:
|
|
with gm._set_replace_hook(sig.get_replace_hook()):
|
|
for node in list(reversed(gm.graph.nodes)):
|
|
if node.op != "call_function":
|
|
continue
|
|
if (
|
|
node.target == torch.ops.aten.detach.default
|
|
and len(node.users) == 1
|
|
and next(iter(node.users)).target == torch.ops.aten.detach.default
|
|
):
|
|
next(iter(node.users)).replace_all_uses_with(node)
|
|
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
|
|
|
|
def _export_forward_backward(
|
|
ep: torch.export.ExportedProgram, joint_loss_index: int = 0
|
|
) -> torch.export.ExportedProgram:
|
|
"""
|
|
WARNING: This API is highly unstable and will be subject to change in the future.
|
|
"""
|
|
from torch._decomp import core_aten_decompositions
|
|
|
|
ep = _decompose_exported_program(
|
|
ep,
|
|
cia_to_decomp={},
|
|
python_decomp_table=core_aten_decompositions(),
|
|
joint_loss_index=joint_loss_index,
|
|
# For serialization purpose, we don't want to decompose custom triton ops.
|
|
# If users would like to decompose custom triton ops, they could do it
|
|
# with run_decompositions() API.
|
|
decompose_custom_triton_ops=False,
|
|
)
|
|
gm, new_graph_signature = _copy_graph_module_and_signature(ep)
|
|
_remove_detach_pass(gm, new_graph_signature)
|
|
|
|
return ep._update(gm, new_graph_signature)
|
|
|
|
|
|
@typing.no_type_check
|
|
def _sticky_export(forward_func, dynamic_shapes_callback=None):
|
|
"""
|
|
Lazily export the model on first forward call.
|
|
Usage:
|
|
model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback)
|
|
"""
|
|
model = forward_func.__self__
|
|
original_forward = forward_func.__func__
|
|
|
|
@functools.wraps(forward_func)
|
|
def wrapper(*args, **kwargs):
|
|
# Unpatch forward to avoid recursion during export
|
|
model.forward = types.MethodType(original_forward, model)
|
|
|
|
dynamic_shapes_spec = None
|
|
if dynamic_shapes_callback:
|
|
dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs)
|
|
|
|
try:
|
|
exported = torch.export.export(
|
|
model,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes=dynamic_shapes_spec,
|
|
).module()
|
|
wrapper._exported_artifact = exported
|
|
finally:
|
|
# Restore the wrapper after export
|
|
model.forward = wrapper
|
|
|
|
return exported(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _ExportMethod:
|
|
overloads: dict[str, torch.export.ExportedProgram]
|
|
fallbacks: list[torch.export.ExportedProgram]
|
|
|
|
|
|
_InputT = typing_extensions.ParamSpec("_InputT")
|
|
_RetT = typing.TypeVar("_RetT")
|
|
|
|
|
|
class _ExportPackage:
|
|
"""
|
|
An export package is a collection of torch.export()-ed PyTorch models consisting of
|
|
a list of exported methods and their corresponding overloads. ExportPackage is introduced
|
|
on top of torch.export() to support the following use cases:
|
|
- Exporting a model with multiple methods if a model has multiple independent parts.
|
|
- Exporting a function with multiple overloads based on tensor shapes or other metadata.
|
|
|
|
ExportPackage is designed to contain multiple methods (associated with method names) and for
|
|
each method, it can have multiple overloads (associated with overload names).
|
|
|
|
Here is an example of the data structure for an ExportPackage:
|
|
```
|
|
ExportPackage(
|
|
methods={
|
|
"decoder": ExportMethod(
|
|
overloads={
|
|
"prefill": ExportedProgram(...),
|
|
"decode": ExportedProgram(...),
|
|
},
|
|
fallbacks=[],
|
|
),
|
|
"encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]),
|
|
},
|
|
)
|
|
```
|
|
|
|
To export a model into an ExportPackage, users can use the exporter API provided by ExportPackage.
|
|
Exporter is a decorator that takes a callable and returns a wrapper. The wrapper will export the
|
|
function into an ExportPackage, when it's invoked with some sample inputs (similar to how
|
|
torch.compile() works). For more details, please refer to the document on .exporter() method.
|
|
|
|
This design allows users to decouple the exported callables from the actual sample inputs which can
|
|
be helpful for use cases where the exported callable is hidden behind helper functions or when sample
|
|
inpusts are hard to get.
|
|
|
|
NOTE: This is an experimental API and anything can be changed in the future.
|
|
|
|
Example usage:
|
|
```
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
def main(f, x):
|
|
x += 1
|
|
ret = f(x)
|
|
return ret + 1
|
|
|
|
package = ExportPackage()
|
|
main(package.exporter(fn), torch.randn(3, 2))
|
|
```
|
|
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.methods: dict[str, _ExportMethod] = {}
|
|
|
|
def _exporter(
|
|
self,
|
|
method: str,
|
|
fn: typing.Callable[_InputT, _RetT],
|
|
*,
|
|
fallback: str = "once",
|
|
) -> typing.Callable[_InputT, _RetT]:
|
|
"""
|
|
A function/module decorator that sets up a callable to be exported later invoked.
|
|
By default the exporter will only trigger torch.export for once and error on
|
|
later invocations. To customize this behavior, users have the following two options:
|
|
1. Call .define_overload() method on the returned wrapper to define an overload.
|
|
2. Adjust the fallback policy using `fallback` argument.
|
|
|
|
An "overload" is a named branch for an ExportMethod with a user defined precondition,
|
|
typically based on input tensor shapes. It's up to a downstream backend implementation
|
|
of ExportMethod to respect the precondition later in inference.
|
|
|
|
define_overload() takes arguments like the following:
|
|
- A name, for indexing purposes in a backend.
|
|
- A callable (spec) that:
|
|
- Has the same model input signature as the original model code.
|
|
- Returns an optional dynamic shape spec.
|
|
|
|
Exporter will only export an overload when the spec callable successfully returns
|
|
a result without rasing AssertionError.
|
|
|
|
For example:
|
|
```
|
|
package = ExportPackage()
|
|
|
|
|
|
def prefill(x, xa, kv_cache):
|
|
assert x.shape[1] == 3
|
|
assert kv_cache == {}
|
|
|
|
|
|
def decode(x, xa, kv_cache):
|
|
assert x.shape[1] > 1
|
|
assert len(kv_cache) > 0
|
|
return {...} # dynamic shape specs here
|
|
|
|
|
|
exporter = (
|
|
package.exporter(decoder)
|
|
.define_overload("prefill", prefill)
|
|
.define_overload("decode", decode)
|
|
)
|
|
```
|
|
|
|
A "fallback" is exported when no overload precondition matches a given set of sample
|
|
inputs. Overloads should
|
|
Fallbacks don't have names and are ordered in a list. It's up to a backend to decide
|
|
which fallback is used amony multiple ones.
|
|
|
|
A reference backend implementation of ExportMethod may look like the following:
|
|
```
|
|
def execute(method: ExportMethod, *args, **kwargs):
|
|
for overload in method.overloads:
|
|
if match_precondition(overload, *args, **kwargs):
|
|
return execute_overload(overload, *args, **kwargs)
|
|
for fallback in method.fallbacks:
|
|
if match_precondition(fallback, *args, **kwargs):
|
|
return execute_fallback(fallback, *args, **kwargs)
|
|
```
|
|
|
|
Args:
|
|
method(str): The method name for an exported part of PyTorch model. This
|
|
will be saved together with the exported/compiled artifacts
|
|
in any serialization format and can be used as the key to
|
|
index ExportPackage methods later.
|
|
fn(callable): A PyTorch function/module to be exported.
|
|
fallback(str): The fallback policy to decide when to call torch.export
|
|
- "once" is the default policy. Under this policy a PyTorch program is assumed
|
|
to be only called once later and an error will be raised for subsequent
|
|
runs.
|
|
- "error" means the ExportMethod will never have any fallbacks, meaning
|
|
users should define all the possible overloads ahead of time.
|
|
|
|
"""
|
|
|
|
fallbacks: list[torch.export.ExportedProgram] = []
|
|
specs: dict[str, typing.Callable[_InputT, typing.Any]] = {}
|
|
overloads: dict[str, torch.export.ExportedProgram] = {}
|
|
self.methods[method] = _ExportMethod(fallbacks=fallbacks, overloads=overloads)
|
|
|
|
@functools.wraps(fn)
|
|
def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def]
|
|
import torch.export._wrapper_utils
|
|
|
|
model: torch.nn.Module
|
|
if not isinstance(fn, torch.nn.Module):
|
|
model = torch.export._wrapper_utils._WrapperModule(fn)
|
|
else:
|
|
model = fn
|
|
|
|
for k, v in specs.items():
|
|
try:
|
|
if isinstance(fn, torch.nn.Module):
|
|
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
|
|
else:
|
|
dynamic_shapes = v(*args, **kwargs)
|
|
except AssertionError:
|
|
continue
|
|
if k not in overloads:
|
|
ep = torch.export.export(
|
|
model, args, kwargs, dynamic_shapes=dynamic_shapes
|
|
)
|
|
overloads[k] = ep
|
|
ep = overloads[k]
|
|
return ep.module()(*args, **kwargs)
|
|
|
|
if fallback == "error":
|
|
raise RuntimeError(
|
|
f"Exporter: Cannot export fallback {fn} when fallback policy is set to 'error',"
|
|
+ "please specify an overload or adjust the fallback policy."
|
|
)
|
|
elif fallback == "once":
|
|
if len(fallbacks) > 0:
|
|
raise RuntimeError(
|
|
f"Exporter: Cannot export {fn} more than once, "
|
|
+ "please specify an overload or adjust the fallback policy."
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Unknown fallback policy: {fallback}")
|
|
ep = torch.export.export(model, args, kwargs)
|
|
|
|
fallbacks.append(ep)
|
|
return ep.module()(*args, **kwargs)
|
|
|
|
if isinstance(fn, torch.nn.Module):
|
|
_exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811
|
|
fn, lambda _: _exporter_context
|
|
)
|
|
|
|
def _define_overload(
|
|
overload: str, spec: typing.Callable[_InputT, typing.Any]
|
|
) -> typing.Any:
|
|
assert overload not in specs
|
|
assert callable(spec)
|
|
assert overload.isidentifier()
|
|
specs[overload] = spec
|
|
return _exporter_context
|
|
|
|
assert not hasattr(fn, "_define_overload")
|
|
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
|
|
|
|
return _exporter_context
|