mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
		
			
				
	
	
		
			1062 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1062 lines
		
	
	
		
			41 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
import collections
 | 
						|
import functools
 | 
						|
import inspect
 | 
						|
import textwrap
 | 
						|
import types
 | 
						|
import warnings
 | 
						|
 | 
						|
import torch
 | 
						|
import torch._jit_internal as _jit_internal
 | 
						|
from torch._sources import fake_range
 | 
						|
from torch.jit._builtins import _find_builtin
 | 
						|
from torch.jit._check import AttributeTypeIsSupportedChecker
 | 
						|
from torch.jit._state import _add_script_class, _get_script_class, _python_cu
 | 
						|
from torch.jit.frontend import (
 | 
						|
    get_class_properties,
 | 
						|
    get_default_args,
 | 
						|
    get_jit_class_def,
 | 
						|
    get_jit_def,
 | 
						|
)
 | 
						|
from torch.nn import Module
 | 
						|
 | 
						|
 | 
						|
ScriptMethodStub = collections.namedtuple(
 | 
						|
    "ScriptMethodStub", ("resolution_callback", "def_", "original_method")
 | 
						|
)
 | 
						|
PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_"))
 | 
						|
 | 
						|
 | 
						|
# TODO: there should be a more principled way of doing this.
 | 
						|
