mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/156311 Approved by: https://github.com/albanD
		
			
				
	
	
		
			1553 lines
		
	
	
		
			53 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1553 lines
		
	
	
		
			53 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
"""
 | 
						|
The weak_script annotation needs to be here instead of inside torch/jit/ so it
 | 
						|
can be used in other places in torch/ (namely torch.nn) without running into
 | 
						|
circular dependency problems
 | 
						|
"""
 | 
						|
 | 
						|
import ast
 | 
						|
import builtins
 | 
						|
import collections
 | 
						|
import contextlib
 | 
						|
import enum
 | 
						|
import inspect
 | 
						|
import io
 | 
						|
import pickle
 | 
						|
import sys
 | 
						|
import textwrap
 | 
						|
import threading
 | 
						|
import types
 | 
						|
import typing
 | 
						|
import warnings
 | 
						|
import weakref
 | 
						|
from typing import (  # noqa: UP035, F401  # (Dict, List, Tuple) imported by torch.jit.annotations
 | 
						|
    Any,
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Final,
 | 
						|
    ForwardRef,
 | 
						|
    get_args,
 | 
						|
    get_origin,
 | 
						|
    List,
 | 
						|
    Optional,
 | 
						|
    Tuple,
 | 
						|
    TypeVar,
 | 
						|
    Union,
 | 
						|
)
 | 
						|
from typing_extensions import ParamSpec
 | 
						|
 | 
						|
import torch
 | 
						|
 | 
						|
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
 | 
						|
# Explicitly ask to import `torch.distributed.__init__` first.
 | 
						|
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
 | 
						|
import torch.distributed.rpc
 | 
						|
import torch.package._mangling as package_mangling
 | 
						|
from torch._awaits import _Await
 | 
						|
from torch._C import _Await as CAwait, Future as CFuture
 | 
						|
from torch._sources import fake_range, get_source_lines_and_file, parse_def
 | 
						|
from torch.futures import Future
 | 
						|
 | 
						|
 | 
						|
_P = ParamSpec("_P")
 | 
						|
_R = TypeVar("_R")
 | 
						|
 | 
						|
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
 | 
						|
 | 
						|
BuiltinUnionType: Union[type, tuple[type, ...]]
 | 
						|
if sys.version_info >= (3, 10):
 | 
						|
    # NOTE: IS_PY310_PLUS doesn't work with mypy.
 | 
						|
    # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
 | 
						|
    BuiltinUnionType = types.UnionType
 | 
						|
else:
 | 
						|
    BuiltinUnionType = ()  # trick: this makes isinstance short circuit.
 | 
						|
 | 
						|
LockType: type
 | 
						|
try:
 | 
						|
    import _thread
 | 
						|
 | 
						|
    LockType = _thread.LockType
 | 
						|
except ImportError:
 | 
						|
    import _dummy_thread  # type: ignore[import-not-found]
 | 
						|
 | 
						|
    LockType = _dummy_thread.LockType
 | 
						|
 | 
						|
# Wrapper functions that can call either of 2 functions depending on a boolean
 | 
						|
# argument
 | 
						|
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, dict[str, Callable]]" = (
 | 
						|
    weakref.WeakKeyDictionary()
 | 
						|
)  # noqa: T484
 | 
						|
 | 
						|
 | 
						|
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
 | 
						|
 | 
						|
 | 
						|
def is_final(ann) -> bool:
 | 
						|
    return (
 | 
						|
        hasattr(ann, "__module__")
 | 
						|
        and ann.__module__ in {"typing", "typing_extensions"}
 | 
						|
        and (get_origin(ann) is Final or isinstance(ann, type(Final)))
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
# allows BroadcastingList instance to be subscriptable
 | 
						|
class BroadcastingListCls:
 | 
						|
    def __getitem__(self, types):
 | 
						|
        return
 | 
						|
 | 
						|
 | 
						|
# mypy doesn't support parameters on types, so we have to explicitly type each
 | 
						|
# list size
 | 
						|
BroadcastingList1 = BroadcastingListCls()
 | 
						|
for i in range(2, 7):
 | 
						|
    globals()[f"BroadcastingList{i}"] = BroadcastingList1
 | 
						|
 | 
						|
 | 
						|
def is_scripting() -> bool:
 | 
						|
    r"""
 | 
						|
    Function that returns True when in compilation and False otherwise. This
 | 
						|
    is useful especially with the @unused decorator to leave code in your
 | 
						|
    model that is not yet TorchScript compatible.
 | 
						|
    .. testcode::
 | 
						|
 | 
						|
        import torch
 | 
						|
 | 
						|
        @torch.jit.unused
 | 
						|
        def unsupported_linear_op(x):
 | 
						|
            return x
 | 
						|
 | 
						|
        def linear(x):
 | 
						|
            if torch.jit.is_scripting():
 | 
						|
                return torch.linear(x)
 | 
						|
            else:
 | 
						|
                return unsupported_linear_op(x)
 | 
						|
    """
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
 | 
						|
def _qualified_name(obj, mangle_name=True) -> str:
 | 
						|
    # This special case allows us to override the qualified name on a type.
 | 
						|
    # It's currently used in conjunction with tracing, where we create a
 | 
						|
    # fake module to filter only supported attributes. However, since this
 | 
						|
    # new type is defined as a local class, we need a mechanism to override
 | 
						|
    # its qualname so it appears correctly in the TorchScript system. This,
 | 
						|
    # we set '_jit_override_qualname' with the original traced module's
 | 
						|
    # qualified name, which is picked up here
 | 
						|
    if hasattr(obj, "_jit_override_qualname"):
 | 
						|
        return obj._jit_override_qualname
 | 
						|
    # short-circuit in cases where the object already has a known qualified name
 | 
						|
    if isinstance(obj, torch._C.ScriptFunction):
 | 
						|
        return obj.qualified_name
 | 
						|
 | 
						|
    if getattr(obj, "__name__", None):
 | 
						|
        name = obj.__name__
 | 
						|
    # Enum classes do not have `__name__` attr, instead they have `name`.
 | 
						|
    elif isinstance(obj, enum.Enum):
 | 
						|
        name = obj.name
 | 
						|
    else:
 | 
						|
        raise RuntimeError("Could not get name of python class object")
 | 
						|
 | 
						|
    if name == "<lambda>":
 | 
						|
        name = "_lambda"  # make name a valid identifier
 | 
						|
 | 
						|
    module_name = obj.__module__
 | 
						|
 | 
						|
    # If the module is actually a torchbind module, then we should short circuit
 | 
						|
    if module_name == "torch._classes":
 | 
						|
        return obj.qualified_name
 | 
						|
 | 
						|
    # The Python docs are very clear that `__module__` can be None, but I can't
 | 
						|
    # figure out when it actually would be.
 | 
						|
    if module_name is None:
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Could not get qualified name for class '{name}': "
 | 
						|
            "__module__ can't be None."
 | 
						|
        )
 | 
						|
 | 
						|
    # if getattr(sys.modules[module_name], name) is not obj:
 | 
						|
    #     raise RuntimeError(f"Could not get qualified name for class '{name}': "
 | 
						|
    #                        f"the attr {name} on module {module_name} is not the class")
 | 
						|
 | 
						|
    # torch.package and TorchScript have separate mangling schemes to avoid
 | 
						|
    # name collisions from multiple packages. To avoid them interfering with
 | 
						|
    # each other, normalize the package managing here.
 | 
						|
    if package_mangling.is_mangled(module_name):
 | 
						|
        module_name = module_name.replace("<", "_")
 | 
						|
        module_name = module_name.replace(">", "_")
 | 
						|
 | 
						|
    # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
 | 
						|
    # does not need mangle the python class name.
 | 
						|
    if mangle_name:
 | 
						|
        # __main__ is a builtin module, so rewrite it to "__torch__".
 | 
						|
        if module_name == "__main__":
 | 
						|
            module_name = "__torch__"
 | 
						|
        else:
 | 
						|
            # Everything else gets a "__torch__" prefix to avoid name collisions
 | 
						|
            # with the names of user values.
 | 
						|
            module_name = "__torch__." + module_name
 | 
						|
 | 
						|
    if "." in name:
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Could not get qualified name for class '{name}': "
 | 
						|
            f"'{name}' is not a valid identifier"
 | 
						|
        )
 | 
						|
 | 
						|
    return module_name + "." + name
 | 
						|
 | 
						|
 | 
						|
