mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"
This reverts commit fafdd588f27e1d56090c6d260d0382c255eaf9eb. Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
This commit is contained in:
@ -14,17 +14,6 @@ from torch.utils._device import DeviceContext
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class TestMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
if func == torch.add:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -335,130 +324,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
fn(inp)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_nested_torch_function_mode(self):
|
||||
mode_1_called = False
|
||||
mode_2_called = False
|
||||
|
||||
def reset_state():
|
||||
nonlocal mode_1_called
|
||||
nonlocal mode_2_called
|
||||
mode_1_called = False
|
||||
mode_2_called = False
|
||||
|
||||
ones = torch.ones(2, 2)
|
||||
zeros = torch.zeros(2, 2)
|
||||
|
||||
class TestMode1(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
nonlocal mode_1_called
|
||||
|
||||
mode_1_called = True
|
||||
|
||||
if func == torch.add:
|
||||
return zeros
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class TestMode2(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
nonlocal mode_2_called
|
||||
|
||||
mode_2_called = True
|
||||
|
||||
if func == torch.mul:
|
||||
return ones
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
def fn_2(x):
|
||||
return torch.mul(x, 3) + torch.add(x, 3)
|
||||
|
||||
inp = torch.ones(2, 2) + 1
|
||||
|
||||
for fn_i in [fn, fn_2]:
|
||||
fn_opt = torch.compile(fn_i, fullgraph=True)
|
||||
with TestMode1(), TestMode2():
|
||||
expected = fn_i(inp), mode_1_called, mode_2_called
|
||||
reset_state()
|
||||
actual = fn_opt(inp), mode_1_called, mode_2_called
|
||||
reset_state()
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_disable(self):
|
||||
class TestSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
if func == torch.add:
|
||||
return torch.ones(2, 2)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class TestMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
if func == torch.add:
|
||||
return torch.zeros(2, 2)
|
||||
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
with torch._C.DisableTorchFunctionSubclass():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_highest_priority(self):
|
||||
class TestSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args, kwargs=None):
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
if func == torch.add:
|
||||
return torch.ones(2, 2)
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def fn(x):
|
||||
return torch.add(x, 3)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
||||
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
with TestMode(), torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses", {TestSubclass}
|
||||
):
|
||||
expected = fn(inp)
|
||||
actual = fn_opt(inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -26,7 +26,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfRocm,
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
@ -2883,7 +2882,6 @@ def forward(self, pred_1, x_1):
|
||||
gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfCrossRef # Arg order changes with crossref
|
||||
def test_scan_simple_graph(self):
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
@ -2990,7 +2988,6 @@ class TestControlFlowTraced(TestCase):
|
||||
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
@skipIfCrossRef # Arg order changes with crossref
|
||||
def test_cond_simple_with_linear_compile_check_graph(self):
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
@ -3253,7 +3250,6 @@ def forward(self, arg0_1):
|
||||
self._check_compile(fn, inp, backend=backend)
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
@skipIfCrossRef # Arg order changes with cross ref
|
||||
def test_while_loop_simple_with_linear_compile_check_graph(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
|
||||
from torch.fx import Node
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
|
||||
class TestHelperModules:
|
||||
@ -139,8 +139,6 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
self.assertEqual(v, node_tags[k])
|
||||
return m
|
||||
|
||||
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
|
||||
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
|
||||
def test_simple_metadata_porting(self):
|
||||
"""
|
||||
Model under test
|
||||
|
@ -605,10 +605,6 @@ def _compile(
|
||||
output: Optional[OutputGraph] = None
|
||||
tracer: Optional[InstructionTranslator] = None
|
||||
|
||||
tf_mode_stack: List[
|
||||
torch.overrides.TorchFunctionMode
|
||||
] = torch.overrides._get_current_function_mode_stack()
|
||||
|
||||
@preserve_global_state
|
||||
def transform(
|
||||
instructions: List[Instruction], code_options: Dict[str, object]
|
||||
@ -622,7 +618,6 @@ def _compile(
|
||||
locals,
|
||||
globals,
|
||||
builtins,
|
||||
tf_mode_stack,
|
||||
code_options,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
|
@ -98,7 +98,6 @@ from .source import (
|
||||
ScriptObjectQualifiedNameSource,
|
||||
ShapeEnvSource,
|
||||
SubclassAttrListSource,
|
||||
TorchFunctionModeStackSource,
|
||||
TupleIteratorGetItemSource,
|
||||
TypeSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
@ -112,7 +111,6 @@ from .utils import (
|
||||
dict_keys_repr,
|
||||
get_custom_getattr,
|
||||
get_torch_function_mode_stack,
|
||||
get_torch_function_mode_stack_at,
|
||||
guard_failures,
|
||||
istype,
|
||||
key_is_id,
|
||||
@ -316,7 +314,6 @@ CLOSURE_VARS = {
|
||||
"___dict_contains": lambda a, b: a in b,
|
||||
"___tuple_iterator_len": tuple_iterator_len,
|
||||
"___tuple_iterator_getitem": tuple_iterator_getitem,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
"inf": float("inf"),
|
||||
@ -904,15 +901,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||
):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager
|
||||
elif istype(source, TorchFunctionModeStackSource):
|
||||
out = root_guard_manager.lambda_manager(
|
||||
python_lambda=lambda _: get_torch_function_mode_stack_at(
|
||||
source._get_index()
|
||||
),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, GradSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.grad_manager(
|
||||
@ -2226,8 +2214,6 @@ class CheckFunctionManager:
|
||||
self.output_graph = output_graph
|
||||
w_builder = None
|
||||
|
||||
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
|
||||
# in case a set default device call was made in the graph.
|
||||
self.torch_function_mode_stack = (
|
||||
output_graph.torch_function_mode_stack if output_graph else None
|
||||
)
|
||||
|
@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source):
|
||||
ind: int
|
||||
|
||||
def name(self):
|
||||
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
||||
return ""
|
||||
|
||||
def _get_index(self):
|
||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||
|
@ -19,7 +19,20 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Deque,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -59,12 +72,14 @@ from .source import (
|
||||
GlobalWeakRefSource,
|
||||
LocalSource,
|
||||
Source,
|
||||
TorchFunctionModeStackSource,
|
||||
)
|
||||
from .trace_rules import is_builtin_constant, is_forbidden
|
||||
from .utils import (
|
||||
counters,
|
||||
get_fake_value,
|
||||
get_instruction_source_311,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_dup_warning_checker,
|
||||
istype,
|
||||
LazyString,
|
||||
@ -105,10 +120,11 @@ from .variables.misc import (
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
||||
from .variables.torch_function import (
|
||||
SymbolicTorchFunctionState,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .variables.torch_function import TorchFunctionModeVariable
|
||||
|
||||
from .variables.user_defined import (
|
||||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
@ -268,10 +284,6 @@ class BlockStackEntry:
|
||||
return ReenterWith(self.stack_index)
|
||||
|
||||
def exit(self, tx):
|
||||
if hasattr(self, "graph_break") and isinstance(
|
||||
self.with_context, TorchFunctionModeVariable
|
||||
):
|
||||
return
|
||||
assert self.with_context is not None
|
||||
return self.with_context.exit(tx)
|
||||
|
||||
@ -640,9 +652,7 @@ def break_graph_if_unsupported(*, push):
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
assert b.with_context is not None
|
||||
assert isinstance(
|
||||
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
|
||||
)
|
||||
assert isinstance(b.with_context, ContextWrappingVariable)
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
@ -718,7 +728,7 @@ class InstructionTranslatorBase(
|
||||
output: OutputGraph
|
||||
symbolic_locals: Dict[str, VariableTracker]
|
||||
symbolic_globals: Dict[str, VariableTracker]
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
|
||||
stack: List[VariableTracker]
|
||||
instruction_pointer: Optional[int]
|
||||
current_instruction: Instruction
|
||||
@ -2538,7 +2548,7 @@ class InstructionTranslatorBase(
|
||||
code_options: Dict[str, Any],
|
||||
symbolic_locals: Dict[str, VariableTracker],
|
||||
symbolic_globals: Dict[str, VariableTracker],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
|
||||
f_code: types.CodeType,
|
||||
export: bool,
|
||||
inline_depth: int,
|
||||
@ -2553,7 +2563,7 @@ class InstructionTranslatorBase(
|
||||
self.output = output
|
||||
self.symbolic_locals = symbolic_locals
|
||||
self.symbolic_globals = symbolic_globals
|
||||
self.symbolic_torch_function_state = symbolic_torch_function_state
|
||||
self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
|
||||
self.stack = []
|
||||
# stack of variable names for tracking 3.13 closures
|
||||
self.name_stack: list[Any] = []
|
||||
@ -2642,7 +2652,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
f_locals,
|
||||
f_globals,
|
||||
f_builtins,
|
||||
torch_function_mode_stack,
|
||||
code_options,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
@ -2677,7 +2686,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
symbolic_locals={}, # set below
|
||||
# A global var is inserted only after a STORE_GLOBAL happens to it
|
||||
symbolic_globals={},
|
||||
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
|
||||
symbolic_torch_function_mode_stack=collections.deque(),
|
||||
f_code=f_code,
|
||||
export=export,
|
||||
inline_depth=0,
|
||||
@ -2712,9 +2721,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
if k in f_locals
|
||||
}
|
||||
|
||||
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
|
||||
torch_function_mode_stack
|
||||
)
|
||||
self._init_torch_function_mode_stack()
|
||||
|
||||
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
|
||||
if export:
|
||||
@ -2755,6 +2762,29 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
)
|
||||
unimplemented(msg)
|
||||
|
||||
def _init_torch_function_mode_stack(self):
|
||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||
|
||||
TorchFunctionModeStackVariable.reset()
|
||||
|
||||
self.symbolic_torch_function_mode_stack: Deque[
|
||||
TorchFunctionModeVariable
|
||||
] = collections.deque()
|
||||
# We want to retrieve all modes to properly reconstruct the stack if needed
|
||||
py_stack = get_torch_function_mode_stack(filter_ignored=False)
|
||||
|
||||
if py_stack:
|
||||
has_device_context = isinstance(
|
||||
py_stack[0], torch.utils._device.DeviceContext
|
||||
)
|
||||
|
||||
for i, val in enumerate(py_stack):
|
||||
self.symbolic_torch_function_mode_stack.append(
|
||||
variables.LazyVariableTracker.create(
|
||||
val, source=TorchFunctionModeStackSource(i)
|
||||
)
|
||||
)
|
||||
|
||||
def get_example_value(self, source: Source):
|
||||
if isinstance(source, LocalSource):
|
||||
return self.f_locals[source.local_name]
|
||||
@ -3086,7 +3116,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code,
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
parent.symbolic_torch_function_mode_stack,
|
||||
closure_cells,
|
||||
func,
|
||||
)
|
||||
@ -3096,7 +3126,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code,
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
parent.symbolic_torch_function_mode_stack,
|
||||
closure_cells,
|
||||
func,
|
||||
)
|
||||
@ -3149,7 +3179,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
code: types.CodeType,
|
||||
symbolic_locals: Dict[str, VariableTracker],
|
||||
symbolic_globals: Dict[str, VariableTracker],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
|
||||
closure_cells: Dict[str, VariableTracker],
|
||||
funcvar: BaseUserFunctionVariable,
|
||||
) -> None:
|
||||
@ -3166,7 +3196,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
f_builtins=f_builtins,
|
||||
symbolic_locals=symbolic_locals,
|
||||
symbolic_globals=symbolic_globals,
|
||||
symbolic_torch_function_state=symbolic_torch_function_state,
|
||||
symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
|
||||
instructions=instructions,
|
||||
code_options={k: getattr(code, k) for k in get_code_keys()},
|
||||
f_code=code,
|
||||
|
@ -3258,7 +3258,6 @@ MOD_INLINELIST = [
|
||||
"torch.testing",
|
||||
"torch.utils._content_store",
|
||||
"torch.utils._contextlib",
|
||||
"torch.utils._device",
|
||||
"torch.utils._foreach_utils",
|
||||
"torch.utils._python_dispatch",
|
||||
"torch.utils._pytree",
|
||||
@ -3593,9 +3592,7 @@ def lookup_inner(
|
||||
if reasons is not None:
|
||||
reasons.add("func name is patched_init")
|
||||
return SkipFunctionVariable
|
||||
elif name == "__torch_function__" or (
|
||||
obj and obj.__name__ == "__torch_function__"
|
||||
):
|
||||
elif name == "__torch_function__":
|
||||
if reasons is not None:
|
||||
reasons.add("func name is __torch_function__")
|
||||
return UserFunctionVariable
|
||||
|
@ -63,6 +63,7 @@ import torch.fx.experimental.symbolic_shapes
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._C import (
|
||||
_get_function_stack_at,
|
||||
_instruction_counter,
|
||||
_len_torch_function_stack,
|
||||
_pop_torch_function_stack,
|
||||
@ -3064,9 +3065,7 @@ def is_parameter_freezing():
|
||||
def get_torch_function_mode_stack(filter_ignored=True):
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
stack = [
|
||||
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
|
||||
]
|
||||
stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
|
||||
if filter_ignored:
|
||||
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
|
||||
|
||||
@ -3086,11 +3085,6 @@ def set_torch_function_mode_stack(stack):
|
||||
_push_on_torch_function_stack(mode)
|
||||
|
||||
|
||||
def clear_torch_function_mode_stack():
|
||||
for i in range(_len_torch_function_stack()):
|
||||
_pop_torch_function_stack()
|
||||
|
||||
|
||||
def verify_guard_fn_signature(value):
|
||||
fn = value.__metadata_guard__
|
||||
sig = inspect.signature(fn)
|
||||
|
@ -637,8 +637,6 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
assert len(values) == 1
|
||||
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
|
||||
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
|
||||
tx.output.set_torch_function_state(values[0])
|
||||
|
||||
|
||||
|
@ -156,15 +156,6 @@ tracing_state_functions = {
|
||||
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_overridable_functions():
|
||||
from itertools import chain
|
||||
|
||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||
|
||||
return set(chain(*get_overridable_functions_().values()))
|
||||
|
||||
|
||||
class BaseTorchVariable(VariableTracker):
|
||||
"""common base for all torch.* functions, classes, modules and other things"""
|
||||
|
||||
@ -815,10 +806,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not args and not kwargs
|
||||
if not tx.symbolic_torch_function_state.mode_stack:
|
||||
if not tx.symbolic_torch_function_mode_stack:
|
||||
raise unimplemented("Popping from an empty torch function mode stack")
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
return tx.symbolic_torch_function_state.pop_torch_function_mode()
|
||||
return tx.symbolic_torch_function_mode_stack.pop()
|
||||
|
||||
@register(torch._C._push_on_torch_function_stack)
|
||||
def handle_push_torch_function(
|
||||
@ -826,7 +817,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
):
|
||||
assert len(args) == 1 and not kwargs
|
||||
TorchFunctionModeStackVariable.register_mutation(tx)
|
||||
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
|
||||
tx.symbolic_torch_function_mode_stack.append(args[0])
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@register(torch._C._len_torch_function_stack)
|
||||
@ -834,9 +825,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
):
|
||||
assert not args and not kwargs
|
||||
return ConstantVariable.create(
|
||||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
@ -868,9 +857,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
from . import ConstantVariable, SymNodeVariable, TensorVariable
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
if self.torch_function_override_enabled(tx, args, kwargs):
|
||||
return dispatch_torch_function(tx, self, args, kwargs)
|
||||
|
||||
if self.can_constant_fold_through() and check_unspec_or_constant_args(
|
||||
args, kwargs
|
||||
):
|
||||
@ -892,144 +878,147 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
if result:
|
||||
return result
|
||||
|
||||
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
return dispatch_torch_function(tx, self, args, kwargs)
|
||||
else:
|
||||
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
|
||||
|
||||
all_ints_or_floats = all(
|
||||
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
|
||||
for x in args
|
||||
)
|
||||
if (
|
||||
getattr(self.value, "__module__", "") == "torch"
|
||||
and self.value.__name__ in bin_ops
|
||||
and any_symints_or_symfloats
|
||||
and all_ints_or_floats
|
||||
):
|
||||
msg = f"""\
|
||||
all_ints_or_floats = all(
|
||||
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
|
||||
for x in args
|
||||
)
|
||||
if (
|
||||
getattr(self.value, "__module__", "") == "torch"
|
||||
and self.value.__name__ in bin_ops
|
||||
and any_symints_or_symfloats
|
||||
and all_ints_or_floats
|
||||
):
|
||||
msg = f"""\
|
||||
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
|
||||
To support this behavior, we need to allow const-propping tensors that store symint data.
|
||||
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
|
||||
"""
|
||||
log.warning(msg)
|
||||
unimplemented(msg)
|
||||
log.warning(msg)
|
||||
unimplemented(msg)
|
||||
|
||||
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
||||
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
||||
# of value + args to determine this.
|
||||
fn_ = self.value
|
||||
if any_symints_or_symfloats:
|
||||
torch_sym_op = f"_sym_{self.value.__name__}"
|
||||
if getattr(self.value, "__module__", None) == "math" and hasattr(
|
||||
torch, torch_sym_op
|
||||
):
|
||||
fn_ = getattr(torch, torch_sym_op)
|
||||
# TODO(voz): Replace w/ dynamic shape rewrite table.
|
||||
# Ideally, we would be able to do this at ctor time, but alas we need a combination
|
||||
# of value + args to determine this.
|
||||
fn_ = self.value
|
||||
if any_symints_or_symfloats:
|
||||
torch_sym_op = f"_sym_{self.value.__name__}"
|
||||
if getattr(self.value, "__module__", None) == "math" and hasattr(
|
||||
torch, torch_sym_op
|
||||
):
|
||||
fn_ = getattr(torch, torch_sym_op)
|
||||
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
# tx.output.tracked_fakes. tracked_fakes are used to apply
|
||||
# symbolic_shape guards. Mutating them destroys the information
|
||||
# prior to tracing, which is essential for creating right
|
||||
# guards. So save the shape now, and check later if it has
|
||||
# changed. If it has, graph break.
|
||||
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
|
||||
fake_out_shape = None
|
||||
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
|
||||
# Calling fake tensor propagation can mutate the out= tensor in
|
||||
# tx.output.tracked_fakes. tracked_fakes are used to apply
|
||||
# symbolic_shape guards. Mutating them destroys the information
|
||||
# prior to tracing, which is essential for creating right
|
||||
# guards. So save the shape now, and check later if it has
|
||||
# changed. If it has, graph break.
|
||||
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
|
||||
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn_,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(tensor_variable, TensorVariable)
|
||||
and "requires_grad" in kwargs
|
||||
and kwargs["requires_grad"].as_python_constant()
|
||||
):
|
||||
unimplemented(
|
||||
"""factory functions that return tensors that require grad are not supported.
|
||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn_,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
):
|
||||
# out variants of torch operators like torch.sort and
|
||||
# torch.sigmoid mutate the tensors in the out field. Track such
|
||||
# tensors and rewrite the symbolic locals.
|
||||
if isinstance(tensor_variable, TupleVariable):
|
||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||
output_tensor_names = [
|
||||
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
||||
]
|
||||
for idx, name in enumerate(output_tensor_names):
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
||||
for out_tensor, result_tensor in zip(
|
||||
kwargs["out"].items, tensor_variable.items
|
||||
):
|
||||
if (
|
||||
isinstance(tensor_variable, TensorVariable)
|
||||
and "requires_grad" in kwargs
|
||||
and kwargs["requires_grad"].as_python_constant()
|
||||
):
|
||||
unimplemented(
|
||||
"""factory functions that return tensors that require grad are not supported.
|
||||
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
|
||||
)
|
||||
|
||||
if "out" in kwargs and not (
|
||||
isinstance(kwargs["out"], variables.ConstantVariable)
|
||||
and kwargs["out"].as_python_constant() is None
|
||||
):
|
||||
# out variants of torch operators like torch.sort and
|
||||
# torch.sigmoid mutate the tensors in the out field. Track such
|
||||
# tensors and rewrite the symbolic locals.
|
||||
if isinstance(tensor_variable, TupleVariable):
|
||||
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
|
||||
output_tensor_names = [
|
||||
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
|
||||
]
|
||||
for idx, name in enumerate(output_tensor_names):
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable.items[idx]
|
||||
for out_tensor, result_tensor in zip(
|
||||
kwargs["out"].items, tensor_variable.items
|
||||
):
|
||||
if (
|
||||
out_tensor.source
|
||||
and out_tensor in tx.output.graphargs
|
||||
and isinstance(out_tensor, variables.TensorVariable)
|
||||
and isinstance(result_tensor, variables.TensorVariable)
|
||||
and out_tensor.size != result_tensor.size
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
elif isinstance(tensor_variable, TensorVariable):
|
||||
assert isinstance(kwargs["out"], TensorVariable)
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if (
|
||||
out_tensor.source
|
||||
and out_tensor in tx.output.graphargs
|
||||
and isinstance(out_tensor, variables.TensorVariable)
|
||||
and isinstance(result_tensor, variables.TensorVariable)
|
||||
and out_tensor.size != result_tensor.size
|
||||
kwargs["out"].source
|
||||
and kwargs["out"] in tx.output.graphargs
|
||||
and fake_out_shape != fake_tensor.shape
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
elif isinstance(tensor_variable, TensorVariable):
|
||||
assert isinstance(kwargs["out"], TensorVariable)
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if (
|
||||
kwargs["out"].source
|
||||
and kwargs["out"] in tx.output.graphargs
|
||||
and fake_out_shape != fake_tensor.shape
|
||||
):
|
||||
# It's hard to get out variants with resizing on graph inputs work
|
||||
# properly across dynamo/aot/inductor, just fall back.
|
||||
unimplemented("out variants with resizing on graph inputs")
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
name = tx.find_symbolic_locals_name(kwargs["out"])
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable
|
||||
elif (
|
||||
isinstance(tensor_variable, ConstantVariable)
|
||||
and tensor_variable.value is None
|
||||
):
|
||||
# Handle out-variant custom ops that return None.
|
||||
if isinstance(kwargs["out"], TensorVariable):
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
elif isinstance(kwargs["out"], ListVariable):
|
||||
for idx, x in enumerate(kwargs["out"].items):
|
||||
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
|
||||
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
|
||||
name = tx.find_symbolic_locals_name(kwargs["out"])
|
||||
if name in tx.symbolic_locals:
|
||||
tx.symbolic_locals[name] = tensor_variable
|
||||
elif (
|
||||
isinstance(tensor_variable, ConstantVariable)
|
||||
and tensor_variable.value is None
|
||||
):
|
||||
# Handle out-variant custom ops that return None.
|
||||
if isinstance(kwargs["out"], TensorVariable):
|
||||
assert "example_value" in kwargs["out"].proxy.node.meta
|
||||
fake_out = kwargs["out"].proxy.node.meta["example_value"]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where some of the output tensors were non-contiguous"
|
||||
"out= op was called where output tensor was non-contiguous"
|
||||
)
|
||||
else:
|
||||
unimplemented(f"out variant of {type(kwargs['out'])}")
|
||||
elif isinstance(kwargs["out"], ListVariable):
|
||||
for idx, x in enumerate(kwargs["out"].items):
|
||||
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
|
||||
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
|
||||
if not torch._prims_common.is_contiguous(fake_out):
|
||||
# It's difficult to handle strides correctly in functionalization
|
||||
# when calling an out= op with a non-contiguous out argument
|
||||
unimplemented(
|
||||
"out= op was called where some of the output tensors were non-contiguous"
|
||||
)
|
||||
else:
|
||||
unimplemented(f"out variant of {type(kwargs['out'])}")
|
||||
|
||||
return tensor_variable
|
||||
return tensor_variable
|
||||
|
||||
def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
|
||||
"""inline behavior of torch.nn.modules.utils._ntuple"""
|
||||
@ -1157,12 +1146,3 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
source
|
||||
)
|
||||
return result
|
||||
|
||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||
return (
|
||||
self.get_function() in get_overridable_functions()
|
||||
or isinstance(
|
||||
self.get_function(),
|
||||
(torch._ops.OpOverload, torch._ops.OpOverloadPacket),
|
||||
)
|
||||
) and can_dispatch_torch_function(tx, args, kwargs)
|
||||
|
@ -1,11 +1,8 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
@ -18,7 +15,6 @@ from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_att
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
@ -63,67 +59,6 @@ banned_attrs = [
|
||||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
def __init__(self, py_stack):
|
||||
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
|
||||
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
|
||||
# These are their definitions:
|
||||
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
|
||||
# (if either are entered, this will be False)
|
||||
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
|
||||
# torch._C.DisableTorchFunction has been entered
|
||||
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
|
||||
# concepts (modes and subclasses) are enabled.
|
||||
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
|
||||
# the stack length from the enablement state of torch function modes.
|
||||
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
|
||||
# or not torch function modes are enabled and whether we should trace it.
|
||||
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
|
||||
|
||||
# This differs from the C API of the same name
|
||||
# this will only be false iff we have entered torch._C.DisableTorchFunction
|
||||
# and does not take into account the mode stack length, while the C API bundles these
|
||||
# two concepts
|
||||
self.torch_function_mode_enabled = (
|
||||
not torch._C._is_torch_function_all_disabled()
|
||||
)
|
||||
|
||||
self.cur_mode = None
|
||||
|
||||
TorchFunctionModeStackVariable.reset()
|
||||
|
||||
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
|
||||
|
||||
for i, val in enumerate(py_stack):
|
||||
self.mode_stack.append(
|
||||
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
|
||||
)
|
||||
|
||||
def in_torch_function_mode(self):
|
||||
return len(self.mode_stack) > 0
|
||||
|
||||
def pop_torch_function_mode(self):
|
||||
return self.mode_stack.pop()
|
||||
|
||||
def push_torch_function_mode(self, mode_var):
|
||||
self.mode_stack.append(mode_var)
|
||||
|
||||
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
|
||||
with self._pop_mode_for_inlining() as cur_mode:
|
||||
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_for_inlining(self):
|
||||
old_mode = self.cur_mode
|
||||
self.cur_mode = self.pop_torch_function_mode()
|
||||
try:
|
||||
yield self.cur_mode
|
||||
finally:
|
||||
mode = self.cur_mode
|
||||
self.cur_mode = old_mode
|
||||
self.push_torch_function_mode(mode)
|
||||
|
||||
|
||||
class TorchFunctionModeStackVariable(VariableTracker):
|
||||
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
|
||||
|
||||
@ -153,20 +88,19 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
def register_mutation(cls, tx: "InstructionTranslator"):
|
||||
if cls.stack_value_singleton not in tx.output.side_effects:
|
||||
var = cls(
|
||||
source=Source(),
|
||||
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
|
||||
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
|
||||
)
|
||||
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
|
||||
tx.output.side_effects.mutation(var)
|
||||
|
||||
@classmethod
|
||||
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
stack = tx.symbolic_torch_function_mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
return
|
||||
else:
|
||||
cls.offset += 1
|
||||
stack.insert(
|
||||
tx.symbolic_torch_function_mode_stack.insert(
|
||||
0,
|
||||
TorchFunctionModeVariable(
|
||||
None, source=TorchFunctionModeStackSource(-cls.offset)
|
||||
@ -175,7 +109,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
|
||||
@classmethod
|
||||
def clear_default_device(cls, tx: "InstructionTranslator"):
|
||||
stack = tx.symbolic_torch_function_state.mode_stack
|
||||
stack = tx.symbolic_torch_function_mode_stack
|
||||
if stack and cls.is_device_context(stack[0]):
|
||||
stack.popleft()
|
||||
cls.offset -= 1
|
||||
@ -190,39 +124,23 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
|
||||
@staticmethod
|
||||
def get_global_mangled_name(tx, val):
|
||||
return get_safe_global_name(
|
||||
tx, f"__torch_function_mode_{val.__class__.__name__}", val
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# This shouldn't be called unless we have a source
|
||||
# We don't support locally created torch function modes yet
|
||||
assert self.source
|
||||
self.source.reconstruct(codegen)
|
||||
|
||||
def module_name(self):
|
||||
return self.value.__module__
|
||||
|
||||
def fn_name(self):
|
||||
return type(self.value).__name__
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
return call_torch_function(
|
||||
tx,
|
||||
self,
|
||||
build_torch_function_fn(tx, self.value, self.source),
|
||||
fn,
|
||||
types,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
unimplemented("enter/exit for torch function mode NYI")
|
||||
def _call_func(self, tx, values):
|
||||
unimplemented("torch function mode context manager is not supported yet")
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
@ -313,13 +231,9 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||
|
||||
|
||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
has_overridden_args = any(
|
||||
return tx.output.torch_function_enabled and any(
|
||||
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
|
||||
)
|
||||
tf_state = tx.symbolic_torch_function_state
|
||||
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
|
||||
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
|
||||
)
|
||||
|
||||
|
||||
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
@ -331,20 +245,11 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
_get_subclass_type,
|
||||
)
|
||||
|
||||
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
|
||||
|
||||
if tx.symbolic_torch_function_state.in_torch_function_mode():
|
||||
res = tx.symbolic_torch_function_state.call_torch_function_mode(
|
||||
tx, fn, types, args, kwargs
|
||||
)
|
||||
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
||||
return res
|
||||
|
||||
for arg in overloaded_args:
|
||||
res = arg.call_torch_function(
|
||||
tx,
|
||||
fn,
|
||||
types,
|
||||
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
@ -82,6 +82,11 @@ def is_forbidden_context_manager(ctx):
|
||||
from _pytest.python_api import RaisesContext
|
||||
from _pytest.recwarn import WarningsChecker
|
||||
|
||||
# TODO mlazos: Temporary to get this stack to pass
|
||||
# remove in subsequent PR
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
f_ctxs.append(BaseTorchFunctionMode)
|
||||
f_ctxs.append(RaisesContext)
|
||||
f_ctxs.append(WarningsChecker)
|
||||
except ImportError:
|
||||
@ -408,6 +413,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
# import here to avoid an unfortunate circular dependency.
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
@ -415,6 +421,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
||||
elif is_namedtuple_cls(self.value):
|
||||
fields = namedtuple_fields(self.value)
|
||||
# check if this a quasi-namedtuple or a real one
|
||||
|
@ -506,11 +506,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if (
|
||||
not torch.compiler.is_dynamo_compiling()
|
||||
and log.isEnabledFor(logging.DEBUG)
|
||||
and config.extended_debug_current_loc
|
||||
):
|
||||
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
|
||||
frame = _find_user_code_frame()
|
||||
if frame is not None:
|
||||
log.debug(
|
||||
|
Reference in New Issue
Block a user