diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 2b3458df15a0..7fe9bf9370b0 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -127,71 +127,75 @@ def capture_pre_autograd_graph( """ from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG - - log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) + from torch._utils_internal import export_api_rollout_check assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance." if kwargs is None: kwargs = {} - # Do not decompose dropout for exported models, because in eval mode the dropout - # op disappears from the graph, which makes it difficult to switch to train mode. - # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. - decomp_table = { - op: op.decompose - for op in FunctionalTensor.maybe_aliasing_or_mutating_ops - if op != torch.ops.aten.dropout.default - } - with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): - m = torch._dynamo.export( - f, - dynamic_shapes=dynamic_shapes, - assume_static_by_default=True, - tracing_mode="symbolic", - decomposition_table=decomp_table, - pre_dispatch=True, - aten_graph=True, - _log_export_usage=False, - )( - *args, - **kwargs, - )[0] + if export_api_rollout_check(): + module = torch.export._trace._export(f, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=True).module() + else: + log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) - _, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs) - - m.meta["inline_constraints"] = { - k: v - for k, v in fake_mode.shape_env.var_to_range.items() - if re.match(r"^[if]\d+$", str(k)) + # Do not decompose dropout for exported models, because in eval mode the dropout + # op disappears from the graph, which makes it difficult to switch to train mode. + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832. + decomp_table = { + op: op.decompose + for op in FunctionalTensor.maybe_aliasing_or_mutating_ops + if op != torch.ops.aten.dropout.default } + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + m = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, + assume_static_by_default=True, + tracing_mode="symbolic", + decomposition_table=decomp_table, + pre_dispatch=True, + aten_graph=True, + _log_export_usage=False, + )( + *args, + **kwargs, + )[0] - if isinstance(f, torch.nn.Module): - from torch.export._trace import _restore_state_dict - _restore_state_dict(f, m) + _, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs) - flat_args, _ = pytree.tree_flatten((args, kwargs or {})) - range_constraints = _process_constraints(fake_mode, m, 0, flat_args) + m.meta["inline_constraints"] = { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if re.match(r"^[if]\d+$", str(k)) + } - module = _create_stateful_graph_module( - m, - range_constraints=range_constraints, - ) + if isinstance(f, torch.nn.Module): + from torch.export._trace import _restore_state_dict + _restore_state_dict(f, m) - error_message = \ - """ - Calling train() or eval() is not supported for exported models. - Alternatively, you may override these methods to do custom user behavior as follows: + flat_args, _ = pytree.tree_flatten((args, kwargs or {})) + range_constraints = _process_constraints(fake_mode, m, 0, flat_args) - def _my_train(self, mode: bool = True): - ... + module = _create_stateful_graph_module( + m, + range_constraints=range_constraints, + ) - def _my_eval(self): - ... + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: - model.train = types.MethodType(_my_train, model) - model.eval = types.MethodType(_my_eval, model) - """ + def _my_train(self, mode: bool = True): + ... + + def _my_eval(self): + ... + + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ def _train(self, mode: bool = True): raise NotImplementedError(error_message) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 43d4cfee2b6d..698492088147 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -100,6 +100,10 @@ def log_torchscript_usage(api: str): return +def export_api_rollout_check() -> bool: + return False + + def justknobs_check(name: str) -> bool: """ This function can be used to killswitch functionality in FB prod,