class SourceLoader:
 | 
						|
    def __init__(self):
 | 
						|
        self.content = {}
 | 
						|
 | 
						|
    def cache(self, fn, source):
 | 
						|
        self.content[fn] = source
 | 
						|
 | 
						|
    def get_source(self, fn):
 | 
						|
        return self.content.get(fn)
 | 
						|
 | 
						|
 | 
						|
loader = SourceLoader()
 | 
						|
 | 
						|
 | 
						|
def createResolutionCallbackFromEnv(lookup_base):
 | 
						|
    """
 | 
						|
    Creates a resolution callback that will look up qualified names in an
 | 
						|
    environment, starting with `lookup_base` for the base of any qualified
 | 
						|
    names, then proceeding down the lookup chain with the resolved object.
 | 
						|
 | 
						|
    You should not use this directly, it should only be used from the other
 | 
						|
    createResolutionCallbackFrom* functions.
 | 
						|
    """
 | 
						|
 | 
						|
    def lookupInModule(qualified_name, module):
 | 
						|
        if "." in qualified_name:
 | 
						|
            base, remaining_pieces = qualified_name.split(".", maxsplit=1)
 | 
						|
            module_value = getattr(module, base)
 | 
						|
            return lookupInModule(remaining_pieces, module_value)
 | 
						|
        else:
 | 
						|
            return getattr(module, qualified_name)
 | 
						|
 | 
						|
    def parseNestedExpr(expr, module) -> tuple[Any, int]:
 | 
						|
        i = 0
 | 
						|
        while i < len(expr) and expr[i] not in (",", "[", "]"):
 | 
						|
            i += 1
 | 
						|
 | 
						|
        # Special case logic for the empty Tuple as a subscript (used
 | 
						|
        # in the type annotation `Tuple[()]`)
 | 
						|
        if expr[:i] == "()":
 | 
						|
            return (), i
 | 
						|
 | 
						|
        base = lookupInModule(expr[:i].strip(), module)
 | 
						|
        assert base is not None, f"Unresolvable type {expr[:i]}"
 | 
						|
        if i == len(expr) or expr[i] != "[":
 | 
						|
            return base, i
 | 
						|
 | 
						|
        assert expr[i] == "["
 | 
						|
        parts = []
 | 
						|
        while expr[i] != "]":
 | 
						|
            part_len = 0
 | 
						|
            i += 1
 | 
						|
            part, part_len = parseNestedExpr(expr[i:], module)
 | 
						|
            parts.append(part)
 | 
						|
            i += part_len
 | 
						|
        if len(parts) > 1:
 | 
						|
            return base[tuple(parts)], i + 1
 | 
						|
        else:
 | 
						|
            return base[parts[0]], i + 1
 | 
						|
 | 
						|
    def parseExpr(expr, module):
 | 
						|
        try:
 | 
						|
            value, len_parsed = parseNestedExpr(expr, module)
 | 
						|
            assert len_parsed == len(expr), (
 | 
						|
                "whole expression was not parsed, falling back to c++ parser"
 | 
						|
            )
 | 
						|
            return value
 | 
						|
        except Exception:
 | 
						|
            """
 | 
						|
            The python resolver fails in several cases in known unit tests, and is intended
 | 
						|
            to fall back gracefully to the c++ resolver in general.  For example, python 2 style
 | 
						|
            annotations which are frequent in our unit tests often fail with types e.g. int not
 | 
						|
            resolvable from the calling frame.
 | 
						|
            """
 | 
						|
            return None
 | 
						|
 | 
						|
    return lambda expr: parseExpr(expr, lookup_base)
 | 
						|
 | 
						|
 | 
						|
def createResolutionCallbackFromFrame(frames_up: int = 0):
 | 
						|
    """
 | 
						|
    Creates a function which, given a string variable name,
 | 
						|
    returns the value of the variable in the scope of the caller of
 | 
						|
    the function which called createResolutionCallbackFromFrame (by default).
 | 
						|
 | 
						|
    This is used to enable access in-scope Python variables inside
 | 
						|
    TorchScript fragments.
 | 
						|
 | 
						|
    frames_up is number of additional frames to go up on the stack.
 | 
						|
    The default value is 0, which correspond to the frame of the caller
 | 
						|
    of createResolutionCallbackFromFrame. Also for example, if frames_up is set
 | 
						|
    to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
 | 
						|
    will be taken.
 | 
						|
 | 
						|
    For example, the following program prints 2::
 | 
						|
 | 
						|
        def bar():
 | 
						|
            cb = createResolutionCallbackFromFrame(1)
 | 
						|
            print(cb("foo"))
 | 
						|
 | 
						|
 | 
						|
        def baz():
 | 
						|
            foo = 2
 | 
						|
            bar()
 | 
						|
 | 
						|
 | 
						|
        baz()
 | 
						|
    """
 | 
						|
    frame = inspect.currentframe()
 | 
						|
    i = 0
 | 
						|
    while i < frames_up + 1:
 | 
						|
        assert frame is not None
 | 
						|
        frame = frame.f_back
 | 
						|
        i += 1
 | 
						|
 | 
						|
    assert frame is not None
 | 
						|
    f_locals = frame.f_locals
 | 
						|
    f_globals = frame.f_globals
 | 
						|
 | 
						|
    class env:
 | 
						|
        def __getattr__(self, key):
 | 
						|
            if key in f_locals:
 | 
						|
                return f_locals[key]
 | 
						|
            elif key in f_globals:
 | 
						|
                return f_globals[key]
 | 
						|
            elif key in dir(builtins):
 | 
						|
                return getattr(builtins, key)
 | 
						|
 | 
						|
    return createResolutionCallbackFromEnv(env())
 | 
						|
 | 
						|
 | 
						|
def get_closure(fn):
 | 
						|
    """
 | 
						|
    Get a dictionary of closed over variables from a function
 | 
						|
    """
 | 
						|
    captures = {}
 | 
						|
    captures.update(fn.__globals__)
 | 
						|
 | 
						|
    for index, captured_name in enumerate(fn.__code__.co_freevars):
 | 
						|
        captures[captured_name] = fn.__closure__[index].cell_contents
 | 
						|
 | 
						|
    return captures
 | 
						|
 | 
						|
 | 
						|
# [local resolution in python]
 | 
						|
# Depending on where a variable is defined, and where it is used, we may
 | 
						|
# or may not be able to recover its value when recursively compiling a
 | 
						|
# script function. Remember in the general case, a module or function is
 | 
						|
# first defined and then later scripted. This means we do not have a
 | 
						|
# chance to capture the active frames when the function is defined. Hence any
 | 
						|
# name resolution has to happen later on the created closure. The way
 | 
						|
# python captures type annotations restricts what we can recover. The
 | 
						|
# follow example illustrates the different cases:
 | 
						|
#
 | 
						|
#         class MyGlobalClass:
 | 
						|
#         ...
 | 
						|
#         def my_local_scope():
 | 
						|
#             @torch.jit.script
 | 
						|
#             class MyClass:
 | 
						|
#                 ...
 | 
						|
#             @torch.jit.script
 | 
						|
#             class MyClassUsedAsVar:
 | 
						|
#                 ...
 | 
						|
#             def eg(x: MyClass, y: MyGlobalClass):
 | 
						|
#                 a_local_capture : Foo
 | 
						|
#                 return MyClassUsedAsVar(x)
 | 
						|
#
 | 
						|
# MyGlobalClass is defined in the __globals__ dictionary of function
 | 
						|
# 'eg', so it is always recoverable. my_local_scope introduces a new local
 | 
						|
# variable scope in the function. Classes defined here are only visible as
 | 
						|
# local variables. For the case of MyClassUsedAsVar, it is captured
 | 
						|
# because it is used as a variable inside the body of the function, and we
 | 
						|
# can resolve it using the captures returned from `get_closure`. However,
 | 
						|
# the type annotations are not captured by the closure. In Python
 | 
						|
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
 | 
						|
# annotations on `eg``, but starting in Python 4.0, they will represented as
 | 
						|
# strings and no longer present. Furthermore, since the body of `eg` does
 | 
						|
# not reference those names, they do not appear in the list of closed over
 | 
						|
# variables. In Python 2.x, type annotations are in comments, leading to a
 | 
						|
# similar situation where their definitions are not available. We anticipate
 | 
						|
# that most users will not run into this issue because their modules and
 | 
						|
# functions will be defined at a global scope like MyGlobalClass. In cases
 | 
						|
# where they are not, it is possible to work around issues by declaring the
 | 
						|
