mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Summary: tlparse shows unknown for certain items when _export.aot_compile() passes the graph obtained from dynamo.export() to inductor.aot_compile(), we also do not have access to the dynamo trace in the GraphModule exported by dynamo. This change plumbs through the compile_context into aot_compile as a part of GraphModule.meta without a major change to APIs within dynamo. Addresses issue: https://github.com/pytorch/pytorch/issues/123759?fbclid=IwY2xjawGE0LBleHRuA2FlbQIxMQABHS-PRpxvsrsHCDPdStHpqr1jQvx1YOnrPsRAfYAb-oXkU8MxidkIUENY-Q_aem_MAT2oaOgD03C8ggBNm575Q#issuecomment-2430722505 Test Plan: ``` buck2 test mode/opt //caffe2/test/dynamo:test_dynamo Buck UI: https://www.internalfb.com/buck2/ad64c267-65be-47cf-a94f-e4b26e6e030b Test UI: https://www.internalfb.com/intern/testinfra/testrun/9288674286334710 Network: Up: 83KiB Down: 314KiB (reSessionID-1dad223b-c91d-4718-97a4-bb2c81e480f0) Jobs completed: 10750. Time elapsed: 19:18.5s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 5365. Fail 2. Fatal 0. Skip 4. Build failure 0 buck2 test mode/opt //caffe2/test/dynamo:test_dynamo_fb Buck UI: https://www.internalfb.com/buck2/179a60bb-34e1-43b3-97ad-91af8a93ab01 Test UI: https://www.internalfb.com/intern/testinfra/testrun/2533275046340687 Network: Up: 201KiB Down: 1.8GiB (reSessionID-36f33983-6d78-4ec9-aa1b-34cee80dcb4f) Jobs completed: 17. Time elapsed: 42.9s. Cache hits: 0%. Commands: 1 (cached: 0, remote: 0, local: 1) Tests finished: Pass 6. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxZGXf6/index.html Repor fixed: https://github.com/pytorch/pytorch/issues/123759?fbclid=IwY2xjawGE0LBleHRuA2FlbQIxMQABHS-PRpxvsrsHCDPdStHpqr1jQvx1YOnrPsRAfYAb-oXkU8MxidkIUENY-Q_aem_MAT2oaOgD03C8ggBNm575Q#issuecomment-2430722505 Differential Revision: D64863946 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138793 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			369 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
import copy
 | 
						|
import dataclasses
 | 
						|
import functools
 | 
						|
import io
 | 
						|
import json
 | 
						|
import logging
 | 
						|
import os
 | 
						|
import re
 | 
						|
import sys
 | 
						|
import types
 | 
						|
import warnings
 | 
						|
import weakref
 | 
						|
import zipfile
 | 
						|
from collections import OrderedDict
 | 
						|
from contextlib import contextmanager
 | 
						|
from functools import lru_cache
 | 
						|
 | 
						|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 | 
						|
from unittest.mock import patch
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.fx
 | 
						|
import torch.utils._pytree as pytree
 | 
						|
 | 
						|
from torch._dispatch.python import enable_python_dispatcher
 | 
						|
from torch._guards import compile_context
 | 
						|
from torch._utils_internal import log_export_usage
 | 
						|
from torch.export._tree_utils import reorder_kwargs
 | 
						|
from torch.export.graph_signature import (
 | 
						|
    ArgumentSpec,
 | 
						|
    ConstantArgument,
 | 
						|
    ExportGraphSignature,
 | 
						|
    InputKind,
 | 
						|
    InputSpec,
 | 
						|
    OutputKind,
 | 
						|
    OutputSpec,
 | 
						|
    SymIntArgument,
 | 
						|
    TensorArgument,
 | 
						|
)
 | 
						|
from torch.fx import traceback as fx_traceback
 | 
						|
from torch.fx._compatibility import compatibility
 | 
						|
from torch.fx.experimental.proxy_tensor import make_fx
 | 
						|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
 | 
						|
 | 
						|
from .wrappers import _wrap_submodules
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
 | 
						|
@dataclasses.dataclass
 | 
						|
class ExportDynamoConfig:
 | 
						|
    """
 | 
						|
    Manage Export-specific configurations of Dynamo.
 | 
						|
    """
 | 
						|
    allow_rnn: bool = True
 | 
						|
 | 
						|
 | 
						|
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
 | 
						|
# is called multiple times.
 | 
						|
@lru_cache
 | 
						|
