From ac07de4a61831bd23e2d3d270890cde3983d9257 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Thu, 2 Mar 2023 18:15:11 +0000 Subject: [PATCH] Add export docs, improve asserts (#94961) Pull Request resolved: https://github.com/pytorch/pytorch/pull/94961 Approved by: https://github.com/tugsbayasgalan --- torch/_dynamo/eval_frame.py | 43 ++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 9383c9fab34e..249da7d9ec3f 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -50,6 +50,7 @@ from torch.fx.experimental import proxy_tensor always_optimize_code_objects = utils.ExactWeakKeyDictionary() null_context = contextlib.nullcontext + # See https://github.com/python/typing/pull/240 class Unset(Enum): token = 0 @@ -565,6 +566,38 @@ def explain(f, *args, **kwargs): def export( f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs ): + """ + Export an input function f to a format that can be executed outside of PyTorch using the FX graph. + + Args: + f (callable): A PyTorch function to be exported. + + *args: Variable length argument list to be passed to the function f. + + aten_graph (bool): If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. Default is False. + + decomposition_table (dict): A dictionary that maps operators to their decomposition functions. + Required if aten_graph or tracing_mode is specified. Default is None. + + tracing_mode (str): Specifies the tracing mode. Must be set to "real" if decomposition_table is not specified. + If decomposition_table is specified, the options are "symbolic" or "fake". Default is "real". + + **kwargs: Arbitrary keyword arguments to be passed to the function f. + + Returns: + A tuple of (graph, guards) + Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. + Guards: The guards we accumulated during tracing f above + + Raises: + AssertionError: If decomposition_table or tracing_mode is specified without setting aten_graph=True, + or if graph breaks during tracing in export. + + AssertionError: If Dynamo input and output is not consistent with traced input/output. + + Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. + """ check_if_dynamo_supported() torch._C._log_api_usage_once("torch._dynamo.export") if decomposition_table is not None or tracing_mode != "real": @@ -617,7 +650,9 @@ def export( ): nonlocal graph - assert graph is None, "whole graph export entails exactly one graph" + assert ( + graph is None + ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." graph = gm def result_capturing_wrapper(*graph_inputs): @@ -645,8 +680,10 @@ def export( result_traced = opt_f(*args, **kwargs) remove_from_cache(f) - assert graph is not None, "whole graph export entails exactly one call" - assert out_guards is not None, "whole graph export entails exactly one guard export" + assert ( + graph is not None + ), "Failed to produce a graph during tracing. Tracing through 'f' must produce a single graph." + assert out_guards is not None, "Failed to produce guards during tracing" matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)