# values global in the function.
 | 
						|
# In Python 3.9 declaring class as global will make it invisible to
 | 
						|
# `inspect.getsource`, see https://bugs.python.org/issue42666 .
 | 
						|
# This could be worked around by manually adding it to `global()` dictionary.
 | 
						|
 | 
						|
 | 
						|
def createResolutionCallbackFromClosure(fn):
 | 
						|
    """
 | 
						|
    Create a resolutionCallback by introspecting the function instead of
 | 
						|
    looking up the stack for the enclosing scope
 | 
						|
    """
 | 
						|
    closure = get_closure(fn)
 | 
						|
 | 
						|
    class closure_lookup:
 | 
						|
        # This is a class since `closure` is a dict and it's easier in
 | 
						|
        # `env_helper` if everything just works with `getattr` calls
 | 
						|
        def __getattr__(self, key):
 | 
						|
            if key in closure:
 | 
						|
                return closure[key]
 | 
						|
            elif hasattr(typing, key):
 | 
						|
                return getattr(typing, key)
 | 
						|
            elif hasattr(builtins, key):
 | 
						|
                return getattr(builtins, key)
 | 
						|
            return None
 | 
						|
 | 
						|
    return createResolutionCallbackFromEnv(closure_lookup())
 | 
						|
 | 
						|
 | 
						|
def can_compile_class(cls) -> bool:
 | 
						|
    # If any of the functions on a type don't have a code object, this type can't
 | 
						|
    # be compiled and is probably a builtin / bound from C
 | 
						|
    if is_ignored_fn(cls):
 | 
						|
        return False
 | 
						|
 | 
						|
    # Ignore the following list of built-in classes.
 | 
						|
    ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
 | 
						|
    if issubclass(cls, ignored_builtin_classes):
 | 
						|
        return False
 | 
						|
 | 
						|
    names = cls.__dict__
 | 
						|
    fns = [
 | 
						|
        getattr(cls, name)
 | 
						|
        for name in names
 | 
						|
        if inspect.isroutine(getattr(cls, name, None))
 | 
						|
    ]
 | 
						|
    has_code = [hasattr(fn, "__code__") for fn in fns]
 | 
						|
    return all(has_code)
 | 
						|
 | 
						|
 | 
						|
def get_callable_argument_names(fn) -> list[str]:
 | 
						|
    """
 | 
						|
    Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
 | 
						|
    Returns an empty list when other types of arguments are present.
 | 
						|
 | 
						|
    This is used by `torch.jit.trace` to assign meaningful argument names to
 | 
						|
    traced functions and modules.
 | 
						|
 | 
						|
    Args:
 | 
						|
        fn: A callable.
 | 
						|
    Returns:
 | 
						|
        Argument names: List[str]
 | 
						|
    """
 | 
						|
    # inspect.signature may fail, give up in that case.
 | 
						|
    try:
 | 
						|
        callable_signature = inspect.signature(fn)
 | 
						|
    except Exception:
 | 
						|
        return []
 | 
						|
 | 
						|
    argument_names = []
 | 
						|
    for name, param in callable_signature.parameters.items():
 | 
						|
        # All four other types of arguments do not map to individual values
 | 
						|
        # with a keyword as name.
 | 
						|
        if not param.kind == param.POSITIONAL_OR_KEYWORD:
 | 
						|
            continue
 | 
						|
 | 
						|
        argument_names.append(name)
 | 
						|
 | 
						|
    return argument_names
 | 
						|
 | 
						|
 | 
						|
def get_annotation_str(annotation):
 | 
						|
    """
 | 
						|
    Convert an AST node containing a type annotation to the string present in the source
 | 
						|
    that represents the same annotation.
 | 
						|
    """
 | 
						|
    if isinstance(annotation, ast.Name):
 | 
						|
        return annotation.id
 | 
						|
    elif isinstance(annotation, ast.Attribute):
 | 
						|
        return ".".join([get_annotation_str(annotation.value), annotation.attr])
 | 
						|
    elif isinstance(annotation, ast.Subscript):
 | 
						|
        # In Python3.9+ subscript indices are not wrapped in ast.Index
 | 
						|
        subscript_slice = annotation.slice
 | 
						|
        return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
 | 
						|
    elif isinstance(annotation, ast.Tuple):
 | 
						|
        return ",".join([get_annotation_str(elt) for elt in annotation.elts])
 | 
						|
    elif isinstance(annotation, ast.Constant):
 | 
						|
        return f"{annotation.value}"
 | 
						|
 | 
						|
    # If an AST node is not handled here, it's probably handled in ScriptTypeParser.
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
def get_type_hint_captures(fn):
 | 
						|
    """
 | 
						|
    Get a dictionary containing type resolution mappings necessary to resolve types
 | 
						|
    for the literal annotations on 'fn'. These are not considered to be closed-over by fn
 | 
						|
    and must be obtained separately (e.g. using this function).
 | 
						|
 | 
						|
    Args:
 | 
						|
        fn: A callable.
 | 
						|
    Returns:
 | 
						|
        A Dict[str, Any] containing a mapping from the literal annotations used on
 | 
						|
        fn to the Python objects they refer to.
 | 
						|
    """
 | 
						|
    # First, try to get the source of the function. We'll need to parse it to find the actual string names
 | 
						|
    # that were used to annotate the types, since inspect.signature() will only return the class object that
 | 
						|
    # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
 | 
						|
    # This may happen in cases where the function is synthesized dynamically at runtime.
 | 
						|
    src = loader.get_source(fn)
 | 
						|
    if src is None:
 | 
						|
        try:
 | 
						|
            src = inspect.getsource(fn)
 | 
						|
        except OSError as e:
 | 
						|
            raise OSError(
 | 
						|
                f"Failed to get source for {fn} using inspect.getsource"
 | 
						|
            ) from e
 | 
						|
 | 
						|
    # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
 | 
						|
    # types are strings. These are only understood by TorchScript in the context of a type annotation
 | 
						|
    # that refers to a class in its own definition, but trying to include a mapping for this in the result
 | 
						|
    # function would cause infinite recursion because the class is currently being compiled.
 | 
						|
    # In addition, there is logic in ScriptTypeParser to handle this.
 | 
						|
    signature = inspect.signature(fn)
 | 
						|
    name_to_type = {
 | 
						|
        name: parameter.annotation
 | 
						|
        for name, parameter in signature.parameters.items()
 | 
						|
        if parameter.annotation is not inspect.Parameter.empty
 | 
						|
        and not isinstance(parameter.annotation, str)
 | 
						|
    }
 | 
						|
 | 
						|
    # Then, get the literal type annotations from the function declaration
 | 
						|
    # by source inspection. This accounts for the case in which aliases are used
 | 
						|
    # to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
 | 
						|
    # frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
 | 
						|
    a = ast.parse(textwrap.dedent(src))
 | 
						|
    if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
 | 
						|
        raise RuntimeError(f"Expected {fn} to be a function")
 | 
						|
    f = a.body[0]
 | 
						|
 | 
						|
    # Prepare a dictionary of source annotation -> type, which will be the final result of this function,
 | 
						|
    # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
 | 
						|
    # them to the type object corresponding to the annotation via name_to_type using the parameter name.
 | 
						|
    annotation_to_type = {}
 | 
						|
 | 
						|
    for arg in f.args.args:
 | 
						|
        # Get the source type annotation string for this argument if possible.
 | 
						|
        arg_annotation_str = (
 | 
						|
            get_annotation_str(arg.annotation) if arg.annotation else None
 | 
						|
        )
 | 
						|
 | 
						|
        # If the argument has no annotation or get_annotation_str cannot convert it to a string,
 | 
						|
        # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
 | 
						|
        # this in the latter case.
 | 
						|
        if arg_annotation_str is None:
 | 
						|
            continue
 | 
						|
 | 
						|
        # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
 | 
						|
        # be present in name_to_type is that the annotation itself is a string and not a type object
 | 
						|
        # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
 | 
						|
        arg_name = arg.arg
 | 
						|
        if arg_name in name_to_type:
 | 
						|
            annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
 | 
						|
 | 
						|
    # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
 | 
						|
    # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
 | 
						|
    # of the annotation cannot be a string.
 | 
						|
    literal_return_annotation = get_annotation_str(f.returns)
 | 
						|
    valid_literal_annotation = literal_return_annotation is not None
 | 
						|
    return_annotation = signature.return_annotation
 | 
						|
    valid_return_annotation_type = (
 | 
						|
        return_annotation is not inspect.Parameter.empty
 | 
						|
        and not isinstance(return_annotation, str)
 | 
						|
    )
 | 
						|
    if valid_literal_annotation and valid_return_annotation_type:
 | 
						|
        annotation_to_type[literal_return_annotation] = return_annotation
 | 
						|
 | 
						|
    return annotation_to_type
 | 
						|
 | 
						|
 | 
						|