def capture_pre_autograd_graph_warning():
 | 
						|
    from torch._inductor import config
 | 
						|
 | 
						|
    log.warning("+============================+")
 | 
						|
    log.warning("|     !!!   WARNING   !!!    |")
 | 
						|
    log.warning("+============================+")
 | 
						|
    log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
 | 
						|
    log.warning("Please switch to use torch.export.export_for_training instead.")
 | 
						|
    if config.is_fbcode():
 | 
						|
        log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.")  # noqa: B950
 | 
						|
 | 
						|
@lru_cache
 | 
						|
def print_export_warning():
 | 
						|
    log.warning("Using torch.export.export_for_training(...,strict=True)")
 | 
						|
 | 
						|
def gm_using_training_ir(graph_module):
 | 
						|
    """
 | 
						|
    Returns true if the graph module is detected to use training IR.
 | 
						|
 | 
						|
    This function checks for two specific conditions within the nodes of the graph module:
 | 
						|
    1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
 | 
						|
    2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
 | 
						|
 | 
						|
    The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
 | 
						|
    """
 | 
						|
    # TODO: clean up this code after training IR migration.
 | 
						|
    # T199018392
 | 
						|
    has_training_ir_batch_norm = False
 | 
						|
    has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
 | 
						|
    for node in graph_module.graph.nodes:
 | 
						|
        if node.op == "call_function":
 | 
						|
            if node.target == torch.ops.aten.batch_norm.default:
 | 
						|
                has_training_ir_batch_norm = True
 | 
						|
            if node.meta.get("capture_pre_autograd_graph_tag", False):
 | 
						|
                has_deprecated_ir_tag = True
 | 
						|
            if node.target in [
 | 
						|
                torch.ops.aten._native_batch_norm_legit.default,
 | 
						|
                torch.ops.aten.cudnn_batch_norm.default,
 | 
						|
                torch.ops.aten.miopen_batch_norm.default,
 | 
						|
            ]:
 | 
						|
                has_deprecated_ir_tag = True
 | 
						|
 | 
						|
    if has_deprecated_ir_tag and has_training_ir_batch_norm:
 | 
						|
        raise RuntimeError("Conflicting IR detected.")
 | 
						|
    return has_training_ir_batch_norm or not has_deprecated_ir_tag
 | 
						|
 | 
						|
@compatibility(is_backward_compatible=False)
 | 
						|