ignored_attributes = [
 | 
						|
    "_version",
 | 
						|
    "_parameters",
 | 
						|
    "_buffers",
 | 
						|
    "_non_persistent_buffers_set",
 | 
						|
    "_backward_hooks",
 | 
						|
    "_backward_pre_hooks",
 | 
						|
    "_forward_hooks",
 | 
						|
    "_forward_hooks_with_kwargs",
 | 
						|
    "_forward_pre_hooks",
 | 
						|
    "_forward_pre_hooks_with_kwargs",
 | 
						|
    "_forward_hooks_always_called",
 | 
						|
    "_state_dict_hooks",
 | 
						|
    "_state_dict_pre_hooks",
 | 
						|
    "_load_state_dict_pre_hooks",
 | 
						|
    "_load_state_dict_post_hooks",
 | 
						|
    "_modules",
 | 
						|
    "_initializing",
 | 
						|
    "dump_patches",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def _compile_and_register_class(obj, rcb, qualified_name):
 | 
						|
    script_class = _get_script_class(obj)
 | 
						|
 | 
						|
    if not script_class:
 | 
						|
        ast = get_jit_class_def(obj, obj.__name__)
 | 
						|
        defaults = torch.jit.frontend.get_default_args_for_class(obj)
 | 
						|
        script_class = torch._C._jit_script_class_compile(
 | 
						|
            qualified_name, ast, defaults, rcb
 | 
						|
        )
 | 
						|
        _add_script_class(obj, script_class)
 | 
						|
 | 
						|
    return script_class
 | 
						|
 | 
						|
 | 
						|
def make_stub(func, name):
 | 
						|
    rcb = _jit_internal.createResolutionCallbackFromClosure(func)
 | 
						|
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
 | 
						|
    return ScriptMethodStub(rcb, ast, func)
 | 
						|
 | 
						|
 | 
						|
def make_stub_from_method(nn_module, method_name):
 | 
						|
    func = getattr(nn_module, method_name)
 | 
						|
    if isinstance(func, ScriptMethodStub):
 | 
						|
        return func
 | 
						|
    # Make sure the name present in the resulting AST will match the name
 | 
						|
    # requested here. The only time they don't match is if you do something
 | 
						|
    # like:
 | 
						|
    #   def _forward(self):
 | 
						|
    #       pass
 | 
						|
    #   forward = _forward
 | 
						|
    # In this case, the actual function object will have the name `_forward`,
 | 
						|
    # even though we requested a stub for `forward`.
 | 
						|
    return make_stub(func, method_name)
 | 
						|
 | 
						|
 | 
						|
def make_stubs_from_exported_methods(mod):
 | 
						|
    stubs = []
 | 
						|
    for name in dir(mod):
 | 
						|
        item = getattr(mod, name, None)
 | 
						|
        if (
 | 
						|
            _jit_internal.get_torchscript_modifier(item)
 | 
						|
            is _jit_internal.FunctionModifiers.EXPORT
 | 
						|
        ):
 | 
						|
            stubs.append(make_stub_from_method(mod, name))
 | 
						|
 | 
						|
    return stubs
 | 
						|
 | 
						|
 | 
						|
def jit_ignored_properties(module):
 | 
						|
    user_annotated_ignored_attributes = getattr(
 | 
						|
        module, "__jit_ignored_attributes__", []
 | 
						|
    )
 | 
						|
 | 
						|
    def get_properties_names(module):
 | 
						|
        return {k for k, v in vars(module).items() if isinstance(v, property)}
 | 
						|
 | 
						|
    properties = get_properties_names(type(module))
 | 
						|
    user_annoted_ignored_properties = set()
 | 
						|
 | 
						|
    for ignored_attr in user_annotated_ignored_attributes:
 | 
						|
        if ignored_attr in properties:
 | 
						|
            user_annoted_ignored_properties.add(ignored_attr)
 | 
						|
    return user_annoted_ignored_properties
 | 
						|
 | 
						|
 | 
						|
# base types that can be constants
 | 
						|
# in addition, tuples and lists of these base types are also considered constants
 | 
						|
# If you edit this list, then you also need to edit the handlers in
 | 
						|
# ConstantValue in jit/script/init.cpp
 | 
						|
_constant_types = (
 | 
						|
    bool,
 | 
						|
    float,
 | 
						|
    int,
 | 
						|
    str,
 | 
						|
    type(None),
 | 
						|
    torch.device,
 | 
						|
    torch.layout,
 | 
						|
    torch.dtype,
 | 
						|
    torch.qscheme,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def _get_valid_constant(attr, v, owner_type):
 | 
						|
    if isinstance(v, _constant_types):
 | 
						|
        return v
 | 
						|
    elif isinstance(v, (tuple, list)):
 | 
						|
        return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
 | 
						|
    constants = ", ".join(torch.typename(typ) for typ in _constant_types)
 | 
						|
    raise TypeError(
 | 
						|
        textwrap.dedent(
 | 
						|
            f"""
 | 
						|
        '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant.
 | 
						|
        Valid constants are:
 | 
						|
        1. a nn.ModuleList
 | 
						|
        2. a value of type {{{constants}}}
 | 
						|
        3. a list or tuple of (2)
 | 
						|
        """
 | 
						|
        )
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
 | 
						|
    def __init__(self, source, filename, file_lineno, leading_whitespace_len):
 | 
						|
        super().__init__(source, filename, file_lineno, leading_whitespace_len)
 | 
						|
 | 
						|
 | 
						|
def get_annotations(obj):
 | 
						|
    # In Python-3.10+ it is recommended to use inspect.get_annotations
 | 
						|
    # See https://docs.python.org/3.10/howto/annotations.html
 | 
						|
    # But also, in 3.10 annotations from base class are not inherited
 | 
						|
    # by unannotated derived one, so they must be manually extracted
 | 
						|
    annotations = inspect.get_annotations(obj)
 | 
						|
    if annotations:
 | 
						|
        return annotations
 | 
						|
 | 
						|
    def get_cls_annotations(cls):
 | 
						|
        cls_annotations = inspect.get_annotations(cls)
 | 
						|
        if cls_annotations:
 | 
						|
            return cls_annotations
 | 
						|
        for base in cls.__bases__:
 | 
						|
            cls_annotations = get_cls_annotations(base)
 | 
						|
            if cls_annotations:
 | 
						|
                return cls_annotations
 | 
						|
        return {}
 | 
						|
 | 
						|
    cls = obj if isinstance(obj, type) else type(obj)
 | 
						|
    return get_cls_annotations(cls)
 | 
						|
 | 
						|
 | 
						|
def infer_concrete_type_builder(nn_module, share_types=True):
 | 
						|
    """
 | 
						|
    Build a ConcreteModuleTypeBuilder from an nn.Module.
 | 
						|
 | 
						|
    This ConcreteModuleType doesn't have a JIT type associated with it yet, it
 | 
						|
    must be filled in by the caller.
 | 
						|
    """
 | 
						|
    concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
 | 
						|
    if isinstance(nn_module, (torch.nn.ModuleDict)):
 | 
						|
        concrete_type_builder.set_module_dict()
 | 
						|
    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
 | 
						|
        concrete_type_builder.set_module_list()
 | 
						|
    if isinstance(nn_module, (torch.nn.ParameterList)):
 | 
						|
        concrete_type_builder.set_parameter_list()
 | 
						|
    if isinstance(nn_module, (torch.nn.ParameterDict)):
 | 
						|
        concrete_type_builder.set_parameter_dict()
 | 
						|
 | 
						|
    class_annotations = get_annotations(nn_module)
 | 
						|
    if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
 | 
						|
        class_annotations = {}
 | 
						|
 | 
						|
    # Get user-annotated ignored attributes.
 | 
						|
    user_annotated_ignored_attributes = getattr(
 | 
						|
        nn_module, "__jit_ignored_attributes__", []
 | 
						|
    )
 | 
						|
    concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
 | 
						|
    ignored_properties = jit_ignored_properties(nn_module)
 | 
						|
 | 
						|
    # try to infer the type from type annotation or from the object itself
 | 
						|
    def infer_type(name, item):
 | 
						|
        # The forward function from Module is special; never use this annotations; we
 | 
						|
        # need to infer type directly using JIT.  I originally wanted to write
 | 
						|
        # this test as isinstance(class_annotations[name], Callable) but
 | 
						|
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
 | 
						|
        # is also true!
 | 
						|
        inferred = False
 | 
						|
        try:
 | 
						|
            if (
 | 
						|
                name in class_annotations
 | 
						|
                and class_annotations[name]
 | 
						|
                != torch.nn.Module.__annotations__["forward"]
 | 
						|
            ):
 | 
						|
                ann_to_type = torch.jit.annotations.ann_to_type(
 | 
						|
                    class_annotations[name], fake_range()
 | 
						|
                )
 | 
						|
                attr_type = torch._C.InferredType(ann_to_type)
 | 
						|
            elif isinstance(item, torch.jit.Attribute):
 | 
						|
                ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
 | 
						|
                attr_type = torch._C.InferredType(ann_to_type)
 | 
						|
            else:
 | 
						|
                attr_type = torch._C._jit_try_infer_type(item)
 | 
						|
                inferred = True
 | 
						|
        except RuntimeError as re:
 | 
						|
            raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re
 | 
						|
 | 
						|
        return attr_type, inferred
 | 
						|
 | 
						|
    added_names = set()
 | 
						|
 | 
						|
    for name, item in nn_module._parameters.items():
 | 
						|
        if name in user_annotated_ignored_attributes:
 | 
						|
            continue
 | 
						|
 | 
						|
        assert item is None or isinstance(item, torch.Tensor)
 | 
						|
        attr_type, _ = infer_type(name, item)
 | 
						|
        # We currently have the invariant in various places in our code
 | 
						|
        # that parameters must be Tensors. However, the nn.Module API also
 | 
						|
        # allows NoneType parameters. These parameters are not returned as
 | 
						|
        # part of `parameters()` and its variants, but are available
 | 
						|
        # through direct attribute access.
 | 
						|
        concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
 | 
						|
        added_names.add(name)
 | 
						|
 | 
						|
    for name, item in nn_module._buffers.items():
 | 
						|
        if name in user_annotated_ignored_attributes:
 | 
						|
            continue
 | 
						|
 | 
						|
        assert item is None or isinstance(item, torch.Tensor)
 | 
						|
        attr_type, _ = infer_type(name, item)
 | 
						|
        concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
 | 
						|
        added_names.add(name)
 | 
						|
 | 
						|
    for name, item in nn_module._modules.items():
 | 
						|
        if name in user_annotated_ignored_attributes:
 | 
						|
            continue
 | 
						|
 | 
						|
        attr_type, _ = infer_type(name, item)
 | 
						|
        if item is None:
 | 
						|
            # Modules can be None. We don't have direct support for optional
 | 
						|
            # Modules, so the register it as an NoneType attribute instead.
 | 
						|
            concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
 | 
						|
            continue
 | 
						|
        if attr_type.success():
 | 
						|
            assert attr_type.type().is_interface_type()
 | 
						|
            # if the type can be inferred, it should be a module interface type
 | 
						|
            sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
 | 
						|
                attr_type.type()
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            # otherwise we get the concrete module type for item and add it to concrete_type
 | 
						|
            sub_concrete_type = get_module_concrete_type(item, share_types)
 | 
						|
        concrete_type_builder.add_module(name, sub_concrete_type)
 | 
						|
 | 
						|
        added_names.add(name)
 | 
						|
 | 
						|
    # populate constants_set
 | 
						|
    constants_set = set(getattr(nn_module, "__constants__", ()))
 | 
						|
 | 
						|
    # Constants annotated via `Final[T]` rather than being added to `__constants__`
 | 
						|
    for name, ann in class_annotations.items():
 | 
						|
        if torch._jit_internal.is_final(ann):
 | 
						|
            constants_set.add(name)
 | 
						|
 | 
						|
    for name in constants_set:
 | 
						|
        if name in added_names:
 | 
						|
            # TODO: We should really error in this case, but its bc-breaking so
 | 
						|
            # we need to warn for at least one release
 | 
						|
            if name in nn_module._modules:
 | 
						|
                hint = "submodule"
 | 
						|
            elif name in nn_module._buffers:
 | 
						|
                hint = "buffer"
 | 
						|
            elif name in nn_module._parameters:
 | 
						|
                hint = "parameter"
 | 
						|
            else:
 | 
						|
                raise AssertionError(
 | 
						|
                    "added_names must be submodule, parameter, or buffer"
 | 
						|
                )
 | 
						|
 | 
						|
            warnings.warn(
 | 
						|
                f"'{name}' was found in ScriptModule constants, "
 | 
						|
                f" but it is a non-constant {hint}. Consider removing it."
 | 
						|
            )
 | 
						|
            continue
 | 
						|
        if not hasattr(nn_module, name):
 | 
						|
            # TODO: We should really error in this case, but its bc-breaking so
 | 
						|
            # we need to warn for at least one release
 | 
						|
            warnings.warn(
 | 
						|
                f"'{name}' was found in ScriptModule constants, "
 | 
						|
                "but was not actually set in __init__. "
 | 
						|
                "Consider removing it."
 | 
						|
            )
 | 
						|
            continue
 | 
						|
        value = getattr(nn_module, name)
 | 
						|
        concrete_type_builder.add_constant(
 | 
						|
            name, _get_valid_constant(name, value, type(nn_module).__name__)
 | 
						|
        )
 | 
						|
        added_names.add(name)
 | 
						|
 | 
						|
    # populate overloads
 | 
						|
    overloads = getattr(nn_module, "__overloads__", {})
 | 
						|
    # update with any annotated overloads
 | 
						|
    overloads.update(
 | 
						|
        get_overload_name_mapping(
 | 
						|
            get_overload_annotations(nn_module, ignored_properties)
 | 
						|
        )
 | 
						|
    )
 | 
						|
    for name, overloaded_names in overloads.items():
 | 
						|
        concrete_type_builder.add_overload(name, overloaded_names)
 | 
						|
 | 
						|
    for name, value in nn_module.__dict__.items():
 | 
						|
        if name in ignored_attributes or name.startswith("__"):
 | 
						|
            # Python objects have lots of random attributes attached to them;
 | 
						|
            # PyTorch adds a few more. Prevent these from getting compiled.
 | 
						|
            continue
 | 
						|
 | 
						|
        if name in user_annotated_ignored_attributes:
 | 
						|
            continue
 | 
						|
 | 
						|
        if name in added_names:
 | 
						|
            # Don't re-add anything we already added
 | 
						|
            continue
 | 
						|
 | 
						|
        isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
 | 
						|
        if isoverloadpacket:
 | 
						|
            value = value.op
 | 
						|
        # Handle Python function attributes
 | 
						|
        if inspect.isfunction(value):
 | 
						|
            try:
 | 
						|
                scripted_fn = torch.jit.script(value)
 | 
						|
                concrete_type_builder.add_function_attribute(
 | 
						|
                    name, torch._C._jit_try_infer_type(scripted_fn).type(), value
 | 
						|
                )
 | 
						|
            except Exception as e:
 | 
						|
                # If we fail to script the function, it isn't a hard error.
 | 
						|
                # Instead, we will add it to the list of attributes we failed
 | 
						|
                # to convert, with the compilation error.
 | 
						|
                hint = (
 | 
						|
                    "(This function exists as an attribute on the Python module, "
 | 
						|
                    "but we failed to compile it to a TorchScript function. "
 | 
						|
                    f"\nThe error stack is reproduced here:\n{e}"
 | 
						|
                )
 | 
						|
                concrete_type_builder.add_failed_attribute(name, hint)
 | 
						|
 | 
						|
            continue
 | 
						|
 | 
						|
        # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
 | 
						|
        # a call to an aten function like torch.add)
 | 
						|
        builtin_symbol_name = _find_builtin(value)
 | 
						|
        if builtin_symbol_name:
 | 
						|
            concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
 | 
						|
            continue
 | 
						|
 | 
						|
        # Handle Script function attributes
 | 
						|
        if isinstance(value, torch.jit.ScriptFunction):
 | 
						|
            concrete_type_builder.add_function_attribute(
 | 
						|
                name, torch._C._jit_try_infer_type(value).type(), value
 | 
						|
            )
 | 
						|
            continue
 | 
						|
 | 
						|
        # If we got here, this is a regular "data" attribute, add it to the concrete type
 | 
						|
        attr_type, inferred = infer_type(name, value)
 | 
						|
        if attr_type.success():
 | 
						|
            concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
 | 
						|
        else:
 | 
						|
            # TODO: could add more detail here. For example, what the user should do
 | 
						|
            # when the pytype is `list` or `NoneType`
 | 
						|
            inferred_msg = (
 | 
						|
                "Its type was inferred; try adding a type annotation for the attribute."
 | 
						|
                if inferred
 | 
						|
                else ""
 | 
						|
            )
 | 
						|
            additional_info = f"{attr_type.reason()}. {inferred_msg}"
 | 
						|
            hint = (
 | 
						|
                "(This attribute exists on the Python module, "
 | 
						|
                f"but we failed to convert Python type: '{torch.typename(type(value))}' "
 | 
						|
                f"to a TorchScript type. {additional_info})"
 | 
						|
            )
 | 
						|
            concrete_type_builder.add_failed_attribute(name, hint)
 | 
						|
 | 
						|
    # add hooks to concrete type
 | 
						|
    for hook in nn_module._forward_hooks.values():
 | 
						|
        concrete_type_builder.add_forward_hook(hook)
 | 
						|
    for pre_hook in nn_module._forward_pre_hooks.values():
 | 
						|
        concrete_type_builder.add_forward_pre_hook(pre_hook)
 | 
						|
 | 
						|
    return concrete_type_builder
 | 
						|
 | 
						|
 | 
						|
class ConcreteTypeStore:
 | 
						|
    type_store: dict[type[Module], list[torch._C.ConcreteModuleType]]
 | 
						|
    methods_compiled: set[torch._C.ConcreteModuleType]
 | 
						|
 | 
						|
    def __init__(self) -> None:
 | 
						|
        # Python module type => List[ConcreteModuleType)]
 | 
						|
        self.type_store = {}
 | 
						|
        # ConcreteTypes that have had their methods already compiled
 | 
						|
        self.methods_compiled = set()
 | 
						|
 | 
						|
    def get_or_create_concrete_type(self, nn_module):
 | 
						|
        """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are reused if possible."""
 | 
						|
        concrete_type_builder = infer_concrete_type_builder(nn_module)
 | 
						|
 | 
						|
        nn_module_type = type(nn_module)
 | 
						|
        if nn_module_type not in self.type_store:
 | 
						|
            self.type_store[nn_module_type] = []
 | 
						|
 | 
						|
        # Search the type store for an already-available JIT type
 | 
						|
        known_types = self.type_store[nn_module_type]
 | 
						|
        for known_type in known_types:
 | 
						|
            if known_type.equals(concrete_type_builder):
 | 
						|
                return known_type
 | 
						|
 | 
						|
        # We didn't find anything; generate a new JIT type from this concrete type
 | 
						|
        concrete_type = concrete_type_builder.build()
 | 
						|
        self.type_store[nn_module_type].append(concrete_type)
 | 
						|
        return concrete_type
 | 
						|
 | 
						|
 | 
						|
concrete_type_store = ConcreteTypeStore()
 | 
						|
 | 
						|
 | 
						|
def create_methods_and_properties_from_stubs(
 | 
						|
    concrete_type, method_stubs, property_stubs
 | 
						|
):
 | 
						|
    method_defs = [m.def_ for m in method_stubs]
 | 
						|
    method_rcbs = [m.resolution_callback for m in method_stubs]
 | 
						|
    method_defaults = [get_default_args(m.original_method) for m in method_stubs]
 | 
						|
 | 
						|
    property_defs = [p.def_ for p in property_stubs]
 | 
						|
    property_rcbs = [p.resolution_callback for p in property_stubs]
 | 
						|
 | 
						|
    concrete_type._create_methods_and_properties(
 | 
						|
        property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
 | 
						|
    hook_defs = [h.def_ for h in hook_stubs]
 | 
						|
    hook_rcbs = [h.resolution_callback for h in hook_stubs]
 | 
						|
 | 
						|
    pre_hook_defs = [h.def_ for h in pre_hook_stubs]
 | 
						|
    pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
 | 
						|
 | 
						|
    concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
 | 
						|
 | 
						|
 | 
						|
def get_module_concrete_type(nn_module, share_types=True):
 | 
						|
    """
 | 
						|
    Get a concrete type for nn_modules.
 | 
						|
 | 
						|
    If share_types is True, the concrete type is fetched from concrete_type_store.
 | 
						|
    If it is False, a new concrete type is created without first searching concrete_type_store.
 | 
						|
 | 
						|
    Args:
 | 
						|
        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
 | 
						|
        share_types = Whether to share underlying JIT types between modules (if possible).
 | 
						|
 | 
						|
    Returns:
 | 
						|
        A concrete type for nn_module.
 | 
						|
    """
 | 
						|
    assert isinstance(nn_module, Module)
 | 
						|
    if isinstance(nn_module, torch.jit.ScriptModule) and hasattr(
 | 
						|
        nn_module, "_concrete_type"
 | 
						|
    ):
 | 
						|
        return nn_module._concrete_type
 | 
						|
 | 
						|
    if share_types:
 | 
						|
        # Look into the store of cached JIT types
 | 
						|
        concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
 | 
						|
    else:
 | 
						|
        # Get a concrete type directly, without trying to reuse an existing JIT
 | 
						|
        # type from the type store.
 | 
						|
        concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
 | 
						|
        concrete_type_builder.set_poisoned()
 | 
						|
        concrete_type = concrete_type_builder.build()
 | 
						|
 | 
						|
    return concrete_type
 | 
						|
 | 
						|
 | 
						|
def create_script_class(obj):
 | 
						|
    """
 | 
						|
    Create and return a RecursiveScriptClass instance from a Python object.
 | 
						|
 | 
						|
    Arguments:
 | 
						|
        obj: A Python object.
 | 
						|
    """
 | 
						|
    qualified_class_name = _jit_internal._qualified_name(type(obj))
 | 
						|
    rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
 | 
						|
    # Script the type of obj if it hasn't already been scripted.
 | 
						|
    _compile_and_register_class(type(obj), rcb, qualified_class_name)
 | 
						|
    class_ty = _python_cu.get_class(qualified_class_name)
 | 
						|
    # Create an empty torch._C.ScriptObject with the scripted type.
 | 
						|
    cpp_object = torch._C._create_object_with_type(class_ty)
 | 
						|
    # Copy all of the attributes over to the torch._C.ScriptObject.
 | 
						|
    for name, value in obj.__dict__.items():
 | 
						|
        cpp_object.setattr(name, value)
 | 
						|
 | 
						|
    # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance.
 | 
						|
    return wrap_cpp_class(cpp_object)
 | 
						|
 | 
						|
 | 
						|
def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
 | 
						|
    """
 | 
						|
    Create a new ScriptModule from an nn.Module.
 | 
						|
 | 
						|
    Args:
 | 
						|
        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
 | 
						|
        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
 | 
						|
        share_types:  Whether to share underlying JIT types between modules (if possible).
 | 
						|
            NOTE: Only set to False this when we cannot guarantee type sharing will work
 | 
						|
                correctly. This only happens today for traced modules, where the same
 | 
						|
                module can produce different traced methods depending on the inputs.
 | 
						|
        is_tracing: Whether this function is called during tracing or scripting. If tracing,
 | 
						|
                we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
 | 
						|
                attributes will be baked as constant in the tracing graph. In addition,
 | 
						|
                this check significantly slows down the traced modules when the module size is big.
 | 
						|
    """
 | 
						|
    assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
 | 
						|
    check_module_initialized(nn_module)
 | 
						|
    concrete_type = get_module_concrete_type(nn_module, share_types)
 | 
						|
    if not is_tracing:
 | 
						|
        AttributeTypeIsSupportedChecker().check(nn_module)
 | 
						|
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
 | 
						|
 | 
						|
 | 
						|
def create_script_module_impl(nn_module, concrete_type, stubs_fn):
 | 
						|
    """
 | 
						|
    Convert an nn.Module to a RecursiveScriptModule.
 | 
						|
 | 
						|
    Args:
 | 
						|
        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
 | 
						|
        concrete_type:  The fully initialized ConcreteType of the module.
 | 
						|
        stubs_fn:  Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
 | 
						|
    """
 | 
						|
    cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
 | 
						|
    method_stubs = stubs_fn(nn_module)
 | 
						|
    property_stubs = get_property_stubs(nn_module)
 | 
						|
    hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
 | 
						|
    ignored_properties = jit_ignored_properties(nn_module)
 | 
						|
 | 
						|
    def init_fn(script_module):
 | 
						|
        # Initialize the ScriptModule:
 | 
						|
        # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
 | 
						|
        for name in concrete_type.get_attributes().keys():
 | 
						|
            orig_value = getattr(nn_module, name)
 | 
						|
            orig_value = (
 | 
						|
                orig_value.value
 | 
						|
                if isinstance(orig_value, torch.jit.Attribute)
 | 
						|
                else orig_value
 | 
						|
            )
 | 
						|
            cpp_module.setattr(name, orig_value)
 | 
						|
 | 
						|
        # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
 | 
						|
        #    recursively scripting them.
 | 
						|
        for name, sub_concrete_type in concrete_type.get_modules():
 | 
						|
            orig_value = getattr(nn_module, name)
 | 
						|
            assert isinstance(orig_value, Module), (
 | 
						|
                f"Expected Module but got {type(orig_value)}"
 | 
						|
            )
 | 
						|
            module_type = sub_concrete_type.jit_type
 | 
						|
            if isinstance(module_type, torch._C.InterfaceType):
 | 
						|
                # use the interface inference rule to compile the module
 | 
						|
                scripted = interface_script(module_type, orig_value)
 | 
						|
            elif isinstance(orig_value, torch.jit.ScriptModule):
 | 
						|
                scripted = orig_value
 | 
						|
            else:
 | 
						|
                # always reuse the provided stubs_fn to infer the methods to compile
 | 
						|
                scripted = create_script_module_impl(
 | 
						|
                    orig_value, sub_concrete_type, stubs_fn
 | 
						|
                )
 | 
						|
 | 
						|
            cpp_module.setattr(name, scripted)
 | 
						|
            script_module._modules[name] = scripted
 | 
						|
 | 
						|
        # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
 | 
						|
        #    This ensures we can access these Python methods on the ScriptModule.
 | 
						|
        for name in dir(nn_module):
 | 
						|
            if name in ignored_properties:
 | 
						|
                continue
 | 
						|
            item = getattr(nn_module, name, None)
 | 
						|
            if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
 | 
						|
                unbound_function = getattr(nn_module, name).__func__
 | 
						|
                bound_method = unbound_function.__get__(script_module)
 | 
						|
                setattr(script_module, name, bound_method)
 | 
						|
            elif concrete_type.is_ignored_attribute(name):
 | 
						|
                setattr(script_module, name, item)
 | 
						|
 | 
						|
        # For convenience, attach the concrete type to the new ScriptModule
 | 
						|
        script_module._concrete_type = concrete_type
 | 
						|
 | 
						|
    # Actually create the ScriptModule, initializing it with the function we just defined
 | 
						|
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
 | 
						|
 | 
						|
    # Compile methods if necessary
 | 
						|
    if concrete_type not in concrete_type_store.methods_compiled:
 | 
						|
        create_methods_and_properties_from_stubs(
 | 
						|
            concrete_type, method_stubs, property_stubs
 | 
						|
        )
 | 
						|
        # Create hooks after methods to ensure no name collisions between hooks and methods.
 | 
						|
        # If done before, hooks can overshadow methods that aren't exported.
 | 
						|
        create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
 | 
						|
        torch._C._run_emit_module_hook(cpp_module)
 | 
						|
        concrete_type_store.methods_compiled.add(concrete_type)
 | 
						|
 | 
						|
    # Copy the forward hooks and pre-hooks to the new ScriptModule
 | 
						|
    # to allow the hooks to be run from eager as ScriptFunctions
 | 
						|
    for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
 | 
						|
        script_module._forward_pre_hooks[idx] = fn
 | 
						|
    for idx, fn in enumerate(script_module._c._get_forward_hooks()):
 | 
						|
        script_module._forward_hooks[idx] = fn
 | 
						|
 | 
						|
    # Special handling so methods like __len__ work in script methods on classes derived from containers
 | 
						|
    if (
 | 
						|
        isinstance(
 | 
						|
            nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
 | 
						|
        )
 | 
						|
        and "__len__" not in cpp_module._method_names()
 | 
						|
    ):
 | 
						|
        script_module.define(f"def __len__(self):\n   return {len(nn_module)}\n")
 | 
						|
    if (
 | 
						|
        isinstance(nn_module, torch.nn.ModuleDict)
 | 
						|
        and "__contains__" not in cpp_module._method_names()
 | 
						|
    ):
 | 
						|
        if len(nn_module.keys()):
 | 
						|
            keys = repr(list(nn_module.keys()))
 | 
						|
            script_module.define(
 | 
						|
                f"def __contains__(self, key: str):\n   return key in {keys}\n"
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            script_module.define("def __contains__(self, key: str):\n   return False\n")
 | 
						|
 | 
						|
    # Make the compiled methods available to the Python ScriptModule class.
 | 
						|
    for method_stub in method_stubs:
 | 
						|
        if method_stub.original_method is None:
 | 
						|
            # define()'d methods don't have an Python original_method, so we
 | 
						|
            # don't need to do any Python re-wrapping stuff
 | 
						|
            continue
 | 
						|
 | 
						|
        name = method_stub.original_method.__name__
 | 
						|
        if name != method_stub.def_.name().name:
 | 
						|
            # TODO: Why skip this? Because @torch.jit._overload_method will
 | 
						|
            # mangle the name of the function.
 | 
						|
            continue
 | 
						|
        script_method = cpp_module._get_method(name)
 | 
						|
 | 
						|
        # Wrap the original to propagate docstrings and such.
 | 
						|
        # TODO: we don't currently do this functions that are recursively
 | 
						|
        # compiled, we should.
 | 
						|
        wrapped_script_method = functools.wraps(method_stub.original_method)(
 | 
						|
            script_method
 | 
						|
        )
 | 
						|
 | 
						|
        # Add the methods to the script_module directly. This ensures they will
 | 
						|
        # be found first when `name` is looked up (as opposed to the stubs or
 | 
						|
        # nn.Module.forward)
 | 
						|
        script_module.__dict__[name] = wrapped_script_method
 | 
						|
 | 
						|
    # Make module properties available on the Python ScriptModule class.
 | 
						|
    for property_stub in property_stubs:
 | 
						|
        property_name = property_stub.def_.name().name
 | 
						|
        fget = cpp_module._get_method(property_stub.def_.getter_name().name)
 | 
						|
        # Setter is optional, so it may not exist.
 | 
						|
        setter_name = property_stub.def_.setter_name()
 | 
						|
        fset = cpp_module._get_method(setter_name.name) if setter_name else None
 | 
						|
        script_module.__dict__[property_name] = property(property_name, fget, fset)  # type: ignore[arg-type]
 | 
						|
 | 
						|
    # copy over python methods to script module if they aren't defined on the script module
 | 
						|
    # this is currently an internal api used only on module containers
 | 
						|
    for name in dir(nn_module):
 | 
						|
        if name in ignored_properties:
 | 
						|
            continue
 | 
						|
        item = getattr(nn_module, name, None)
 | 
						|
        if (
 | 
						|
            _jit_internal.get_torchscript_modifier(item)
 | 
						|
            is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
 | 
						|
        ):
 | 
						|
            add_python_attr_to_scripted_model(script_module, nn_module, name)
 | 
						|
 | 
						|
    return script_module
 | 
						|
 | 
						|
 | 
						|
# We define shims of certain attributes on the RecursiveScriptModule to support
 | 
						|
# magic methods. To check if a script model defines an attribute we need
 | 
						|
# to also check that the attribute is not the shim
 | 
						|
def script_model_defines_attr(script_model, attr):
 | 
						|
    script_attr = getattr(script_model, attr, None)
 | 
						|
    if script_attr is None:
 | 
						|
        return False
 | 
						|
    default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None)
 | 
						|
    if default_attr is None:
 | 
						|
        return False
 | 
						|
    return script_attr != default_attr
 | 
						|
 | 
						|
 | 
						|
def add_python_attr_to_scripted_model(script_model, orig, attr):
 | 
						|
    if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
 | 
						|
        setattr(script_model, attr, getattr(orig, attr))
 | 
						|
 | 
						|
 | 
						|
def get_overload_annotations(mod, jit_ignored_properties):
 | 
						|
    # original function => [(mangled overload name, overload function)]
 | 
						|
    overloads = {}
 | 
						|
 | 
						|
    for name in dir(type(mod)):
 | 
						|
        if name in jit_ignored_properties:
 | 
						|
            continue
 | 
						|
        item = getattr(mod, name, None)
 | 
						|
        if not callable(item):
 | 
						|
            continue
 | 
						|
 | 
						|
        # builtin functions like repr() in python 2 do not have __module__ defined
 | 
						|
        if hasattr(item, "__module__") and item.__module__ is not None:
 | 
						|
            method_overloads = _jit_internal._get_overloaded_methods(
 | 
						|
                item, mod.__class__
 | 
						|
            )
 | 
						|
            if method_overloads is None:
 | 
						|
                continue
 | 
						|
 | 
						|
            # pyrefly: ignore  # missing-attribute
 | 
						|
            if item.__func__ in method_overloads:
 | 
						|
                raise RuntimeError(
 | 
						|
                    _jit_internal.get_overload_no_implementation_error_message(
 | 
						|
                        "method", item.__func__
 | 
						|
                    )
 | 
						|
                )
 | 
						|
 | 
						|
            names = [name + "__" + str(i) for i in range(len(method_overloads))]
 | 
						|
            overloads[item] = list(zip(names, method_overloads))
 | 
						|
 | 
						|
    return overloads
 | 
						|
 | 
						|
 | 
						|
def get_overload_name_mapping(overload_info):
 | 
						|
    # Same format as __overloads__
 | 
						|
    # original function => [overload names]
 | 
						|
    overload_name_mappings: dict[str, list[str]] = {}
 | 
						|
    for orig_fn, overloads in overload_info.items():
 | 
						|
        original_name = orig_fn.__name__
 | 
						|
        if original_name not in overload_name_mappings:
 | 
						|
            overload_name_mappings[original_name] = []
 | 
						|
 | 
						|
        for overload_name, _ in overloads:
 | 
						|
            overload_name_mappings[original_name].append(overload_name)
 | 
						|
    return overload_name_mappings
 | 
						|
 | 
						|
 | 
						|
def _check_no_signature(func):
 | 
						|
    signature = torch.jit.annotations.get_signature(
 | 
						|
        func, None, fake_range(), inspect.ismethod(func)
 | 
						|
    )
 | 
						|
    if signature is None:
 | 
						|
        qual_name = _jit_internal._qualified_name(func)
 | 
						|
        raise RuntimeError(
 | 
						|
            f"Must explicitly add type annotations to overloaded functions: {qual_name}"
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def make_stubs_for_overloads(overload_info):
 | 
						|
    overload_stubs = []
 | 
						|
    for orig_fn, overloads in overload_info.items():
 | 
						|
        orig_ast = get_jit_def(
 | 
						|
            orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule"
 | 
						|
        )
 | 
						|
        for overload_name, overload_fn in overloads:
 | 
						|
            _check_no_signature(overload_fn)
 | 
						|
            over_ast = get_jit_def(
 | 
						|
                overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule"
 | 
						|
            )
 | 
						|
            new_ast = torch._C._replace_overloaded_method_decl(
 | 
						|
                over_ast.decl(), orig_ast, overload_name
 | 
						|
            )
 | 
						|
            _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
 | 
						|
            overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
 | 
						|
    return overload_stubs
 | 
						|
 | 
						|
 | 
						|
def check_module_initialized(mod):
 | 
						|
    assert isinstance(mod, torch.nn.Module)
 | 
						|
    if not hasattr(mod, "_parameters"):
 | 
						|
        raise RuntimeError(
 | 
						|
            f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?"
 | 
						|
        )
 | 
						|
 | 
						|
    # This is to avoid importing torch.distributed.nn
 | 
						|
    if not hasattr(mod, "remote_parameters"):
 | 
						|
        for name, param in mod._parameters.items():
 | 
						|
            if param is not None and torch.nn.parameter.is_lazy(param):
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?"
 | 
						|
                )
 | 
						|
        for name, buf in mod._buffers.items():
 | 
						|
            if buf is not None and torch.nn.parameter.is_lazy(buf):
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?"
 | 
						|
                )
 | 
						|
 | 
						|
 | 
						|
def infer_methods_to_compile(nn_module):
 | 
						|
    """Implement the default rules for which methods should act as starting points for compilation.
 | 
						|
 | 
						|
    (TODO add a link when the rules are published).
 | 
						|
    """
 | 
						|
    check_module_initialized(nn_module)
 | 
						|
    ignored_properties = jit_ignored_properties(nn_module)
 | 
						|
 | 
						|
    methods: list[str] = []
 | 
						|
    if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
 | 
						|
        nn_module.forward
 | 
						|
    ):
 | 
						|
        forward_func = getattr(nn_module.forward, "__func__", None)
 | 
						|
        module_forward = getattr(torch.nn.Module, "forward", None)
 | 
						|
        if forward_func != module_forward:
 | 
						|
            methods = ["forward"]
 | 
						|
 | 
						|
    exported = []
 | 
						|
    for name in dir(nn_module):
 | 
						|
        if name in ignored_properties:
 | 
						|
            continue
 | 
						|
        item = getattr(nn_module, name, None)
 | 
						|
        if (
 | 
						|
            _jit_internal.get_torchscript_modifier(item)
 | 
						|
            is _jit_internal.FunctionModifiers.EXPORT
 | 
						|
        ):
 | 
						|
            exported.append(name)
 | 
						|
 | 
						|
    methods = methods + exported
 | 
						|
 | 
						|
    overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
 | 
						|
    overload_info = get_overload_annotations(nn_module, ignored_properties)
 | 
						|
    overload_name_mappings.update(get_overload_name_mapping(overload_info))
 | 
						|
    overload_stubs = make_stubs_for_overloads(overload_info)
 | 
						|
 | 
						|
    nn_module.__overloads__ = overload_name_mappings
 | 
						|
 | 
						|
    # we shouldn't directly compile overloaded methods, just its overloads
 | 
						|
    def ignore_overloaded(method_name):
 | 
						|
        return method_name not in overload_name_mappings
 | 
						|
 | 
						|
    filtered_methods = filter(ignore_overloaded, methods)
 | 
						|
 | 
						|
    # Unique the methods. We don't want to use a set to store the methods because it
 | 
						|
    # introduces non-determinism to compile order.
 | 
						|
    uniquer: set[str] = set()
 | 
						|
    uniqued_methods = []
 | 
						|
    for name in filtered_methods:
 | 
						|
        if name in uniquer:
 | 
						|
            continue
 | 
						|
        uniqued_methods.append(name)
 | 
						|
        uniquer.add(name)
 | 
						|
 | 
						|
    stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods]
 | 
						|
    return overload_stubs + stubs
 | 
						|
 | 
						|
 | 
						|
def get_hook_stubs(nn_module):
 | 
						|
    """Return forward hook and pre_hook ScriptModuleStubs."""
 | 
						|
    check_module_initialized(nn_module)
 | 
						|
    hook_map: dict = {}
 | 
						|
 | 
						|
    hook_stubs = []
 | 
						|
    for hook in nn_module._forward_hooks.values():
 | 
						|
        if hook.__name__ in hook_map:
 | 
						|
            if id(hook) != id(hook_map[hook.__name__]):
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"Hook '{hook.__name__}' on {type(nn_module).__name__} "
 | 
						|
                    "has at least two different python definitions."
 | 
						|
                    " Please use unique names for all hooks."
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            hook_map[hook.__name__] = hook
 | 
						|
        hook_stubs.append(make_stub(hook, hook.__name__))
 | 
						|
 | 
						|
    pre_hook_stubs = []
 | 
						|
    for pre_hook in nn_module._forward_pre_hooks.values():
 | 
						|
        if pre_hook.__name__ in hook_map:
 | 
						|
            if id(pre_hook) != id(hook_map[pre_hook.__name__]):
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} "
 | 
						|
                    "has at least two different python definitions."
 | 
						|
                    " Please use unique names for all hooks."
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            hook_map[pre_hook.__name__] = pre_hook
 | 
						|
        pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__))
 | 
						|
 | 
						|
    return hook_stubs, pre_hook_stubs
 | 
						|
 | 
						|
 | 
						|