def createResolutionCallbackForClassMethods(cls):
 | 
						|
    """
 | 
						|
    This looks at all the methods defined in a class and pulls their closed-over
 | 
						|
    variables into a dictionary and uses that to resolve variables.
 | 
						|
    """
 | 
						|
    # cls is a type here, so `ismethod` is false since the methods on the type
 | 
						|
    # aren't bound to anything, so Python treats them as regular functions
 | 
						|
    fns = [
 | 
						|
        getattr(cls, name)
 | 
						|
        for name in cls.__dict__
 | 
						|
        if inspect.isroutine(getattr(cls, name))
 | 
						|
    ]
 | 
						|
    # Skip built-ins, as they do not have global scope nor type hints
 | 
						|
    # Needed to support `enum.Enum` derived classes in Python-3.11
 | 
						|
    # That adds `_new_member_` property which is an alias to `__new__`
 | 
						|
    fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
 | 
						|
    captures = {}
 | 
						|
 | 
						|
    for fn in fns:
 | 
						|
        captures.update(get_closure(fn))
 | 
						|
        captures.update(get_type_hint_captures(fn))
 | 
						|
 | 
						|
    def lookup_in_class(key):
 | 
						|
        if key in captures:
 | 
						|
            return captures[key]
 | 
						|
        else:
 | 
						|
            return getattr(builtins, key, None)
 | 
						|
 | 
						|
    return lookup_in_class
 | 
						|
 | 
						|
 | 
						|
def boolean_dispatch(
 | 
						|
    arg_name,
 | 
						|
    arg_index,
 | 
						|
    default,
 | 
						|
    if_true,
 | 
						|
    if_false,
 | 
						|
    module_name,
 | 
						|
    func_name,
 | 
						|
):
 | 
						|
    """
 | 
						|
    Dispatches to either of 2 script functions based on a boolean argument.
 | 
						|
    In TorchScript, the boolean argument must be constant so that the correct
 | 
						|
    function to use can be determined at compile time.
 | 
						|
    """
 | 
						|
 | 
						|
    def fn(*args, **kwargs):
 | 
						|
        dispatch_flag = default
 | 
						|
        if arg_name in kwargs:
 | 
						|
            dispatch_flag = kwargs[arg_name]
 | 
						|
        elif arg_index < len(args):
 | 
						|
            dispatch_flag = args[arg_index]
 | 
						|
 | 
						|
        if dispatch_flag:
 | 
						|
            return if_true(*args, **kwargs)
 | 
						|
        else:
 | 
						|
            return if_false(*args, **kwargs)
 | 
						|
 | 
						|
    if if_true.__doc__ is None and if_false.__doc__ is not None:
 | 
						|
        doc = if_false.__doc__
 | 
						|
        if_true.__doc__ = doc
 | 
						|
    elif if_false.__doc__ is None and if_true.__doc__ is not None:
 | 
						|
        doc = if_true.__doc__
 | 
						|
        if_false.__doc__ = doc
 | 
						|
    elif if_false.__doc__ is None and if_true.__doc__ is None:
 | 
						|
        # neither function has a docstring
 | 
						|
        doc = None
 | 
						|
    else:
 | 
						|
        raise RuntimeError("only one function can have a docstring")
 | 
						|
    fn.__doc__ = doc
 | 
						|
 | 
						|
    if module_name is not None:
 | 
						|
        fn.__module__ = module_name
 | 
						|
    if func_name is not None:
 | 
						|
        fn.__name__ = func_name
 | 
						|
 | 
						|
    boolean_dispatched[fn] = {
 | 
						|
        "if_true": if_true,
 | 
						|
        "if_false": if_false,
 | 
						|
        "index": arg_index,
 | 
						|
        "default": default,
 | 
						|
        "arg_name": arg_name,
 | 
						|
    }
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
class FunctionModifiers:
 | 
						|
    """
 | 
						|
    Used to denote the behavior of a function in TorchScript. See export() and
 | 
						|
    ignore() for details.
 | 
						|
    """
 | 
						|
 | 
						|
    UNUSED = "unused (ignored and replaced with raising of an exception)"
 | 
						|
    IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
 | 
						|
    EXPORT = "export (compile this function even if nothing calls it)"
 | 
						|
    DEFAULT = "default (compile if called from a exported function / forward)"
 | 
						|
    COPY_TO_SCRIPT_WRAPPER = (
 | 
						|
        "if this method is not scripted, copy the python method onto the scripted model"
 | 
						|
    )
 | 
						|
    _DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
 | 
						|
 | 
						|
 | 
						|
def export(fn: Callable[_P, _R]) -> Callable[_P, _R]:
 | 
						|
    """
 | 
						|
    This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
 | 
						|
    :class:`ScriptModule` and should be compiled.
 | 
						|
 | 
						|
    ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
 | 
						|
    Functions and methods called from ``forward`` are compiled as they are seen
 | 
						|
    by the compiler, so they do not need this decorator either.
 | 
						|
 | 
						|
    Example (using ``@torch.jit.export`` on a method):
 | 
						|
 | 
						|
    .. testcode::
 | 
						|
 | 
						|
        import torch
 | 
						|
        import torch.nn as nn
 | 
						|
 | 
						|
        class MyModule(nn.Module):
 | 
						|
            def implicitly_compiled_method(self, x):
 | 
						|
                return x + 99
 | 
						|
 | 
						|
            # `forward` is implicitly decorated with `@torch.jit.export`,
 | 
						|
            # so adding it here would have no effect
 | 
						|
            def forward(self, x):
 | 
						|
                return x + 10
 | 
						|
 | 
						|
            @torch.jit.export
 | 
						|
            def another_forward(self, x):
 | 
						|
                # When the compiler sees this call, it will compile
 | 
						|
                # `implicitly_compiled_method`
 | 
						|
                return self.implicitly_compiled_method(x)
 | 
						|
 | 
						|
            def unused_method(self, x):
 | 
						|
                return x - 20
 | 
						|
 | 
						|
        # `m` will contain compiled methods:
 | 
						|
        #     `forward`
 | 
						|
        #     `another_forward`
 | 
						|
        #     `implicitly_compiled_method`
 | 
						|
        # `unused_method` will not be compiled since it was not called from
 | 
						|
        # any compiled methods and wasn't decorated with `@torch.jit.export`
 | 
						|
        m = torch.jit.script(MyModule())
 | 
						|
    """
 | 
						|
    fn._torchscript_modifier = FunctionModifiers.EXPORT  # type:ignore[attr-defined]
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
 | 
						|
    """
 | 
						|
    This decorator indicates to the compiler that a function or method should
 | 
						|
    be ignored and replaced with the raising of an exception. This allows you
 | 
						|
    to leave code in your model that is not yet TorchScript compatible and still
 | 
						|
    export your model.
 | 
						|
 | 
						|
        Example (using ``@torch.jit.unused`` on a method)::
 | 
						|
 | 
						|
            import torch
 | 
						|
            import torch.nn as nn
 | 
						|
 | 
						|
 | 
						|
            class MyModule(nn.Module):
 | 
						|
                def __init__(self, use_memory_efficient):
 | 
						|
                    super().__init__()
 | 
						|
                    self.use_memory_efficient = use_memory_efficient
 | 
						|
 | 
						|
                @torch.jit.unused
 | 
						|
                def memory_efficient(self, x):
 | 
						|
                    import pdb
 | 
						|
 | 
						|
                    pdb.set_trace()
 | 
						|
                    return x + 10
 | 
						|
 | 
						|
                def forward(self, x):
 | 
						|
                    # Use not-yet-scriptable memory efficient mode
 | 
						|
                    if self.use_memory_efficient:
 | 
						|
                        return self.memory_efficient(x)
 | 
						|
                    else:
 | 
						|
                        return x + 10
 | 
						|
 | 
						|
 | 
						|
            m = torch.jit.script(MyModule(use_memory_efficient=False))
 | 
						|
            m.save("m.pt")
 | 
						|
 | 
						|
            m = torch.jit.script(MyModule(use_memory_efficient=True))
 | 
						|
            # exception raised
 | 
						|
            m(torch.rand(100))
 | 
						|
    """
 | 
						|
    if isinstance(fn, property):
 | 
						|
        prop = fn
 | 
						|
        setattr(  # noqa: B010
 | 
						|
            prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
 | 
						|
        )
 | 
						|
 | 
						|
        if prop.fset:
 | 
						|
            setattr(  # noqa: B010
 | 
						|
                prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
 | 
						|
            )
 | 
						|
 | 
						|
        return prop
 | 
						|
 | 
						|
    fn._torchscript_modifier = FunctionModifiers.UNUSED  # type: ignore[attr-defined]
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
# No op context manager from python side
 | 
						|