def capture_pre_autograd_graph(
 | 
						|
    f: torch.nn.Module,
 | 
						|
    args: Tuple[Any],
 | 
						|
    kwargs: Optional[Dict[str, Any]] = None,
 | 
						|
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
 | 
						|
) -> torch.nn.Module:
 | 
						|
    """
 | 
						|
    A helper function that is intended to trace a module before any pre-autograd
 | 
						|
    decomposition is run. The produced module will be "non-functional" and
 | 
						|
    composed of aten operators. Later this API will be deleted in favor of more general
 | 
						|
    torch.export API.
 | 
						|
 | 
						|
    Args:
 | 
						|
      f: nn.Module to be traced
 | 
						|
 | 
						|
      args: example positional inputs.
 | 
						|
 | 
						|
      kwargs: optional example keyword inputs.
 | 
						|
 | 
						|
      dynamic_shapes: Should either be:
 | 
						|
         1) a dict from argument names of ``f`` to their dynamic shape specifications,
 | 
						|
         2) a tuple that specifies dynamic shape specifications for each input in original order.
 | 
						|
         If you are specifying dynamism on keyword args, you will need to pass them in the order that
 | 
						|
         is defined in the original function signature.
 | 
						|
 | 
						|
         The dynamic shape of a tensor argument can be specified as either
 | 
						|
         (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
 | 
						|
         not required to include static dimension indices in this dict, but when they are,
 | 
						|
         they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
 | 
						|
         where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
 | 
						|
         are denoted by None. Arguments that are dicts or tuples / lists of tensors are
 | 
						|
         recursively specified by using mappings or sequences of contained specifications.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        An nn.Module containing the traced method.
 | 
						|
 | 
						|
    """
 | 
						|
    from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
 | 
						|
    from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
 | 
						|
    from torch._export.non_strict_utils import make_constraints
 | 
						|
    from torch._subclasses.functional_tensor import FunctionalTensor
 | 
						|
    from torch.export._unlift import _create_stateful_graph_module
 | 
						|
    from torch.export.dynamic_shapes import _combine_args
 | 
						|
 | 
						|
    capture_pre_autograd_graph_warning()
 | 
						|
 | 
						|
    if sys.platform == "win32":
 | 
						|
        raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
 | 
						|
 | 
						|
    assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
 | 
						|
 | 
						|
    if kwargs is None:
 | 
						|
        kwargs = {}
 | 
						|
 | 
						|
    if capture_pre_autograd_graph_using_training_ir():
 | 
						|
        print_export_warning()
 | 
						|
        module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
 | 
						|
    else:
 | 
						|
        log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
 | 
						|
 | 
						|
        # Do not decompose dropout for exported models, because in eval mode the dropout
 | 
						|
        # op disappears from the graph, which makes it difficult to switch to train mode.
 | 
						|
        # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
 | 
						|
        decomp_table = {
 | 
						|
            op: op.decompose
 | 
						|
            for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
 | 
						|
            if op != torch.ops.aten.dropout.default
 | 
						|
        }
 | 
						|
        with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
 | 
						|
            m = torch._dynamo.export(
 | 
						|
                f,
 | 
						|
                dynamic_shapes=dynamic_shapes,
 | 
						|
                assume_static_by_default=True,
 | 
						|
                tracing_mode="symbolic",
 | 
						|
                decomposition_table=decomp_table,
 | 
						|
                pre_dispatch=True,
 | 
						|
                aten_graph=True,
 | 
						|
                _log_export_usage=False,
 | 
						|
            )(
 | 
						|
                *args,
 | 
						|
                **kwargs,
 | 
						|
            )[0]
 | 
						|
 | 
						|
            _, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
 | 
						|
 | 
						|
            m.meta["inline_constraints"] = {
 | 
						|
                k: v
 | 
						|
                for k, v in fake_mode.shape_env.var_to_range.items()
 | 
						|
                if re.match(r"^[if]\d+$", str(k))
 | 
						|
            }
 | 
						|
 | 
						|
            if isinstance(f, torch.nn.Module):
 | 
						|
                from torch.export._trace import _restore_state_dict
 | 
						|
                _restore_state_dict(f, m)
 | 
						|
 | 
						|
            flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
 | 
						|
            combined_args = _combine_args(f, args, kwargs)
 | 
						|
            range_constraints = make_constraints(
 | 
						|
                fake_mode,
 | 
						|
                m,
 | 
						|
                combined_args,
 | 
						|
                dynamic_shapes,
 | 
						|
                0,
 | 
						|
            )
 | 
						|
 | 
						|
            module = _create_stateful_graph_module(
 | 
						|
                m,
 | 
						|
                range_constraints=range_constraints,
 | 
						|
            )
 | 
						|
 | 
						|
            setattr(module, "capture_pre_autograd_graph_tag", True)  # noqa: B010
 | 
						|
            for node in module.graph.nodes:
 | 
						|
                node.meta["capture_pre_autograd_graph_tag"] = True
 | 
						|
 | 
						|
    error_message = \
 | 
						|
        """
 | 
						|
        Calling train() or eval() is not supported for exported models.
 | 
						|
        Alternatively, you may override these methods to do custom user behavior as follows:
 | 
						|
 | 
						|
            def _my_train(self, mode: bool = True):
 | 
						|
                ...
 | 
						|
 | 
						|
            def _my_eval(self):
 | 
						|
                ...
 | 
						|
 | 
						|
            model.train = types.MethodType(_my_train, model)
 | 
						|
            model.eval = types.MethodType(_my_eval, model)
 | 
						|
        """
 | 
						|
 | 
						|
    def _train(self, mode: bool = True):
 | 
						|
        raise NotImplementedError(error_message)
 | 
						|
 | 
						|
    def _eval(self, mode: bool = True):
 | 
						|
        raise NotImplementedError(error_message)
 | 
						|
 | 
						|
    module.train = types.MethodType(_train, module)  # type: ignore[method-assign]
 | 
						|
    module.eval = types.MethodType(_eval, module)  # type: ignore[method-assign]
 | 
						|
 | 
						|
    # Remove Proxy because they cannot be deepcopied or pickled.
 | 
						|
    if hasattr(module, "_buffers"):
 | 
						|
        torch._export.utils.remove_proxy_from_state_dict(
 | 
						|
            module._buffers, in_place=True
 | 
						|
        )
 | 
						|
    return module
 | 
						|
 | 
						|
 | 
						|
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
 | 
						|
# is called multiple times.
 | 
						|
@lru_cache
 | 
						|
def aot_compile_warning():
 | 
						|
    from torch._inductor import config
 | 
						|
 | 
						|
    log.warning("+============================+")
 | 
						|
    log.warning("|     !!!   WARNING   !!!    |")
 | 
						|
    log.warning("+============================+")
 | 
						|
    log.warning(
 | 
						|
        "torch._export.aot_compile() is being deprecated, please switch to "
 | 
						|
        "directly calling torch._inductor.aoti_compile_and_package(torch.export.export()) instead.")
 | 
						|
 | 
						|
 | 
						|