def get_property_stubs(nn_module):
 | 
						|
    """Create property stubs for the properties of the module by creating method stubs for the getter and setter."""
 | 
						|
    module_ty = type(nn_module)
 | 
						|
    properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
 | 
						|
    rcbs = {}
 | 
						|
 | 
						|
    for name in dir(module_ty):
 | 
						|
        item = getattr(module_ty, name, None)
 | 
						|
        if isinstance(item, property):
 | 
						|
            if not item.fget:
 | 
						|
                raise RuntimeError(
 | 
						|
                    f"Property {name} of {nn_module.__name__} must have a getter"
 | 
						|
                )
 | 
						|
 | 
						|
            rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
 | 
						|
 | 
						|
    stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
 | 
						|
    return stubs
 | 
						|
 | 
						|
 | 
						|
def interface_script(mod_interface, nn_module):
 | 
						|
    """
 | 
						|
    Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile.
 | 
						|
 | 
						|
    Args:
 | 
						|
        mod_interface: the interface type that the module have
 | 
						|
        nn_module:  The original Python nn.Module that we are creating a ScriptModule for.
 | 
						|
    """
 | 
						|
    if isinstance(nn_module, torch.jit.ScriptModule):
 | 
						|
        return nn_module
 | 
						|
 | 
						|
    check_module_initialized(nn_module)
 | 
						|
 | 
						|
    def infer_interface_methods_to_compile(nn_module):
 | 
						|
        """Rule to infer the methods from the interface type.
 | 
						|
 | 
						|
        It is used to know which methods need to act as starting points for compilation.
 | 
						|
        """
 | 
						|
        stubs = [
 | 
						|
            make_stub_from_method(nn_module, method)
 | 
						|
            for method in mod_interface.getMethodNames()
 | 
						|
        ]
 | 
						|
        return stubs
 | 
						|
 | 
						|
    return create_script_module(nn_module, infer_interface_methods_to_compile)
 | 
						|
 | 
						|
 | 
						|