class _IgnoreContextManager(contextlib.AbstractContextManager):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        pass
 | 
						|
 | 
						|
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
def ignore(drop=False, **kwargs):
 | 
						|
    """
 | 
						|
    This decorator indicates to the compiler that a function or method should
 | 
						|
    be ignored and left as a Python function. This allows you to leave code in
 | 
						|
    your model that is not yet TorchScript compatible. If called from TorchScript,
 | 
						|
    ignored functions will dispatch the call to the Python interpreter. Models with ignored
 | 
						|
    functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
 | 
						|
 | 
						|
    Example (using ``@torch.jit.ignore`` on a method)::
 | 
						|
 | 
						|
        import torch
 | 
						|
        import torch.nn as nn
 | 
						|
 | 
						|
 | 
						|
        class MyModule(nn.Module):
 | 
						|
            @torch.jit.ignore
 | 
						|
            def debugger(self, x):
 | 
						|
                import pdb
 | 
						|
 | 
						|
                pdb.set_trace()
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                x += 10
 | 
						|
                # The compiler would normally try to compile `debugger`,
 | 
						|
                # but since it is `@ignore`d, it will be left as a call
 | 
						|
                # to Python
 | 
						|
                self.debugger(x)
 | 
						|
                return x
 | 
						|
 | 
						|
 | 
						|
        m = torch.jit.script(MyModule())
 | 
						|
 | 
						|
        # Error! The call `debugger` cannot be saved since it calls into Python
 | 
						|
        m.save("m.pt")
 | 
						|
 | 
						|
    Example (using ``@torch.jit.ignore(drop=True)`` on a method):
 | 
						|
 | 
						|
    .. testcode::
 | 
						|
 | 
						|
        import torch
 | 
						|
        import torch.nn as nn
 | 
						|
 | 
						|
        class MyModule(nn.Module):
 | 
						|
            @torch.jit.ignore(drop=True)
 | 
						|
            def training_method(self, x):
 | 
						|
                import pdb
 | 
						|
                pdb.set_trace()
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                if self.training:
 | 
						|
                    self.training_method(x)
 | 
						|
                return x
 | 
						|
 | 
						|
        m = torch.jit.script(MyModule())
 | 
						|
 | 
						|
        # This is OK since `training_method` is not saved, the call is replaced
 | 
						|
        # with a `raise`.
 | 
						|
        m.save("m.pt")
 | 
						|
 | 
						|
    .. testcleanup::
 | 
						|
 | 
						|
        import os
 | 
						|
        os.remove('m.pt')
 | 
						|
    """
 | 
						|
 | 
						|
    if callable(drop):
 | 
						|
        # used without any args, so drop is actually a function
 | 
						|
        #   @torch.jit.ignore
 | 
						|
        #   def fn(...):
 | 
						|
        fn = drop
 | 
						|
        fn._torchscript_modifier = FunctionModifiers.IGNORE
 | 
						|
        return fn
 | 
						|
 | 
						|
    if not isinstance(drop, bool):
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Argument to @torch.jit.ignore must be a bool or a function but got {drop}"
 | 
						|
        )
 | 
						|
 | 
						|
    # for backwards compat
 | 
						|
    drop_on_export = kwargs.pop("drop_on_export", None)
 | 
						|
    if drop_on_export:
 | 
						|
        warnings.warn(
 | 
						|
            "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
 | 
						|
            "call on compilation. Use torch.jit.unused now. {}",
 | 
						|
            category=FutureWarning,
 | 
						|
        )
 | 
						|
 | 
						|
        drop = drop_on_export
 | 
						|
    elif drop:
 | 
						|
        warnings.warn(
 | 
						|
            "ignore(True) has been deprecated. TorchScript will now drop the function "
 | 
						|
            "call on compilation. Use torch.jit.unused now. {}",
 | 
						|
            category=FutureWarning,
 | 
						|
        )
 | 
						|
 | 
						|
    def decorator(fn):
 | 
						|
        if drop:
 | 
						|
            fn._torchscript_modifier = FunctionModifiers.UNUSED
 | 
						|
        else:
 | 
						|
            fn._torchscript_modifier = FunctionModifiers.IGNORE
 | 
						|
        return fn
 | 
						|
 | 
						|
    return decorator
 | 
						|
 | 
						|
 | 
						|
def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]:
 | 
						|
    fn._torchscript_modifier = FunctionModifiers._DROP  # type: ignore[attr-defined]
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
 | 
						|
    fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER  # type: ignore[attr-defined]
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
def module_has_exports(mod):
 | 
						|
    for name in dir(mod):
 | 
						|
        if hasattr(mod, name):
 | 
						|
            item = getattr(mod, name)
 | 
						|
            if callable(item):
 | 
						|
                if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
 | 
						|
                    return True
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
 | 
						|
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
 | 
						|
# allow JIT'd code to still be covered.
 | 
						|
def should_drop(fn) -> bool:
 | 
						|
    attr = get_torchscript_modifier(fn)
 | 
						|
    if attr is None:
 | 
						|
        return False
 | 
						|
    return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
 | 
						|
 | 
						|
 | 
						|
