mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add export docs, improve asserts (#94961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94961 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
027ebca4d7
commit
ac07de4a61
@ -50,6 +50,7 @@ from torch.fx.experimental import proxy_tensor
|
||||
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
||||
null_context = contextlib.nullcontext
|
||||
|
||||
|
||||
# See https://github.com/python/typing/pull/240
|
||||
class Unset(Enum):
|
||||
token = 0
|
||||
@ -565,6 +566,38 @@ def explain(f, *args, **kwargs):
|
||||
def export(
|
||||
f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
|
||||
):
|
||||
"""
|
||||
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
|
||||
|
||||
Args:
|
||||
f (callable): A PyTorch function to be exported.
|
||||
|
||||
*args: Variable length argument list to be passed to the function f.
|
||||
|
||||
aten_graph (bool): If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators. Default is False.
|
||||
|
||||
decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
|
||||
Required if aten_graph or tracing_mode is specified. Default is None.
|
||||
|
||||
tracing_mode (str): Specifies the tracing mode. Must be set to "real" if decomposition_table is not specified.
|
||||
If decomposition_table is specified, the options are "symbolic" or "fake". Default is "real".
|
||||
|
||||
**kwargs: Arbitrary keyword arguments to be passed to the function f.
|
||||
|
||||
Returns:
|
||||
A tuple of (graph, guards)
|
||||
Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
|
||||
Guards: The guards we accumulated during tracing f above
|
||||
|
||||
Raises:
|
||||
AssertionError: If decomposition_table or tracing_mode is specified without setting aten_graph=True,
|
||||
or if graph breaks during tracing in export.
|
||||
|
||||
AssertionError: If Dynamo input and output is not consistent with traced input/output.
|
||||
|
||||
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
|
||||
"""
|
||||
check_if_dynamo_supported()
|
||||
torch._C._log_api_usage_once("torch._dynamo.export")
|
||||
if decomposition_table is not None or tracing_mode != "real":
|
||||
@ -617,7 +650,9 @@ def export(
|
||||
):
|
||||
nonlocal graph
|
||||
|
||||
assert graph is None, "whole graph export entails exactly one graph"
|
||||
assert (
|
||||
graph is None
|
||||
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
|
||||
graph = gm
|
||||
|
||||
def result_capturing_wrapper(*graph_inputs):
|
||||
@ -645,8 +680,10 @@ def export(
|
||||
result_traced = opt_f(*args, **kwargs)
|
||||
remove_from_cache(f)
|
||||
|
||||
assert graph is not None, "whole graph export entails exactly one call"
|
||||
assert out_guards is not None, "whole graph export entails exactly one guard export"
|
||||
assert (
|
||||
graph is not None
|
||||
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
|
||||
assert out_guards is not None, "Failed to produce guards during tracing"
|
||||
|
||||
matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
|
||||
|
||||
|
Reference in New Issue
Block a user