mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Bugfix] Fix circular import between export and dynamo from tensor fn map (#158931)
Fixes #158120 The issue was caused by populating a builtin tensor fn map at import time; if torch.export.export was called before any dynamo imports with the `meta` device, this map would not be populated, and so would populate on import time which would try to call `torch.disable`, which would not yet be initialized Fix is to populate this map lazily ``` python test/dynamo/imports_non_circular_repro.py TestImports.test_circular_import_with_export_meta ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158931 Approved by: https://github.com/StrongerXi, https://github.com/mlazos, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
abb0bf45df
commit
efc810c7d0
@ -33,12 +33,13 @@ import types
|
||||
import typing
|
||||
import unittest
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import KeysView, Sequence
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
from collections.abc import Iterable, KeysView, Sequence
|
||||
from typing import Any, Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import sym_float, sym_int
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
@ -153,6 +154,126 @@ polyfill_fn_mapping = {
|
||||
operator.ge: polyfills.cmp_ge,
|
||||
}
|
||||
|
||||
bin_ops = (
|
||||
operator.pow,
|
||||
operator.mul,
|
||||
operator.matmul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
operator.mod,
|
||||
operator.add,
|
||||
operator.lt,
|
||||
operator.gt,
|
||||
operator.ge,
|
||||
operator.le,
|
||||
operator.ne,
|
||||
operator.eq,
|
||||
operator.sub,
|
||||
operator.ipow,
|
||||
operator.imul,
|
||||
operator.imatmul,
|
||||
operator.ifloordiv,
|
||||
operator.itruediv,
|
||||
operator.imod,
|
||||
operator.iadd,
|
||||
operator.isub,
|
||||
)
|
||||
|
||||
bin_int_ops = (
|
||||
operator.and_,
|
||||
operator.or_,
|
||||
operator.xor,
|
||||
operator.iand,
|
||||
operator.ixor,
|
||||
operator.ior,
|
||||
)
|
||||
|
||||
un_int_ops = (operator.invert,)
|
||||
|
||||
tensor_and_int_ops = (
|
||||
operator.lshift,
|
||||
operator.rshift,
|
||||
operator.ilshift,
|
||||
operator.irshift,
|
||||
operator.getitem,
|
||||
)
|
||||
|
||||
un_ops = (
|
||||
operator.abs,
|
||||
operator.pos,
|
||||
operator.neg,
|
||||
operator.not_, # Note: this has a local scalar dense call
|
||||
operator.length_hint,
|
||||
)
|
||||
|
||||
BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
|
||||
# These functions represent the r* versions of the above ops
|
||||
# Basically, if __add__(1, Tensor) is called, it is translated
|
||||
# to __radd__(Tensor, 1).
|
||||
# In the builtin var, we check if there is a tensor in the first args position,
|
||||
# if not, we swap the args and use the r* version of the op.
|
||||
BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
|
||||
|
||||
def populate_builtin_to_tensor_fn_map():
|
||||
global BUILTIN_TO_TENSOR_FN_MAP
|
||||
if len(BUILTIN_TO_TENSOR_FN_MAP) > 0:
|
||||
# Only populate once; after there are elements present no need to
|
||||
# repopulate
|
||||
return
|
||||
most_recent_func = None
|
||||
|
||||
class GetMethodMode(BaseTorchFunctionMode):
|
||||
"""
|
||||
Mode to extract the correct methods from torch function invocations
|
||||
(Used to get the correct torch.Tensor methods from builtins)
|
||||
"""
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
nonlocal most_recent_func
|
||||
most_recent_func = func
|
||||
return func(*args, **kwargs)
|
||||
|
||||
inp0 = torch.ones(1)
|
||||
inp1 = torch.ones(1)
|
||||
inp0_int = torch.ones(1, dtype=torch.int32)
|
||||
inp1_int = torch.ones(1, dtype=torch.int32)
|
||||
with GetMethodMode():
|
||||
setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
|
||||
(lambda o: o(inp0), un_ops),
|
||||
(lambda o: o(inp0_int), un_int_ops),
|
||||
(lambda o: o(inp0, inp1), bin_ops),
|
||||
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
|
||||
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
|
||||
]
|
||||
for setup_fn, op_list in setups_and_oplists:
|
||||
for op in op_list:
|
||||
setup_fn(op)
|
||||
assert most_recent_func is not None
|
||||
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
|
||||
|
||||
# gather the reverse functions
|
||||
rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
|
||||
(
|
||||
lambda o: o(1, inp1),
|
||||
bin_ops,
|
||||
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
||||
(lambda o: o(1, inp1_int), bin_int_ops),
|
||||
(lambda o: o(0, inp0_int), tensor_and_int_ops),
|
||||
]
|
||||
|
||||
rskips = {operator.matmul, operator.imatmul, operator.getitem}
|
||||
for setup_fn, op_list in rsetups_and_oplists:
|
||||
for op in op_list:
|
||||
if op in rskips:
|
||||
continue
|
||||
setup_fn(op)
|
||||
assert most_recent_func is not None
|
||||
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
|
||||
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
|
||||
|
||||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
"""
|
||||
@ -1043,17 +1164,15 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
# insert handling for torch function here
|
||||
from .builder import SourcelessBuilder
|
||||
from .torch_function import (
|
||||
BUILTIN_TO_TENSOR_FN_MAP,
|
||||
BUILTIN_TO_TENSOR_RFN_MAP,
|
||||
can_dispatch_torch_function,
|
||||
dispatch_torch_function,
|
||||
)
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
# Only remap the fn to tensor methods if we aren't exporting
|
||||
# export serde does not handle method descriptors today
|
||||
if not tx.export:
|
||||
# Ensure the builtin maps are populated before accessing them
|
||||
populate_builtin_to_tensor_fn_map()
|
||||
# Use sourceless builder, we built the map ourselves
|
||||
if not isinstance(args[0], TensorVariable):
|
||||
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
|
||||
|
Reference in New Issue
Block a user