[Bugfix] Fix circular import between export and dynamo from tensor fn map (#158931)

Fixes #158120

The issue was caused by populating a builtin tensor fn map at import time; if torch.export.export was called before any dynamo imports with the `meta` device, this map would not be populated, and so would populate on import time which would try to call `torch.disable`, which would not yet be initialized

Fix is to populate this map lazily

```
python test/dynamo/imports_non_circular_repro.py TestImports.test_circular_import_with_export_meta
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158931
Approved by: https://github.com/StrongerXi, https://github.com/mlazos, https://github.com/anijain2305
This commit is contained in:
Lucas Kabela
2025-07-24 22:24:52 +00:00
committed by PyTorch MergeBot
parent abb0bf45df
commit efc810c7d0
4 changed files with 160 additions and 78 deletions

View File

@ -33,12 +33,13 @@ import types
import typing
import unittest
from collections import defaultdict, OrderedDict
from collections.abc import KeysView, Sequence
from typing import Callable, TYPE_CHECKING, Union
from collections.abc import Iterable, KeysView, Sequence
from typing import Any, Callable, TYPE_CHECKING, Union
import torch
from torch import sym_float, sym_int
from torch._subclasses.meta_utils import is_sparse_any
from torch.overrides import BaseTorchFunctionMode
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, graph_break_hints, polyfills, variables
@ -153,6 +154,126 @@ polyfill_fn_mapping = {
operator.ge: polyfills.cmp_ge,
}
bin_ops = (
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
operator.ne,
operator.eq,
operator.sub,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
)
bin_int_ops = (
operator.and_,
operator.or_,
operator.xor,
operator.iand,
operator.ixor,
operator.ior,
)
un_int_ops = (operator.invert,)
tensor_and_int_ops = (
operator.lshift,
operator.rshift,
operator.ilshift,
operator.irshift,
operator.getitem,
)
un_ops = (
operator.abs,
operator.pos,
operator.neg,
operator.not_, # Note: this has a local scalar dense call
operator.length_hint,
)
BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
# These functions represent the r* versions of the above ops
# Basically, if __add__(1, Tensor) is called, it is translated
# to __radd__(Tensor, 1).
# In the builtin var, we check if there is a tensor in the first args position,
# if not, we swap the args and use the r* version of the op.
BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {}
def populate_builtin_to_tensor_fn_map():
global BUILTIN_TO_TENSOR_FN_MAP
if len(BUILTIN_TO_TENSOR_FN_MAP) > 0:
# Only populate once; after there are elements present no need to
# repopulate
return
most_recent_func = None
class GetMethodMode(BaseTorchFunctionMode):
"""
Mode to extract the correct methods from torch function invocations
(Used to get the correct torch.Tensor methods from builtins)
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
nonlocal most_recent_func
most_recent_func = func
return func(*args, **kwargs)
inp0 = torch.ones(1)
inp1 = torch.ones(1)
inp0_int = torch.ones(1, dtype=torch.int32)
inp1_int = torch.ones(1, dtype=torch.int32)
with GetMethodMode():
setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
(lambda o: o(inp0), un_ops),
(lambda o: o(inp0_int), un_int_ops),
(lambda o: o(inp0, inp1), bin_ops),
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
]
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
setup_fn(op)
assert most_recent_func is not None
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
# gather the reverse functions
rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [
(
lambda o: o(1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: o(1, inp1_int), bin_int_ops),
(lambda o: o(0, inp0_int), tensor_and_int_ops),
]
rskips = {operator.matmul, operator.imatmul, operator.getitem}
for setup_fn, op_list in rsetups_and_oplists:
for op in op_list:
if op in rskips:
continue
setup_fn(op)
assert most_recent_func is not None
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
class BuiltinVariable(VariableTracker):
"""
@ -1043,17 +1164,15 @@ class BuiltinVariable(VariableTracker):
# insert handling for torch function here
from .builder import SourcelessBuilder
from .torch_function import (
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
can_dispatch_torch_function,
dispatch_torch_function,
)
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP
if can_dispatch_torch_function(tx, args, kwargs):
# Only remap the fn to tensor methods if we aren't exporting
# export serde does not handle method descriptors today
if not tx.export:
# Ensure the builtin maps are populated before accessing them
populate_builtin_to_tensor_fn_map()
# Use sourceless builder, we built the map ourselves
if not isinstance(args[0], TensorVariable):
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: