Revert "[Dynamo] Trace torch function modes entered outside of torch.compile (#133137)"

This reverts commit fafdd588f27e1d56090c6d260d0382c255eaf9eb.

Reverted https://github.com/pytorch/pytorch/pull/133137 on behalf of https://github.com/albanD due to Broke tests on main ([comment](https://github.com/pytorch/pytorch/pull/134732#issuecomment-2348886378))
This commit is contained in:
PyTorch MergeBot
2024-09-13 12:52:58 +00:00
parent 3f30360d05
commit eb7dd91dd1
14 changed files with 202 additions and 455 deletions

View File

@ -14,17 +14,6 @@ from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -335,130 +324,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp)
self.assertEqual(cnt.frame_count, 2)
def test_nested_torch_function_mode(self):
mode_1_called = False
mode_2_called = False
def reset_state():
nonlocal mode_1_called
nonlocal mode_2_called
mode_1_called = False
mode_2_called = False
ones = torch.ones(2, 2)
zeros = torch.zeros(2, 2)
class TestMode1(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_1_called
mode_1_called = True
if func == torch.add:
return zeros
return super().__torch_function__(func, types, args, kwargs)
class TestMode2(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_2_called
mode_2_called = True
if func == torch.mul:
return ones
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
def fn_2(x):
return torch.mul(x, 3) + torch.add(x, 3)
inp = torch.ones(2, 2) + 1
for fn_i in [fn, fn_2]:
fn_opt = torch.compile(fn_i, fullgraph=True)
with TestMode1(), TestMode2():
expected = fn_i(inp), mode_1_called, mode_2_called
reset_state()
actual = fn_opt(inp), mode_1_called, mode_2_called
reset_state()
self.assertEqual(expected, actual)
def test_torch_function_mode_disable(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
with torch._C.DisableTorchFunction():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_highest_priority(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -26,7 +26,6 @@ from torch.testing._internal.common_utils import (
parametrize,
requires_cuda,
run_tests,
skipIfCrossRef,
skipIfRocm,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
@ -2883,7 +2882,6 @@ def forward(self, pred_1, x_1):
gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x)
@skipIfNoDynamoSupport
@skipIfCrossRef # Arg order changes with crossref
def test_scan_simple_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
@ -2990,7 +2988,6 @@ class TestControlFlowTraced(TestCase):
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with crossref
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
@ -3253,7 +3250,6 @@ def forward(self, arg0_1):
self._check_compile(fn, inp, backend=backend)
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with cross ref
def test_while_loop_simple_with_linear_compile_check_graph(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
from torch._dynamo.testing import EagerAndRecordGraphs

View File

@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
from torch.testing._internal.common_utils import IS_WINDOWS
class TestHelperModules:
@ -139,8 +139,6 @@ class TestMetaDataPorting(QuantizationTestCase):
self.assertEqual(v, node_tags[k])
return m
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
def test_simple_metadata_porting(self):
"""
Model under test

View File

@ -605,10 +605,6 @@ def _compile(
output: Optional[OutputGraph] = None
tracer: Optional[InstructionTranslator] = None
tf_mode_stack: List[
torch.overrides.TorchFunctionMode
] = torch.overrides._get_current_function_mode_stack()
@preserve_global_state
def transform(
instructions: List[Instruction], code_options: Dict[str, object]
@ -622,7 +618,6 @@ def _compile(
locals,
globals,
builtins,
tf_mode_stack,
code_options,
compiler_fn,
one_graph,

View File

@ -98,7 +98,6 @@ from .source import (
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
SubclassAttrListSource,
TorchFunctionModeStackSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
@ -112,7 +111,6 @@ from .utils import (
dict_keys_repr,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
guard_failures,
istype,
key_is_id,
@ -316,7 +314,6 @@ CLOSURE_VARS = {
"___dict_contains": lambda a, b: a in b,
"___tuple_iterator_len": tuple_iterator_len,
"___tuple_iterator_getitem": tuple_iterator_getitem,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
@ -904,15 +901,6 @@ class GuardBuilder(GuardBuilderBase):
):
assert base_guard_manager # to make mypy happy
out = base_guard_manager
elif istype(source, TorchFunctionModeStackSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_torch_function_mode_stack_at(
source._get_index()
),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(
@ -2226,8 +2214,6 @@ class CheckFunctionManager:
self.output_graph = output_graph
w_builder = None
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
# in case a set default device call was made in the graph.
self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None
)

View File

@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source):
ind: int
def name(self):
return f"___get_torch_function_mode_stack_at({self._get_index()})"
return ""
def _get_index(self):
from .variables.torch_function import TorchFunctionModeStackVariable

View File

@ -19,7 +19,20 @@ import traceback
import types
import typing
import weakref
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from unittest.mock import patch
import torch
@ -59,12 +72,14 @@ from .source import (
GlobalWeakRefSource,
LocalSource,
Source,
TorchFunctionModeStackSource,
)
from .trace_rules import is_builtin_constant, is_forbidden
from .utils import (
counters,
get_fake_value,
get_instruction_source_311,
get_torch_function_mode_stack,
graph_break_dup_warning_checker,
istype,
LazyString,
@ -105,10 +120,11 @@ from .variables.misc import (
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
TorchFunctionModeVariable,
)
if TYPE_CHECKING:
from .variables.torch_function import TorchFunctionModeVariable
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
@ -268,10 +284,6 @@ class BlockStackEntry:
return ReenterWith(self.stack_index)
def exit(self, tx):
if hasattr(self, "graph_break") and isinstance(
self.with_context, TorchFunctionModeVariable
):
return
assert self.with_context is not None
return self.with_context.exit(tx)
@ -640,9 +652,7 @@ def break_graph_if_unsupported(*, push):
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
assert b.with_context is not None
assert isinstance(
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
)
assert isinstance(b.with_context, ContextWrappingVariable)
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -718,7 +728,7 @@ class InstructionTranslatorBase(
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
symbolic_globals: Dict[str, VariableTracker]
symbolic_torch_function_state: SymbolicTorchFunctionState
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
stack: List[VariableTracker]
instruction_pointer: Optional[int]
current_instruction: Instruction
@ -2538,7 +2548,7 @@ class InstructionTranslatorBase(
code_options: Dict[str, Any],
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -2553,7 +2563,7 @@ class InstructionTranslatorBase(
self.output = output
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_state = symbolic_torch_function_state
self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
self.stack = []
# stack of variable names for tracking 3.13 closures
self.name_stack: list[Any] = []
@ -2642,7 +2652,6 @@ class InstructionTranslator(InstructionTranslatorBase):
f_locals,
f_globals,
f_builtins,
torch_function_mode_stack,
code_options,
compiler_fn,
one_graph,
@ -2677,7 +2686,7 @@ class InstructionTranslator(InstructionTranslatorBase):
symbolic_locals={}, # set below
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
symbolic_torch_function_mode_stack=collections.deque(),
f_code=f_code,
export=export,
inline_depth=0,
@ -2712,9 +2721,7 @@ class InstructionTranslator(InstructionTranslatorBase):
if k in f_locals
}
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
torch_function_mode_stack
)
self._init_torch_function_mode_stack()
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
if export:
@ -2755,6 +2762,29 @@ class InstructionTranslator(InstructionTranslatorBase):
)
unimplemented(msg)
def _init_torch_function_mode_stack(self):
from .variables.torch_function import TorchFunctionModeStackVariable
TorchFunctionModeStackVariable.reset()
self.symbolic_torch_function_mode_stack: Deque[
TorchFunctionModeVariable
] = collections.deque()
# We want to retrieve all modes to properly reconstruct the stack if needed
py_stack = get_torch_function_mode_stack(filter_ignored=False)
if py_stack:
has_device_context = isinstance(
py_stack[0], torch.utils._device.DeviceContext
)
for i, val in enumerate(py_stack):
self.symbolic_torch_function_mode_stack.append(
variables.LazyVariableTracker.create(
val, source=TorchFunctionModeStackSource(i)
)
)
def get_example_value(self, source: Source):
if isinstance(source, LocalSource):
return self.f_locals[source.local_name]
@ -3086,7 +3116,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_torch_function_mode_stack,
closure_cells,
func,
)
@ -3096,7 +3126,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_state,
parent.symbolic_torch_function_mode_stack,
closure_cells,
func,
)
@ -3149,7 +3179,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code: types.CodeType,
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_state: SymbolicTorchFunctionState,
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
) -> None:
@ -3166,7 +3196,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
f_builtins=f_builtins,
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_state=symbolic_torch_function_state,
symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -3258,7 +3258,6 @@ MOD_INLINELIST = [
"torch.testing",
"torch.utils._content_store",
"torch.utils._contextlib",
"torch.utils._device",
"torch.utils._foreach_utils",
"torch.utils._python_dispatch",
"torch.utils._pytree",
@ -3593,9 +3592,7 @@ def lookup_inner(
if reasons is not None:
reasons.add("func name is patched_init")
return SkipFunctionVariable
elif name == "__torch_function__" or (
obj and obj.__name__ == "__torch_function__"
):
elif name == "__torch_function__":
if reasons is not None:
reasons.add("func name is __torch_function__")
return UserFunctionVariable

View File

@ -63,6 +63,7 @@ import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
@ -3064,9 +3065,7 @@ def is_parameter_freezing():
def get_torch_function_mode_stack(filter_ignored=True):
from .variables.torch_function import IGNORED_MODES
stack = [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
]
stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
@ -3086,11 +3085,6 @@ def set_torch_function_mode_stack(stack):
_push_on_torch_function_stack(mode)
def clear_torch_function_mode_stack():
for i in range(_len_torch_function_stack()):
_pop_torch_function_stack()
def verify_guard_fn_signature(value):
fn = value.__metadata_guard__
sig = inspect.signature(fn)

View File

@ -637,8 +637,6 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
tx.output.set_torch_function_state(values[0])

View File

@ -156,15 +156,6 @@ tracing_state_functions = {
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
@functools.lru_cache(None)
def get_overridable_functions():
from itertools import chain
from torch.overrides import get_overridable_functions as get_overridable_functions_
return set(chain(*get_overridable_functions_().values()))
class BaseTorchVariable(VariableTracker):
"""common base for all torch.* functions, classes, modules and other things"""
@ -815,10 +806,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
if not tx.symbolic_torch_function_state.mode_stack:
if not tx.symbolic_torch_function_mode_stack:
raise unimplemented("Popping from an empty torch function mode stack")
TorchFunctionModeStackVariable.register_mutation(tx)
return tx.symbolic_torch_function_state.pop_torch_function_mode()
return tx.symbolic_torch_function_mode_stack.pop()
@register(torch._C._push_on_torch_function_stack)
def handle_push_torch_function(
@ -826,7 +817,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
assert len(args) == 1 and not kwargs
TorchFunctionModeStackVariable.register_mutation(tx)
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
tx.symbolic_torch_function_mode_stack.append(args[0])
return ConstantVariable.create(None)
@register(torch._C._len_torch_function_stack)
@ -834,9 +825,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
return ConstantVariable.create(
len(tx.symbolic_torch_function_state.mode_stack)
)
return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
@register(torch.set_default_device)
def handle_set_default_device(
@ -868,9 +857,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy
if self.torch_function_override_enabled(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs
):
@ -892,144 +878,147 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if result:
return result
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
if can_dispatch_torch_function(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
else:
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
To support this behavior, we need to allow const-propping tensors that store symint data.
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
"""
log.warning(msg)
unimplemented(msg)
log.warning(msg)
unimplemented(msg)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
)
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
)
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if (
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where some of the output tensors were non-contiguous"
"out= op was called where output tensor was non-contiguous"
)
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where some of the output tensors were non-contiguous"
)
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
return tensor_variable
return tensor_variable
def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
"""inline behavior of torch.nn.modules.utils._ntuple"""
@ -1157,12 +1146,3 @@ Either create the tensor outside the compiled region, or do not set the tensor t
source
)
return result
def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()
or isinstance(
self.get_function(),
(torch._ops.OpOverload, torch._ops.OpOverloadPacket),
)
) and can_dispatch_torch_function(tx, args, kwargs)

