Add export docs, improve asserts (#94961)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94961
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Michael Voznesensky
2023-03-02 18:15:11 +00:00
committed by PyTorch MergeBot
parent 027ebca4d7
commit ac07de4a61

View File

@ -50,6 +50,7 @@ from torch.fx.experimental import proxy_tensor
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
# See https://github.com/python/typing/pull/240
class Unset(Enum):
token = 0
@ -565,6 +566,38 @@ def explain(f, *args, **kwargs):
def export(
f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
):
"""
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
Args:
f (callable): A PyTorch function to be exported.
*args: Variable length argument list to be passed to the function f.
aten_graph (bool): If True, exports a graph with ATen operators.
If False, exports a graph with Python operators. Default is False.
decomposition_table (dict): A dictionary that maps operators to their decomposition functions.
Required if aten_graph or tracing_mode is specified. Default is None.
tracing_mode (str): Specifies the tracing mode. Must be set to "real" if decomposition_table is not specified.
If decomposition_table is specified, the options are "symbolic" or "fake". Default is "real".
**kwargs: Arbitrary keyword arguments to be passed to the function f.
Returns:
A tuple of (graph, guards)
Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options.
Guards: The guards we accumulated during tracing f above
Raises:
AssertionError: If decomposition_table or tracing_mode is specified without setting aten_graph=True,
or if graph breaks during tracing in export.
AssertionError: If Dynamo input and output is not consistent with traced input/output.
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
"""
check_if_dynamo_supported()
torch._C._log_api_usage_once("torch._dynamo.export")
if decomposition_table is not None or tracing_mode != "real":
@ -617,7 +650,9 @@ def export(
):
nonlocal graph
assert graph is None, "whole graph export entails exactly one graph"
assert (
graph is None
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
graph = gm
def result_capturing_wrapper(*graph_inputs):
@ -645,8 +680,10 @@ def export(
result_traced = opt_f(*args, **kwargs)
remove_from_cache(f)
assert graph is not None, "whole graph export entails exactly one call"
assert out_guards is not None, "whole graph export entails exactly one guard export"
assert (
graph is not None
), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph."
assert out_guards is not None, "Failed to produce guards during tracing"
matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)