def aot_compile(
 | 
						|
    f: Callable,
 | 
						|
    args: Tuple[Any],
 | 
						|
    kwargs: Optional[Dict[str, Any]] = None,
 | 
						|
    *,
 | 
						|
    dynamic_shapes: Optional[Dict[str, Any]] = None,
 | 
						|
    options: Optional[Dict[str, Any]] = None,
 | 
						|
    remove_runtime_assertions: bool = False,
 | 
						|
    disable_constraint_solver: bool = False,
 | 
						|
    same_signature: bool = True,
 | 
						|
) -> str:
 | 
						|
    """
 | 
						|
    Note: this function is not stable yet
 | 
						|
 | 
						|
    Traces either an nn.Module's forward function or just a callable with PyTorch
 | 
						|
    operations inside, generates executable cpp code from the program, and returns
 | 
						|
    the path to the generated shared library
 | 
						|
 | 
						|
    Args:
 | 
						|
        f: the `nn.Module` or callable to trace.
 | 
						|
 | 
						|
        args: example positional inputs.
 | 
						|
 | 
						|
        kwargs: optional example keyword inputs.
 | 
						|
 | 
						|
        dynamic_shapes: Should either be:
 | 
						|
            1) a dict from argument names of ``f`` to their dynamic shape specifications,
 | 
						|
            2) a tuple that specifies dynamic shape specifications for each input in original order.
 | 
						|
            If you are specifying dynamism on keyword args, you will need to pass them in the order that
 | 
						|
            is defined in the original function signature.
 | 
						|
 | 
						|
            The dynamic shape of a tensor argument can be specified as either
 | 
						|
            (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
 | 
						|
            not required to include static dimension indices in this dict, but when they are,
 | 
						|
            they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
 | 
						|
            where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
 | 
						|
            are denoted by None. Arguments that are dicts or tuples / lists of tensors are
 | 
						|
            recursively specified by using mappings or sequences of contained specifications.
 | 
						|
 | 
						|
        options: A dictionary of options to control inductor
 | 
						|
 | 
						|
        disable_constraint_solver: Whether the dim constraint solver must be disabled.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Path to the generated shared library
 | 
						|
    """
 | 
						|
    from torch.export._trace import _export_to_torch_ir
 | 
						|
    from torch._inductor.decomposition import select_decomp_table
 | 
						|
    from torch._inductor import config
 | 
						|
 | 
						|
    aot_compile_warning()
 | 
						|
 | 
						|
    if config.is_predispatch:
 | 
						|
        gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
 | 
						|
    else:
 | 
						|
        # We want to export to Torch IR here to utilize the pre_grad passes in
 | 
						|
        # inductor, which run on Torch IR.
 | 
						|
        gm = _export_to_torch_ir(
 | 
						|
            f,
 | 
						|
            args,
 | 
						|
            kwargs,
 | 
						|
            dynamic_shapes,
 | 
						|
            disable_constraint_solver=disable_constraint_solver,
 | 
						|
            same_signature=same_signature,
 | 
						|
            # Disabling this flag, because instead we can rely on the mapping
 | 
						|
            # dynamo_flat_name_to_original_fqn which is coming from Dynamo.
 | 
						|
            restore_fqn=False,
 | 
						|
        )
 | 
						|
 | 
						|
    with torch.no_grad():
 | 
						|
        so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options)  # type: ignore[arg-type]
 | 
						|
 | 
						|
    return so_path
 | 
						|
 | 
						|
def aot_load(so_path: str, device: str) -> Callable:
 | 
						|
    """
 | 
						|
    Loads a shared library generated by aot_compile and returns a callable
 | 
						|
 | 
						|
    Args:
 | 
						|
        so_path: Path to the shared library
 | 
						|
 | 
						|
    Returns:
 | 
						|
        A callable
 | 
						|
    """
 | 
						|
    if device == "cpu":
 | 
						|
        runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)  # type: ignore[call-arg]
 | 
						|
    elif device == "cuda" or device.startswith("cuda:"):
 | 
						|
        runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)  # type: ignore[assignment, call-arg]
 | 
						|
    else:
 | 
						|
        raise RuntimeError("Unsupported device " + device)
 | 
						|
 | 
						|
    def optimized(*args, **kwargs):
 | 
						|
        call_spec = runner.get_call_spec()  # type: ignore[attr-defined]
 | 
						|
        in_spec = pytree.treespec_loads(call_spec[0])
 | 
						|
        out_spec = pytree.treespec_loads(call_spec[1])
 | 
						|
        flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
 | 
						|
        flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
 | 
						|
        flat_outputs = runner.run(flat_inputs)  # type: ignore[attr-defined]
 | 
						|
        return pytree.tree_unflatten(flat_outputs, out_spec)
 | 
						|
 | 
						|
    return optimized
 |