def is_ignored_fn(fn) -> bool:
 | 
						|
    mod = get_torchscript_modifier(fn)
 | 
						|
    return (
 | 
						|
        mod is FunctionModifiers.UNUSED
 | 
						|
        or mod is FunctionModifiers.IGNORE
 | 
						|
        or mod is FunctionModifiers._DROP
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _is_drop_fn(fn) -> bool:
 | 
						|
    mod = get_torchscript_modifier(fn)
 | 
						|
    return mod is FunctionModifiers._DROP
 | 
						|
 | 
						|
 | 
						|
def is_static_fn(cls, fn) -> bool:
 | 
						|
    return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
 | 
						|
 | 
						|
 | 
						|
def get_static_fn(cls, fn):
 | 
						|
    return inspect.getattr_static(cls, fn).__func__
 | 
						|
 | 
						|
 | 
						|
def get_torchscript_modifier(fn):
 | 
						|
    if not callable(fn):
 | 
						|
        return None
 | 
						|
    if hasattr(fn, "__func__"):
 | 
						|
        fn = fn.__func__
 | 
						|
    return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
 | 
						|
 | 
						|
 | 
						|
def copy_torchscript_modifier(orig, new) -> None:
 | 
						|
    attr = get_torchscript_modifier(orig)
 | 
						|
    if attr is None:
 | 
						|
        return
 | 
						|
    new._torchscript_modifier = attr
 | 
						|
 | 
						|
 | 
						|
# overloading registration
 | 
						|
# overloads get registered in this file, and compiled in torch/jit/__init__.py
 | 
						|
# so that they can be imported in nn/functional.py without an import cycle
 | 
						|
 | 
						|
# qualified_name => list[overload_functions]
 | 
						|
_overloaded_fns: dict[str, list[Callable]] = {}  # noqa: T484
 | 
						|
 | 
						|
 | 
						|
_OVERLOAD_EXAMPLE = """
 | 
						|
Example usage of overload function:
 | 
						|
@torch.jit._overload
 | 
						|
def my_function(x: type0) -> type0: # decl 1
 | 
						|
    pass
 | 
						|
 | 
						|
@torch.jit._overload
 | 
						|
def my_function(x: type1) -> type1: # decl 2
 | 
						|
    pass
 | 
						|
 | 
						|
def my_function(x):                 # implementation
 | 
						|
    if isinstance(x, type0):
 | 
						|
        return x
 | 
						|
    elif isinstance(x, type1):
 | 
						|
        return x
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
def get_overload_no_implementation_error_message(kind, obj):
 | 
						|
    sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
 | 
						|
    return (
 | 
						|
        f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
 | 
						|
        f"sure a definition is provided and defined after all overload declarations.\n"
 | 
						|
        f'File "{filename}", line {file_lineno}:\n'
 | 
						|
        + "".join(sourcelines)
 | 
						|
        + "\n"
 | 
						|
        + _OVERLOAD_EXAMPLE
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def _check_overload_body(func):
 | 
						|
    try:
 | 
						|
        parsed_def = parse_def(func)
 | 
						|
    except OSError:
 | 
						|
        # Parsing the function definition can raise an OSError if source is unavailable.
 | 
						|
        # Since this is just an initial check, just raise a warning if this is the case.
 | 
						|
        warnings.warn(
 | 
						|
            f"Unable to retrieve source for @torch.jit._overload function: {func}."
 | 
						|
        )
 | 
						|
        return
 | 
						|
 | 
						|
    body = parsed_def.ast.body[0].body
 | 
						|
 | 
						|
    def is_pass(x):
 | 
						|
        return isinstance(x, ast.Pass)
 | 
						|
 | 
						|
    def is_ellipsis(x):
 | 
						|
        return (
 | 
						|
            isinstance(x, ast.Expr)
 | 
						|
            and isinstance(x.value, ast.Constant)
 | 
						|
            and x.value.value is Ellipsis
 | 
						|
        )
 | 
						|
 | 
						|
    if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
 | 
						|
        msg = (
 | 
						|
            "Only `pass` statement or `...` can be the body of overload declaration:\n"
 | 
						|
        )
 | 
						|
        msg += "\n".join(parsed_def.source.split("\n")[:3])
 | 
						|
        msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
 | 
						|
        raise RuntimeError(msg)
 | 
						|
 | 
						|
 | 
						|
def _overload(func):
 | 
						|
    _check_overload_body(func)
 | 
						|
    qual_name = _qualified_name(func)
 | 
						|
    global _overloaded_fns
 | 
						|
    fn_overload_list = _overloaded_fns.get(qual_name)
 | 
						|
    if fn_overload_list is None:
 | 
						|
        fn_overload_list = []
 | 
						|
        _overloaded_fns[qual_name] = fn_overload_list
 | 
						|
    fn_overload_list.append(func)
 | 
						|
    return func
 | 
						|
 | 
						|
 | 
						|
def _get_fn_overloads(qual_name):
 | 
						|
    return _overloaded_fns.get(qual_name)
 | 
						|
 | 
						|
 | 
						|
def _clear_fn_overloads(qual_name) -> None:
 | 
						|
    del _overloaded_fns[qual_name]
 | 
						|
 | 
						|
 | 
						|
def get_class_name_lineno(method) -> tuple[str, int]:
 | 
						|
    current_frame = inspect.currentframe()
 | 
						|
 | 
						|
    # one for the get_class_name call, one for _overload_method call
 | 
						|
    for _ in range(2):
 | 
						|
        assert (
 | 
						|
            current_frame is not None
 | 
						|
        )  # assert current frame is not an Optional[FrameType]
 | 
						|
        current_frame = current_frame.f_back
 | 
						|
 | 
						|
    assert current_frame is not None  # same here
 | 
						|
    class_name = current_frame.f_code.co_name
 | 
						|
    line_no = current_frame.f_code.co_firstlineno
 | 
						|
    return class_name, line_no
 | 
						|
 | 
						|
 | 
						|
# At the point the decorator is applied to class methods the method
 | 
						|
# has no reference to its owning class. _qualified_name would not include
 | 
						|
# the class it is defined in, so any methods with the same name in the same file
 | 
						|
# would have the same _qualified_name, even if they were defined in different
 | 
						|
# classes. This problem only exists in python 2.
 | 
						|
# We get around this problem by looking at the stack frame and identifying
 | 
						|
# the class name, and throwing an error whenever overloads are used
 | 
						|
# when modules of the same name are in the same file
 | 
						|
 | 
						|
# qualified_name => class name => list[overload_functions]
 | 
						|
_overloaded_methods: dict[str, dict[str, list[Callable]]] = {}  # noqa: T484
 | 
						|
 | 
						|
 | 
						|
# (qualified_name, class name) => class_fileno
 | 
						|
_overloaded_method_class_fileno: dict[tuple[str, str], int] = {}
 | 
						|
 | 
						|
 | 
						|
def _overload_method(func):
 | 
						|
    _check_overload_body(func)
 | 
						|
    qual_name = _qualified_name(func)
 | 
						|
    global _overloaded_methods
 | 
						|
    class_name_map = _overloaded_methods.get(qual_name, None)
 | 
						|
    if class_name_map is None:
 | 
						|
        class_name_map = {}
 | 
						|
        _overloaded_methods[qual_name] = class_name_map
 | 
						|
 | 
						|
    class_name, line_no = get_class_name_lineno(func)
 | 
						|
    method_overloads = class_name_map.get(class_name, None)
 | 
						|
    if method_overloads is None:
 | 
						|
        method_overloads = []
 | 
						|
        class_name_map[class_name] = method_overloads
 | 
						|
        _overloaded_method_class_fileno[(qual_name, class_name)] = line_no
 | 
						|
    else:
 | 
						|
        existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
 | 
						|
        if existing_lineno != line_no:
 | 
						|
            raise RuntimeError(
 | 
						|
                "Cannot currently overload the same method name in two different"
 | 
						|
                " classes with the same name in the same module"
 | 
						|
            )
 | 
						|
 | 
						|
    method_overloads.append(func)
 | 
						|
    return func
 | 
						|
 | 
						|
 | 
						|
def _get_overloaded_methods(method, mod_class):
 | 
						|
    # TODO: __name__ not set for submodules in recursive script
 | 
						|
    if not hasattr(method, "__name__"):
 | 
						|
        return None
 | 
						|
    qual_name = _qualified_name(method)
 | 
						|
    class_name_map = _overloaded_methods.get(qual_name, None)
 | 
						|
    if class_name_map is None:
 | 
						|
        return None
 | 
						|
    overloads = class_name_map.get(mod_class.__name__, None)
 | 
						|
    if overloads is None:
 | 
						|
        return None
 | 
						|
 | 
						|
    method_line_no = get_source_lines_and_file(method)[1]
 | 
						|
    mod_class_fileno = get_source_lines_and_file(mod_class)[1]
 | 
						|
    mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
 | 
						|
    if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
 | 
						|
        raise AssertionError(
 | 
						|
            "Overloads are not usable when a module is redeclared within the same file: "
 | 
						|
            + str(method)
 | 
						|
        )
 | 
						|
    return overloads
 | 
						|
 | 
						|
 | 
						|
def is_tuple(ann) -> bool:
 | 
						|
    # Check for typing.Tuple missing args (but `tuple` is fine)
 | 
						|
    if ann is typing.Tuple:  # noqa: UP006
 | 
						|
        raise_error_container_parameter_missing("Tuple")
 | 
						|
 | 
						|
    # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
 | 
						|
    if not hasattr(ann, "__module__"):
 | 
						|
        return False
 | 
						|
 | 
						|
    ann_origin = get_origin(ann)
 | 
						|
    return ann.__module__ in ("builtins", "typing") and ann_origin is tuple
 | 
						|
 | 
						|
 | 
						|
def is_list(ann) -> bool:
 | 
						|
    # Check for typing.List missing args (but `list` is fine)
 | 
						|
    if ann is typing.List:  # noqa: UP006
 | 
						|
        raise_error_container_parameter_missing("List")
 | 
						|
 | 
						|
    if not hasattr(ann, "__module__"):
 | 
						|
        return False
 | 
						|
 | 
						|
    ann_origin = get_origin(ann)
 | 
						|
    return ann.__module__ in ("builtins", "typing") and ann_origin is list
 | 
						|
 | 
						|
 | 
						|
def is_dict(ann) -> bool:
 | 
						|
    # Check for typing.Dict missing args (but `dict` is fine)
 | 
						|
    if ann is typing.Dict:  # noqa: UP006
 | 
						|
        raise_error_container_parameter_missing("Dict")
 | 
						|
 | 
						|
    if not hasattr(ann, "__module__"):
 | 
						|
        return False
 | 
						|
 | 
						|
    ann_origin = get_origin(ann)
 | 
						|
    return ann.__module__ in ("builtins", "typing") and ann_origin is dict
 | 
						|
 | 
						|
 | 
						|
def is_union(ann):
 | 
						|
    if ann is Union:
 | 
						|
        raise_error_container_parameter_missing("Union")
 | 
						|
 | 
						|
    return isinstance(ann, BuiltinUnionType) or (
 | 
						|
        hasattr(ann, "__module__")
 | 
						|
        and ann.__module__ == "typing"
 | 
						|
        and (get_origin(ann) is Union)
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def is_optional(ann):
 | 
						|
    if ann is Optional:
 | 
						|
        raise_error_container_parameter_missing("Optional")
 | 
						|
 | 
						|
    def is_optional_as_optional(ann):
 | 
						|
        return (
 | 
						|
            hasattr(ann, "__module__")
 | 
						|
            and ann.__module__ == "typing"
 | 
						|
            and (get_origin(ann) is Optional)
 | 
						|
        )
 | 
						|
 | 
						|
    def is_union_as_optional(ann):
 | 
						|
        ann_args = get_args(ann)
 | 
						|
        return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
 | 
						|
 | 
						|
    return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
 | 
						|
 | 
						|
 | 
						|
def is_future(ann) -> bool:
 | 
						|
    if ann is Future:
 | 
						|
        raise RuntimeError(
 | 
						|
            "Attempted to use Future without a "
 | 
						|
            "contained type. Please add a contained type, e.g. "
 | 
						|
            "Future[int]"
 | 
						|
        )
 | 
						|
    return get_origin(ann) is Future
 | 
						|
 | 
						|
 | 
						|
def is_await(ann) -> bool:
 | 
						|
    if ann is _Await:
 | 
						|
        return True
 | 
						|
    return get_origin(ann) is _Await
 | 
						|
 | 
						|
 | 
						|
if torch.distributed.rpc.is_available():
 | 
						|
    from torch._C._distributed_rpc import PyRRef
 | 
						|
    from torch.distributed.rpc import RRef
 | 
						|
 | 
						|
    def is_rref(ann) -> bool:
 | 
						|
        if ann is RRef:
 | 
						|
            raise RuntimeError(
 | 
						|
                "Attempted to use RRef without a "
 | 
						|
                "contained type. Please add a contained type, e.g. "
 | 
						|
                "RRef[int]"
 | 
						|
            )
 | 
						|
        return get_origin(ann) is RRef
 | 
						|
 | 
						|
    def is_rref_instance(obj) -> bool:
 | 
						|
        return isinstance(obj, PyRRef)
 | 
						|
 | 
						|
else:
 | 
						|
 | 
						|
    def is_rref_instance(obj) -> bool:
 | 
						|
        # If the RPC module doesn't exist then RRefs don't exist either.
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def _try_get_dispatched_fn(fn):
 | 
						|
    if not callable(fn):
 | 
						|
        return None
 | 
						|
    return boolean_dispatched.get(fn)
 | 
						|
 | 
						|
 | 
						|
def _get_named_tuple_properties(
 | 
						|
    obj,
 | 
						|
    loc: Optional[torch._C._jit_tree_views.SourceRange] = None,
 | 
						|
    rcb=None,
 | 
						|
):
 | 
						|
    if loc is None:
 | 
						|
        loc = fake_range()
 | 
						|
 | 
						|
    assert issubclass(obj, tuple) and hasattr(obj, "_fields")
 | 
						|
    if hasattr(obj, "_field_defaults"):
 | 
						|
        defaults = [
 | 
						|
            obj._field_defaults[field]
 | 
						|
            for field in obj._fields
 | 
						|
            if field in obj._field_defaults
 | 
						|
        ]
 | 
						|
    else:
 | 
						|
        defaults = []
 | 
						|
    # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
 | 
						|
    # Also, annotations from base class are not inherited so they need to be queried explicitly
 | 
						|
    if sys.version_info[:2] < (3, 10):
 | 
						|
        obj_annotations = getattr(obj, "__annotations__", {})
 | 
						|
    else:
 | 
						|
        obj_annotations = inspect.get_annotations(obj)
 | 
						|
        if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
 | 
						|
            obj_annotations = inspect.get_annotations(obj.__base__)
 | 
						|
 | 
						|
    annotations = []
 | 
						|
    for field in obj._fields:
 | 
						|
        if field in obj_annotations:
 | 
						|
            field_type = obj_annotations[field]
 | 
						|
            # [Note: ForwardRef annotations in NamedTuple attributes]
 | 
						|
            # NamedTuple types are slightly different from normal types.
 | 
						|
            #
 | 
						|
            # Normally, annotations are evaluated like this (during jit.script):
 | 
						|
            # 1. Load strings of python code into c++ and parse.
 | 
						|
            # 2. Get annotations as strings
 | 
						|
            # 3. Use the PythonResolver's resolution callback (rcb) to convert
 | 
						|
            #    the string into a python object
 | 
						|
            # 4. We call into annotations.py:ann_to_type to convert python obj
 | 
						|
            #    from step 3 into a type that torchscript understands.
 | 
						|
            #
 | 
						|
            # NamedTuples are more complicated, because it has sub-types.
 | 
						|
            # Normally, once we have the NamedTuple type object from #3,
 | 
						|
            # we can just look at the annotation literal values and use
 | 
						|
            # ann_to_type directly on them.
 | 
						|
            #
 | 
						|
            # But sometimes, users will annotate with string literals, e.g.
 | 
						|
            #    x: 'int'
 | 
						|
            # This also happens with PEP563 (from __forward__ import annotations)
 | 
						|
            #
 | 
						|
            # These annotations appear in the annotation dict as ForwardRef('int').
 | 
						|
            #
 | 
						|
            # Then, we need to convert the string into a python object. This
 | 
						|
            # requires having local context for custom objects or imported types.
 | 
						|
            # rcb() is what gives us this. So, we plumb rcb through the stack so
 | 
						|
            # it can be used in this context for the if block below.
 | 
						|
            #
 | 
						|
            # FAQ:
 | 
						|
            # - Why do we need this special handling for NamedTuple but string
 | 
						|
            #   annotations work fine for normal types? Normally, we parse the
 | 
						|
            #   string directly and then call rcb() directly from C++.
 | 
						|
            # - Why not use ForwardRef._evaluate? For that, we need globals()
 | 
						|
            #   and locals() for the local context where the NamedTuple was defined.
 | 
						|
            #   rcb is what lets us look up into these. So, basically rcb does the
 | 
						|
            #   hard work for us.
 | 
						|
            if isinstance(field_type, ForwardRef) and rcb is not None:
 | 
						|
                rcb_type = rcb(field_type.__forward_arg__)
 | 
						|
                # rcb returns None if it can't find anything.
 | 
						|
                if rcb_type is None:
 | 
						|
                    raise ValueError(
 | 
						|
                        f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
 | 
						|
                        f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
 | 
						|
                        f" Issue occurred at {loc.highlight()}"
 | 
						|
                    )
 | 
						|
                field_type = rcb_type
 | 
						|
            the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
 | 
						|
            annotations.append(the_type)
 | 
						|
        else:
 | 
						|
            annotations.append(torch._C.TensorType.getInferred())
 | 
						|
    return type(obj).__name__, obj._fields, annotations, defaults
 | 
						|
 | 
						|
 | 
						|
def _create_named_tuple(
 | 
						|
    t,
 | 
						|
    unqual_name: str,
 | 
						|
    field_names: list[str],
 | 
						|
    defaults: tuple[Any, ...],
 | 
						|
):
 | 
						|
    TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults)  # type: ignore[call-arg, no-redef, misc]
 | 
						|
    return TupleType(*t)
 | 
						|
 | 
						|
 | 
						|
@contextlib.contextmanager
 | 
						|
def _disable_emit_hooks():
 | 
						|
    hooks = torch._C._jit_get_emit_hooks()
 | 
						|
    torch._C._jit_set_emit_hooks(None, None)
 | 
						|
    try:
 | 
						|
        yield
 | 
						|
    finally:
 | 
						|
        torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
 | 
						|
 | 
						|
 | 
						|
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None:  # noqa: F811
 | 
						|
    # noqa: F841
 | 
						|
    def __enter__(self) -> None:
 | 
						|
        self.hooks = torch._C._jit_get_emit_hooks()
 | 
						|
        torch._C._jit_set_emit_hooks(None, None)
 | 
						|
 | 
						|
    def __exit__(self, *args) -> None:
 | 
						|
        torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
 | 
						|
 | 
						|
 | 
						|
def _is_exception(obj) -> bool:
 | 
						|
    if not inspect.isclass(obj):
 | 
						|
        return False
 | 
						|
    return issubclass(obj, Exception)
 | 
						|
 | 
						|
 | 
						|
def raise_error_container_parameter_missing(target_type) -> None:
 | 
						|
    if target_type.endswith("ict"):
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Attempted to use {target_type} without "
 | 
						|
            "contained types. Please add contained type, e.g. "
 | 
						|
            f"{target_type}[int, int]"
 | 
						|
        )
 | 
						|
    raise RuntimeError(
 | 
						|
        f"Attempted to use {target_type} without a "
 | 
						|
        "contained type. Please add a contained type, e.g. "
 | 
						|
        f"{target_type}[int]"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
_RAW_TYPE_NAME_MAPPING = {
 | 
						|
    dict: "dict",
 | 
						|
    list: "list",
 | 
						|
    tuple: "tuple",
 | 
						|
    typing.Dict: "Dict",  # noqa: UP006
 | 
						|
    typing.List: "List",  # noqa: UP006
 | 
						|
    typing.Optional: "Optional",
 | 
						|
    typing.Tuple: "Tuple",  # noqa: UP006
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
def check_args_exist(target_type) -> None:
 | 
						|
    if name := _RAW_TYPE_NAME_MAPPING.get(target_type):
 | 
						|
        raise_error_container_parameter_missing(name)
 | 
						|
 | 
						|
 | 
						|
def check_empty_containers(obj) -> None:
 | 
						|
    if obj == [] or obj == {} or obj == ():
 | 
						|
        warnings.warn(
 | 
						|
            "The inner type of a container is lost when "
 | 
						|
            "calling torch.jit.isinstance in eager mode. For "
 | 
						|
            "example, List[int] would become list and "
 | 
						|
            "therefore falsely return True for List[float] or"
 | 
						|
            " List[str]."
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
# supports List/Dict/Tuple and Optional types
 | 
						|
# TODO support future
 | 
						|
def container_checker(obj, target_type) -> bool:
 | 
						|
    origin_type = get_origin(target_type)
 | 
						|
    check_args_exist(target_type)
 | 
						|
    if origin_type is None:
 | 
						|
        return False
 | 
						|
    elif origin_type is list or origin_type is typing.List:  # noqa: UP006
 | 
						|
        check_empty_containers(obj)
 | 
						|
        if not isinstance(obj, list):
 | 
						|
            return False
 | 
						|
        arg_type = get_args(target_type)[0]
 | 
						|
        arg_origin = get_origin(arg_type)
 | 
						|
        for el in obj:
 | 
						|
            # check if nested container, ex: List[List[str]]
 | 
						|
            if arg_origin:  # processes nested container, ex: List[List[str]]
 | 
						|
                if not container_checker(el, arg_type):
 | 
						|
                    return False
 | 
						|
            elif not isinstance(el, arg_type):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
    elif origin_type is typing.Dict or origin_type is dict:  # noqa: UP006
 | 
						|
        check_empty_containers(obj)
 | 
						|
        if not isinstance(obj, dict):
 | 
						|
            return False
 | 
						|
        key_type = get_args(target_type)[0]
 | 
						|
        val_type = get_args(target_type)[1]
 | 
						|
        for key, val in obj.items():
 | 
						|
            # check if keys are of right type
 | 
						|
            if not isinstance(key, key_type):
 | 
						|
                return False
 | 
						|
            val_origin = get_origin(val_type)
 | 
						|
            if val_origin:
 | 
						|
                if not container_checker(val, val_type):
 | 
						|
                    return False
 | 
						|
            elif not isinstance(val, val_type):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
    elif origin_type is typing.Tuple or origin_type is tuple:  # noqa: UP006
 | 
						|
        check_empty_containers(obj)
 | 
						|
        if not isinstance(obj, tuple):
 | 
						|
            return False
 | 
						|
        arg_types = get_args(target_type)
 | 
						|
        if len(obj) != len(arg_types):
 | 
						|
            return False
 | 
						|
        for el, el_type in zip(obj, arg_types):
 | 
						|
            el_origin = get_origin(el_type)
 | 
						|
            if el_origin:
 | 
						|
                if not container_checker(el, el_type):
 | 
						|
                    return False
 | 
						|
            elif not isinstance(el, el_type):
 | 
						|
                return False
 | 
						|
        return True
 | 
						|
    elif origin_type is Union or issubclass(
 | 
						|
        origin_type, BuiltinUnionType
 | 
						|
    ):  # also handles Optional
 | 
						|
        if obj is None:  # check before recursion because None is always fine
 | 
						|
            return True
 | 
						|
        inner_types = get_args(target_type)
 | 
						|
        for t in inner_types:
 | 
						|
            t_origin = get_origin(t)
 | 
						|
            if t_origin:
 | 
						|
                return container_checker(obj, t)
 | 
						|
            elif isinstance(obj, t):
 | 
						|
                return True
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def _isinstance(obj, target_type) -> bool:
 | 
						|
    if isinstance(target_type, collections.abc.Container):
 | 
						|
        if not isinstance(target_type, tuple):
 | 
						|
            raise RuntimeError(
 | 
						|
                "The second argument to "
 | 
						|
                "`torch.jit.isinstance` must be a type "
 | 
						|
                "or a tuple of types"
 | 
						|
            )
 | 
						|
        for t_type in target_type:
 | 
						|
            if _isinstance(obj, t_type):
 | 
						|
                return True
 | 
						|
        return False
 | 
						|
 | 
						|
    origin_type = get_origin(target_type)
 | 
						|
    if origin_type:
 | 
						|
        return container_checker(obj, target_type)
 | 
						|
 | 
						|
    # Check to handle non-typed optional origin returns as none instead
 | 
						|
    #    of as optional in 3.7-3.8
 | 
						|
    check_args_exist(target_type)
 | 
						|
 | 
						|
    # handle non-containers
 | 
						|
    return isinstance(obj, target_type)
 | 
						|
 | 
						|
 | 
						|
class _TensorExtractor(pickle.Pickler):
 | 
						|
    def __init__(self, *args, tensors: list[torch.Tensor], **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        self.tensors = tensors
 | 
						|
 | 
						|
    def persistent_id(self, obj):
 | 
						|
        if isinstance(obj, torch.Tensor):
 | 
						|
            self.tensors.append(obj)
 | 
						|
            return ""
 | 
						|
        # Since we just want to extract tensors, we don't mind if an object is
 | 
						|
        # unpicklable if it doesn't contain tensors, as we can just ignore/skip
 | 
						|
        # it. To play it safe, we only do so for common objects that we're sure
 | 
						|
        # don't contain tensors. Feel free to add new types here. Note also that
 | 
						|
        # even if a type isn't listed here this won't block users, since they
 | 
						|
        # can just add a __getstate__ or __reduce__ method to their class.
 | 
						|
        if isinstance(obj, LockType):
 | 
						|
            return ""
 | 
						|
        # Futures and RRefs don't technically contain a value, they just offer
 | 
						|
        # the means to access a value.
 | 
						|
        if isinstance(obj, CFuture) or is_rref_instance(obj):
 | 
						|
            return ""
 | 
						|
        if isinstance(obj, CAwait):
 | 
						|
            return ""
 | 
						|
        if isinstance(obj, torch.cuda.Event):
 | 
						|
            return ""
 | 
						|
        if isinstance(obj, threading.Thread):
 | 
						|
            return ""
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def _extract_tensors(obj):
 | 
						|
    r"""
 | 
						|
    This function is exclusively called from C++.
 | 
						|
    See ``torch/csrc/jit/python/python_ivalue.h``.
 | 
						|
 | 
						|
    It extracts the tensors contained in the given object, through pickling.
 | 
						|
    """
 | 
						|
    tensors: list[torch.Tensor] = []
 | 
						|
    extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
 | 
						|
    extractor.dump(obj)
 | 
						|
    return tensors
 | 
						|
 | 
						|
 | 
						|
def _get_model_id(obj) -> Optional[str]:
 | 
						|
    if isinstance(obj, torch.jit.ScriptModule):
 | 
						|
        return str(obj._c._type())
 | 
						|
    elif isinstance(obj, torch.jit.ScriptFunction):
 | 
						|
        return obj.qualified_name
 | 
						|
    else:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
 | 
						|
# that were previously dropped. To preserve the behavior, explicitly drop them there
 | 
						|
 | 
						|
if sys.version_info >= (3, 11):
 | 
						|
    _drop(enum.Enum.__new__)
 | 
						|
    _drop(enum.Enum.__format__)
 | 
						|
    _drop(enum.Enum.__repr__)
 | 
						|
    _drop(enum.Enum.__str__)
 |