From efc810c7d02f492eed5f1393ffd8c0b35b52e46f Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 24 Jul 2025 22:24:52 +0000 Subject: [PATCH] [Bugfix] Fix circular import between export and dynamo from tensor fn map (#158931) Fixes #158120 The issue was caused by populating a builtin tensor fn map at import time; if torch.export.export was called before any dynamo imports with the `meta` device, this map would not be populated, and so would populate on import time which would try to call `torch.disable`, which would not yet be initialized Fix is to populate this map lazily ``` python test/dynamo/imports_non_circular_repro.py TestImports.test_circular_import_with_export_meta ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158931 Approved by: https://github.com/StrongerXi, https://github.com/mlazos, https://github.com/anijain2305 --- test/dynamo/imports_non_circular_repro.py | 29 +++++ test/dynamo/test_modes.py | 6 +- torch/_dynamo/variables/builtin.py | 135 ++++++++++++++++++++-- torch/_dynamo/variables/torch_function.py | 68 ----------- 4 files changed, 160 insertions(+), 78 deletions(-) create mode 100644 test/dynamo/imports_non_circular_repro.py diff --git a/test/dynamo/imports_non_circular_repro.py b/test/dynamo/imports_non_circular_repro.py new file mode 100644 index 000000000000..8a8d058b810a --- /dev/null +++ b/test/dynamo/imports_non_circular_repro.py @@ -0,0 +1,29 @@ +# Owner(s): ["module: dynamo"] +""" +This file is aimed at providing a simple testcase to reproduce +https://github.com/pytorch/pytorch/issues/158120 + +This means that we cannot rely on torch.dynamo before importing +torch.export, so we can't add this to a file that is a dynamo testcase +""" + +import unittest + +import torch + + +class TestImports(unittest.TestCase): + def test_circular_import_with_export_meta(self): + from torch.export import export + + conv = torch.nn.Conv2d(3, 64, 3, padding=1) + # Note: we want to validate that export within + # torch.device("meta") does not fail due to circular + # import + with torch.device("meta"): + ep = export(conv, (torch.zeros(64, 3, 1, 1),)) + self.assertIsNotNone(ep) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 868627d02026..f8869fd804ef 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -504,11 +504,13 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): # Needs larger cache size since we recompile for each op @patch.object(torch._dynamo.config, "recompile_limit", 48) def test_builtin_equivalent_funcs(self): + from torch._dynamo.variables.builtin import ( + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, + ) from torch._dynamo.variables.torch_function import ( bin_int_ops, bin_ops, - BUILTIN_TO_TENSOR_FN_MAP, - BUILTIN_TO_TENSOR_RFN_MAP, tensor_and_int_ops, un_int_ops, un_ops, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9114cfbe3be5..34012263c73f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -33,12 +33,13 @@ import types import typing import unittest from collections import defaultdict, OrderedDict -from collections.abc import KeysView, Sequence -from typing import Callable, TYPE_CHECKING, Union +from collections.abc import Iterable, KeysView, Sequence +from typing import Any, Callable, TYPE_CHECKING, Union import torch from torch import sym_float, sym_int from torch._subclasses.meta_utils import is_sparse_any +from torch.overrides import BaseTorchFunctionMode from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config, graph_break_hints, polyfills, variables @@ -153,6 +154,126 @@ polyfill_fn_mapping = { operator.ge: polyfills.cmp_ge, } +bin_ops = ( + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +) + +bin_int_ops = ( + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +) + +un_int_ops = (operator.invert,) + +tensor_and_int_ops = ( + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +) + +un_ops = ( + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +) + +BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + + +def populate_builtin_to_tensor_fn_map(): + global BUILTIN_TO_TENSOR_FN_MAP + if len(BUILTIN_TO_TENSOR_FN_MAP) > 0: + # Only populate once; after there are elements present no need to + # repopulate + return + most_recent_func = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + class BuiltinVariable(VariableTracker): """ @@ -1043,17 +1164,15 @@ class BuiltinVariable(VariableTracker): # insert handling for torch function here from .builder import SourcelessBuilder - from .torch_function import ( - BUILTIN_TO_TENSOR_FN_MAP, - BUILTIN_TO_TENSOR_RFN_MAP, - can_dispatch_torch_function, - dispatch_torch_function, - ) + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP if can_dispatch_torch_function(tx, args, kwargs): # Only remap the fn to tensor methods if we aren't exporting # export serde does not handle method descriptors today if not tx.export: + # Ensure the builtin maps are populated before accessing them + populate_builtin_to_tensor_fn_map() # Use sourceless builder, we built the map ourselves if not isinstance(args[0], TensorVariable): if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index bb3ac286773b..7ee8c48b0ffb 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -38,7 +38,6 @@ import torch.utils._pytree as pytree from torch._guards import Source from torch.overrides import ( _get_overloaded_args, - BaseTorchFunctionMode, get_default_nowrap_functions, TorchFunctionMode, ) @@ -124,73 +123,6 @@ un_ops = [ operator.length_hint, ] -BUILTIN_TO_TENSOR_FN_MAP = {} - -# These functions represent the r* versions of the above ops -# Basically, if __add__(1, Tensor) is called, it is translated -# to __radd__(Tensor, 1). -# In the builtin var, we check if there is a tensor in the first args position, -# if not, we swap the args and use the r* version of the op. -BUILTIN_TO_TENSOR_RFN_MAP = {} - - -def populate_builtin_to_tensor_fn_map(): - global BUILTIN_TO_TENSOR_FN_MAP - - most_recent_func = None - - class GetMethodMode(BaseTorchFunctionMode): - """ - Mode to extract the correct methods from torch function invocations - (Used to get the correct torch.Tensor methods from builtins) - """ - - def __torch_function__(self, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - nonlocal most_recent_func - most_recent_func = func - return func(*args, **kwargs) - - inp0 = torch.ones(1) - inp1 = torch.ones(1) - inp0_int = torch.ones(1, dtype=torch.int32) - inp1_int = torch.ones(1, dtype=torch.int32) - with GetMethodMode(): - setups_and_oplists = [ - (lambda o: o(inp0), un_ops), - (lambda o: o(inp0_int), un_int_ops), - (lambda o: o(inp0, inp1), bin_ops), - (lambda o: o(inp0_int, inp1_int), bin_int_ops), - (lambda o: o(inp0_int, 0), tensor_and_int_ops), - ] - for setup_fn, op_list in setups_and_oplists: - for op in op_list: - setup_fn(op) - assert most_recent_func is not None - BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func - - # gather the reverse functions - rsetups_and_oplists = [ - ( - lambda o: o(1, inp1), - bin_ops, - ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) - (lambda o: o(1, inp1_int), bin_int_ops), - (lambda o: o(0, inp0_int), tensor_and_int_ops), - ] - - rskips = {operator.matmul, operator.imatmul, operator.getitem} - for setup_fn, op_list in rsetups_and_oplists: - for op in op_list: - if op in rskips: - continue - setup_fn(op) - assert most_recent_func is not None - if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: - BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func - - -populate_builtin_to_tensor_fn_map() banned_attrs = [ fn.__self__.__name__