def try_compile_fn(fn, loc):
 | 
						|
    if _jit_internal.is_ignored_fn(fn):
 | 
						|
        # Don't do anything for @ignore'd functions
 | 
						|
        return None
 | 
						|
 | 
						|
    if isinstance(fn, torch.nn.Module):
 | 
						|
        # Since modules are callable pybind recognizes them as functions, but
 | 
						|
        # don't do anything for them
 | 
						|
        return None
 | 
						|
 | 
						|
    if not inspect.isfunction(fn) and not inspect.ismethod(fn):
 | 
						|
        raise RuntimeError(
 | 
						|
            f"`{fn}` is not a function. Recursive scripting only supports "
 | 
						|
            "Python functions or methods currently.\n"
 | 
						|
            f"Consider manually annotating `{fn}` with @torch.jit.script."
 | 
						|
        )
 | 
						|
 | 
						|
    # The object returned by __prepare_scriptable__ might have a different closure.
 | 
						|
    # Resolve it here to get the right resolution callback.
 | 
						|
    fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn  # type: ignore[operator]
 | 
						|
 | 
						|
    # We don't have the actual scope where the function was defined, but we can
 | 
						|
    # extract the necessary info from the closed over variables on the function
 | 
						|
    # object
 | 
						|
    rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
 | 
						|
    return torch.jit.script(fn, _rcb=rcb)
 | 
						|
 | 
						|
 | 
						|