View File

@ -1,11 +1,8 @@
# mypy: ignore-errors
import collections
import contextlib
import inspect
from typing import Deque, Dict, List, TYPE_CHECKING
from typing import Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
@ -18,7 +15,6 @@ from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_att
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
@ -63,67 +59,6 @@ banned_attrs = [
IGNORED_MODES = {DeviceContext}
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions:
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
# (if either are entered, this will be False)
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
# torch._C.DisableTorchFunction has been entered
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
# concepts (modes and subclasses) are enabled.
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
# the stack length from the enablement state of torch function modes.
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
# or not torch function modes are enabled and whether we should trace it.
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
# This differs from the C API of the same name
# this will only be false iff we have entered torch._C.DisableTorchFunction
# and does not take into account the mode stack length, while the C API bundles these
# two concepts
self.torch_function_mode_enabled = (
not torch._C._is_torch_function_all_disabled()
)
self.cur_mode = None
TorchFunctionModeStackVariable.reset()
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
for i, val in enumerate(py_stack):
self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
)
def in_torch_function_mode(self):
return len(self.mode_stack) > 0
def pop_torch_function_mode(self):
return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var):
self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager
def _pop_mode_for_inlining(self):
old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode()
try:
yield self.cur_mode
finally:
mode = self.cur_mode
self.cur_mode = old_mode
self.push_torch_function_mode(mode)
class TorchFunctionModeStackVariable(VariableTracker):
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
@ -153,20 +88,19 @@ class TorchFunctionModeStackVariable(VariableTracker):
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(),
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
stack = tx.symbolic_torch_function_mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
stack.insert(
tx.symbolic_torch_function_mode_stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
@ -175,7 +109,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_state.mode_stack
stack = tx.symbolic_torch_function_mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@ -190,39 +124,23 @@ class TorchFunctionModeStackVariable(VariableTracker):
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, source=None, **kwargs):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
@staticmethod
def get_global_mangled_name(tx, val):
return get_safe_global_name(
tx, f"__torch_function_mode_{val.__class__.__name__}", val
)
def reconstruct(self, codegen):
# This shouldn't be called unless we have a source
# We don't support locally created torch function modes yet
assert self.source
self.source.reconstruct(codegen)
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self,
build_torch_function_fn(tx, self.value, self.source),
fn,
types,
args,
kwargs,
)
def _call_func(self, tx: "InstructionTranslator", values):
unimplemented("enter/exit for torch function mode NYI")
def _call_func(self, tx, values):
unimplemented("torch function mode context manager is not supported yet")
def _get_all_args(args, kwargs):
@ -313,13 +231,9 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
has_overridden_args = any(
return tx.output.torch_function_enabled and any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
tf_state = tx.symbolic_torch_function_state
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
@ -331,20 +245,11 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
_get_subclass_type,
)
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
if tx.symbolic_torch_function_state.in_torch_function_mode():
res = tx.symbolic_torch_function_state.call_torch_function_mode(
tx, fn, types, args, kwargs
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
types,
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
args,
kwargs,
)

View File

@ -82,6 +82,11 @@ def is_forbidden_context_manager(ctx):
from _pytest.python_api import RaisesContext
from _pytest.recwarn import WarningsChecker
# TODO mlazos: Temporary to get this stack to pass
# remove in subsequent PR
from torch.overrides import BaseTorchFunctionMode
f_ctxs.append(BaseTorchFunctionMode)
f_ctxs.append(RaisesContext)
f_ctxs.append(WarningsChecker)
except ImportError:
@ -408,6 +413,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
# import here to avoid an unfortunate circular dependency.
from .ctx_manager import GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
@ -415,6 +421,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj
elif is_namedtuple_cls(self.value):
fields = namedtuple_fields(self.value)
# check if this a quasi-namedtuple or a real one

View File

@ -506,11 +506,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if (
not torch.compiler.is_dynamo_compiling()
and log.isEnabledFor(logging.DEBUG)
and config.extended_debug_current_loc
):
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
frame = _find_user_code_frame()
if frame is not None:
log.debug(