def wrap_cpp_class(cpp_class):
 | 
						|
    """Wrap this torch._C.Object in a Python RecursiveScriptClass."""
 | 
						|
    return torch.jit.RecursiveScriptClass(cpp_class)
 | 
						|
 | 
						|
 | 
						|
def wrap_cpp_module(cpp_module):
 | 
						|
    """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
 | 
						|
 | 
						|
    def init_fn(script_module):
 | 
						|
        for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
 | 
						|
            setattr(script_module, name, wrap_cpp_module(cpp_module))
 | 
						|
        script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
 | 
						|
            script_module._c._type()
 | 
						|
        )
 | 
						|
 | 
						|
        for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
 | 
						|
            script_module._forward_pre_hooks[idx] = fn
 | 
						|
        for idx, fn in enumerate(script_module._c._get_forward_hooks()):
 | 
						|
            script_module._forward_hooks[idx] = fn
 | 
						|
 | 
						|
    return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
 | 
						|
 | 
						|
 | 
						|
def compile_unbound_method(concrete_type, fn):
 | 
						|
    if _jit_internal.is_ignored_fn(fn):
 | 
						|
        return None
 | 
						|
    stub = make_stub(fn, fn.__name__)
 | 
						|
    with torch._jit_internal._disable_emit_hooks():
 | 
						|
        # We don't want to call the hooks here since the graph that is calling
 | 
						|
        # this function is not yet complete
 | 
						|
        create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
 | 
						|
    return stub
 | 
						|
 | 
						|
 | 
						|
def lazy_bind(concrete_type, unbound_method):
 | 
						|
    """
 | 
						|
    Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method.
 | 
						|
 | 
						|
    We do this so that any Python shenanigans that
 | 
						|
    will poison type sharing are impossible at compile time.
 | 
						|
    """
 | 
						|
 | 
						|
    def lazy_binding_method(cpp_module, *args):
 | 
						|
        def init_fn(script_module):
 | 
						|
            orig_class = concrete_type.py_class
 | 
						|
 | 
						|
            # Copy @ignored/@unused methods from the original module to the new one.
 | 
						|
            # This ensures they are available during execution.
 | 
						|
            for name in dir(orig_class):
 | 
						|
                item = getattr(orig_class, name, None)
 | 
						|
                if _jit_internal.is_ignored_fn(item):
 | 
						|
                    setattr(script_module, name, item)
 | 
						|
 | 
						|
            # Copy constants over so they are available during execution.
 | 
						|
            for name, value in concrete_type.get_constants().items():
 | 
						|
                setattr(script_module, name, value)
 | 
						|
 | 
						|
        script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
 | 
						|
        method = types.MethodType(unbound_method, script_module)
 | 
						|
        return method(*args)
 | 
						|
 | 
						|
    # make the lazy binding method "look like" the original method
 | 
						|
    lazy_binding_method.original_fn = unbound_method  # type: ignore[attr-defined]
 | 
						|
    lazy_binding_method.__name__ = unbound_method.__name__
 | 
						|
    torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
 | 
						|
 | 
						|
    return lazy_binding_method
 |