Compare commits

...

3 Commits

Author SHA1 Message Date
7342afdad4 [Dynamo] Clear/restore torch function mode stack to prevent overriding torch.compile infrastructure
ghstack-source-id: 4abb0201b60fe47b2a19629ad9383b333218beb7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134733

fix
2024-09-07 14:21:49 -07:00
7f237dd271 [Dynamo] Trace torch function modes
ghstack-source-id: 13e72b90a2dc3918bf367c6304db25084db532e8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133137

Fix

Fix test2

Fix device inlining

fix2

fix2
2024-09-07 14:21:47 -07:00
7e1c068aa2 [Dynamo] Disable metadata tf mode when tracing cond
ghstack-source-id: 177ce9f6c7118a9b5671815025abc86131abe2f3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
2024-09-07 13:23:01 -07:00
21 changed files with 779 additions and 440 deletions

View File

@ -14,6 +14,17 @@ from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -324,6 +335,199 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp)
self.assertEqual(cnt.frame_count, 2)
def test_nested_torch_function_mode(self):
mode_1_called = False
mode_2_called = False
def reset_state():
nonlocal mode_1_called
nonlocal mode_2_called
mode_1_called = False
mode_2_called = False
ones = torch.ones(2, 2)
zeros = torch.zeros(2, 2)
class TestMode1(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_1_called
mode_1_called = True
if func == torch.add:
return zeros
return super().__torch_function__(func, types, args, kwargs)
class TestMode2(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_2_called
mode_2_called = True
if func == torch.mul:
return ones
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
def fn_2(x):
return torch.mul(x, 3) + torch.add(x, 3)
inp = torch.ones(2, 2) + 1
for fn_i in [fn, fn_2]:
fn_opt = torch.compile(fn_i, fullgraph=True)
with TestMode1(), TestMode2():
expected = fn_i(inp), mode_1_called, mode_2_called
reset_state()
actual = fn_opt(inp), mode_1_called, mode_2_called
reset_state()
self.assertEqual(expected, actual)
def test_torch_function_mode_disable(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
with torch._C.DisableTorchFunction():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_highest_priority(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -9,12 +9,7 @@ from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.while_loop import while_loop
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI,
FunctionalTensor,
FunctionalTensorMode,
PythonFunctionalizeAPI,
)
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import SM70OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
@ -24,6 +19,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
TestCase,
@ -1557,6 +1553,7 @@ class TestControlFlowTraced(TestCase):
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with crossref,
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
@ -1664,19 +1661,6 @@ def forward(self, arg0_1, arg1_1, arg2_1):
""", # noqa: B950
)
def _wrap_with_functionalize(self, fn, func_type):
mode = None
if func_type == "cpp":
fn = CppFunctionalizeAPI().functionalize(fn)
elif func_type == "python":
fn = PythonFunctionalizeAPI().functionalize(fn)
mode = FunctionalTensorMode()
elif func_type == "functorch":
fn = torch.func.functionalize(fn)
else:
assert func_type == "no"
return fn, mode
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
def test_while_loop_simple_functionalize_check_graph(self, func_type):
fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
@ -1819,6 +1803,7 @@ def forward(self, arg0_1):
self._check_compile(fn, inp, backend=backend)
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with cross ref
def test_while_loop_simple_with_linear_compile_check_graph(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
from torch._dynamo.testing import EagerAndRecordGraphs
@ -1894,135 +1879,6 @@ def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_bo
return (child, child_1)""", # noqa: B950
)
def test_while_loop_nested2_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested2"]
graphs = self._check_tracing(fn, inp)
gm = graphs["symbolic"]
outer_body = gm.while_loop_body_graph_0
outer_cond = gm.while_loop_cond_graph_0
inner_body = outer_body.while_loop_body_graph_0
inner_cond = outer_body.while_loop_cond_graph_0
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
return (getitem, getitem_1, getitem_2, getitem_3)
""", # noqa: B950
)
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
return (sub, clone, mul, div)
""", # noqa: B950
)
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
return (sub, clone, mul, div)
""", # noqa: B950
)
self.assertExpectedInline(
inner_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71); arg3_1 = None
return (clone, sub, add, sub_1)
""",
)
self.assertExpectedInline(
inner_cond.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
return gt
""",
)
def test_cond_nested_traced(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x, _):
return x.cos()
def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
result_true_true = graph.forward(
x, torch.tensor(True), torch.tensor(True)
) # True + True -> x * x
result_true_false = graph.forward(
x, torch.tensor(True), torch.tensor(False)
) # True + True -> x + x
result_false_true = graph.forward(
x, torch.tensor(False), torch.tensor(True)
) # False + either -> cos
result_false_false = graph.forward(
x, torch.tensor(False), torch.tensor(False)
) # False + either -> cos
self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))
self.assertEqual(result_false_true, result_false_false)
self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)
self.assertEqual(result_false_true, torch.cos(x))
graph = make_fx(f, tracing_mode="symbolic")(
x, torch.tensor(False), torch.tensor(False)
)
self.assertEqual(
graph(x, torch.tensor(True), torch.tensor(True)),
f(x, torch.tensor(True), torch.tensor(True)),
)
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()

View File

@ -181,12 +181,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None
return ()""", # noqa: B950
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
@ -240,7 +238,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
return (getitem_4, getitem_5)""", # noqa: B950
@ -327,9 +325,8 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \
arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg0_1 = arg1_1 = foo_default = None
return ()""", # noqa: B950
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
@ -403,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1); arg4_1 = arg5_1 = arg3_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -415,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -504,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -564,12 +560,12 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
graph_aot,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg2_1])
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -579,12 +575,12 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1])
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None
return (add,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -596,8 +592,8 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
graph_inductor,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
foo_default = torch.ops.mylib.foo.default(arg1_1, arg2_1); foo_default = None
add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg2_1)
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (add,)""",
@ -609,8 +605,8 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None
add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None
return (add,)""",
@ -895,8 +891,8 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg1_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View File

@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs):
return gm.forward
def make_eager_backend_with_torch_function_mode(mode):
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
def fn(gm, fake_tensor_inputs, **kwargs):
with mode:
return gm.forward
return fn
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
if kwargs:

View File

@ -94,6 +94,7 @@ from .symbolic_convert import (
from .trace_rules import is_numpy
from .utils import (
CleanupManager,
clear_torch_function_mode_stack,
CompilationMetrics,
counters,
dynamo_timed,
@ -108,6 +109,7 @@ from .utils import (
orig_code_map,
record_compilation_metrics,
reset_graph_break_dup_checker,
set_torch_function_mode_stack,
setup_compile_debug,
troubleshooting_url,
write_record_to_file,
@ -204,6 +206,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
py_rng_state = random.getstate()
torch_rng_state = torch.random.get_rng_state()
cuda_rng_state = None
prior_tf_mode_stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
allow_tf32 = torch._C._get_cublas_allow_tf32()
@ -220,6 +224,10 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
finally:
cleanup.close()
exit_stack.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
set_torch_function_mode_stack(prior_tf_mode_stack)
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
torch.use_deterministic_algorithms(
@ -605,6 +613,10 @@ def _compile(
output: Optional[OutputGraph] = None
tracer: Optional[InstructionTranslator] = None
tf_mode_stack: List[
torch.overrides.TorchFunctionMode
] = torch.overrides._get_current_function_mode_stack()
@preserve_global_state
def transform(
instructions: List[Instruction], code_options: Dict[str, object]
@ -618,6 +630,7 @@ def _compile(
locals,
globals,
builtins,
tf_mode_stack,
code_options,
compiler_fn,
one_graph,

View File

@ -97,6 +97,7 @@ from .source import (
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
SubclassAttrListSource,
TorchFunctionModeStackSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
@ -110,6 +111,7 @@ from .utils import (
dict_keys_repr,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
guard_failures,
istype,
key_is_id,
@ -313,6 +315,7 @@ CLOSURE_VARS = {
"___dict_contains": lambda a, b: a in b,
"___tuple_iterator_len": tuple_iterator_len,
"___tuple_iterator_getitem": tuple_iterator_getitem,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
@ -900,6 +903,15 @@ class GuardBuilder(GuardBuilderBase):
):
assert base_guard_manager # to make mypy happy
out = base_guard_manager
elif istype(source, TorchFunctionModeStackSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_torch_function_mode_stack_at(
source._get_index()
),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(
@ -2206,6 +2218,8 @@ class CheckFunctionManager:
self.output_graph = output_graph
w_builder = None
# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
# in case a set default device call was made in the graph.
self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None
)

View File

@ -1020,7 +1020,7 @@ class OutputGraph:
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx)
block.exit(tx, is_graph_break=reason.graph_break)
self.cleanup_graph()
tx.prune_dead_locals()

View File

@ -25,6 +25,26 @@ if TYPE_CHECKING:
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice

View File

@ -608,7 +608,7 @@ class TorchFunctionModeStackSource(Source):
ind: int
def name(self):
return ""
return f"___get_torch_function_mode_stack_at({self._get_index()})"
def _get_index(self):
from .variables.torch_function import TorchFunctionModeStackVariable

View File

@ -19,20 +19,7 @@ import traceback
import types
import typing
import weakref
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from unittest.mock import patch
import torch
@ -72,14 +59,12 @@ from .source import (
GlobalWeakRefSource,
LocalSource,
Source,
TorchFunctionModeStackSource,
)
from .trace_rules import is_builtin_constant, is_forbidden
from .utils import (
counters,
get_fake_value,
get_instruction_source_311,
get_torch_function_mode_stack,
graph_break_dup_warning_checker,
istype,
LazyString,
@ -120,11 +105,10 @@ from .variables.misc import (
)
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
if TYPE_CHECKING:
from .variables.torch_function import TorchFunctionModeVariable
from .variables.torch_function import (
SymbolicTorchFunctionState,
TorchFunctionModeVariable,
)
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
@ -283,9 +267,12 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
def exit(self, tx, is_graph_break):
assert self.with_context is not None
return self.with_context.exit(tx)
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
class ReturnValueOp(Exception):
@ -651,8 +638,12 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
continue
assert b.with_context is not None
assert isinstance(b.with_context, ContextWrappingVariable)
assert isinstance(b.with_context, (ContextWrappingVariable))
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -728,7 +719,7 @@ class InstructionTranslatorBase(
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
symbolic_globals: Dict[str, VariableTracker]
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"]
symbolic_torch_function_state: SymbolicTorchFunctionState
stack: List[VariableTracker]
instruction_pointer: Optional[int]
current_instruction: Instruction
@ -2305,7 +2296,11 @@ class InstructionTranslatorBase(
):
unimplemented(f"{inst.opname} {ctx}")
if isinstance(ctx, GenericContextWrappingVariable):
if (
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
breakpoint()
self.generic_context_manager_depth += 1
# Need this redundant check for mypy
@ -2548,7 +2543,7 @@ class InstructionTranslatorBase(
code_options: Dict[str, Any],
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
symbolic_torch_function_state: SymbolicTorchFunctionState,
f_code: types.CodeType,
export: bool,
inline_depth: int,
@ -2563,7 +2558,7 @@ class InstructionTranslatorBase(
self.output = output
self.symbolic_locals = symbolic_locals
self.symbolic_globals = symbolic_globals
self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack
self.symbolic_torch_function_state = symbolic_torch_function_state
self.stack = []
# stack of variable names for tracking 3.13 closures
self.name_stack: list[Any] = []
@ -2652,6 +2647,7 @@ class InstructionTranslator(InstructionTranslatorBase):
f_locals,
f_globals,
f_builtins,
torch_function_mode_stack,
code_options,
compiler_fn,
one_graph,
@ -2686,7 +2682,7 @@ class InstructionTranslator(InstructionTranslatorBase):
symbolic_locals={}, # set below
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals={},
symbolic_torch_function_mode_stack=collections.deque(),
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
f_code=f_code,
export=export,
inline_depth=0,
@ -2721,7 +2717,9 @@ class InstructionTranslator(InstructionTranslatorBase):
if k in f_locals
}
self._init_torch_function_mode_stack()
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
torch_function_mode_stack
)
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
if export:
@ -2762,29 +2760,6 @@ class InstructionTranslator(InstructionTranslatorBase):
)
unimplemented(msg)
def _init_torch_function_mode_stack(self):
from .variables.torch_function import TorchFunctionModeStackVariable
TorchFunctionModeStackVariable.reset()
self.symbolic_torch_function_mode_stack: Deque[
TorchFunctionModeVariable
] = collections.deque()
# We want to retrieve all modes to properly reconstruct the stack if needed
py_stack = get_torch_function_mode_stack(filter_ignored=False)
if py_stack:
has_device_context = isinstance(
py_stack[0], torch.utils._device.DeviceContext
)
for i, val in enumerate(py_stack):
self.symbolic_torch_function_mode_stack.append(
variables.LazyVariableTracker.create(
val, source=TorchFunctionModeStackSource(i)
)
)
def get_example_value(self, source: Source):
if isinstance(source, LocalSource):
return self.f_locals[source.local_name]
@ -3116,7 +3091,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_mode_stack,
parent.symbolic_torch_function_state,
closure_cells,
func,
)
@ -3126,7 +3101,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code,
sub_locals,
parent.symbolic_globals,
parent.symbolic_torch_function_mode_stack,
parent.symbolic_torch_function_state,
closure_cells,
func,
)
@ -3179,7 +3154,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
code: types.CodeType,
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"],
symbolic_torch_function_state: SymbolicTorchFunctionState,
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
) -> None:
@ -3196,7 +3171,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
f_builtins=f_builtins,
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack,
symbolic_torch_function_state=symbolic_torch_function_state,
instructions=instructions,
code_options={k: getattr(code, k) for k in get_code_keys()},
f_code=code,

View File

@ -3254,6 +3254,7 @@ MOD_INLINELIST = [
"torch.testing",
"torch.utils._content_store",
"torch.utils._contextlib",
"torch.utils._device",
"torch.utils._foreach_utils",
"torch.utils._python_dispatch",
"torch.utils._pytree",
@ -3588,7 +3589,9 @@ def lookup_inner(
if reasons is not None:
reasons.add("func name is patched_init")
return SkipFunctionVariable
elif name == "__torch_function__":
elif name == "__torch_function__" or (
obj and obj.__name__ == "__torch_function__"
):
if reasons is not None:
reasons.add("func name is __torch_function__")
return UserFunctionVariable

View File

@ -63,7 +63,6 @@ import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
@ -3065,7 +3064,9 @@ def is_parameter_freezing():
def get_torch_function_mode_stack(filter_ignored=True):
from .variables.torch_function import IGNORED_MODES
stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
stack = [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
@ -3085,6 +3086,11 @@ def set_torch_function_mode_stack(stack):
_push_on_torch_function_stack(mode)
def clear_torch_function_mode_stack():
for i in range(_len_torch_function_stack()):
_pop_torch_function_stack()
def verify_guard_fn_signature(value):
fn = value.__metadata_guard__
sig = inspect.signature(fn)

View File

@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""
@ -637,6 +649,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0]
tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0]
tx.output.set_torch_function_state(values[0])

View File

@ -149,6 +149,15 @@ tracing_state_functions = {
bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"])
@functools.lru_cache(None)
def get_overridable_functions():
from itertools import chain
from torch.overrides import get_overridable_functions as get_overridable_functions_
return set(chain(*get_overridable_functions_().values()))
class BaseTorchVariable(VariableTracker):
"""common base for all torch.* functions, classes, modules and other things"""
@ -782,10 +791,10 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
if not tx.symbolic_torch_function_mode_stack:
if not tx.symbolic_torch_function_state.mode_stack:
raise unimplemented("Popping from an empty torch function mode stack")
TorchFunctionModeStackVariable.register_mutation(tx)
return tx.symbolic_torch_function_mode_stack.pop()
return tx.symbolic_torch_function_state.pop_torch_function_mode()
@register(torch._C._push_on_torch_function_stack)
def handle_push_torch_function(
@ -793,7 +802,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
assert len(args) == 1 and not kwargs
TorchFunctionModeStackVariable.register_mutation(tx)
tx.symbolic_torch_function_mode_stack.append(args[0])
tx.symbolic_torch_function_state.push_torch_function_mode(args[0])
return ConstantVariable.create(None)
@register(torch._C._len_torch_function_stack)
@ -801,7 +810,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
self, tx: "InstructionTranslator", *args, **kwargs
):
assert not args and not kwargs
return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack))
return ConstantVariable.create(
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch.set_default_device)
def handle_set_default_device(
@ -833,6 +844,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy
if self.torch_function_override_enabled(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs
):
@ -850,147 +864,144 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if result:
return result
if can_dispatch_torch_function(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
else:
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
To support this behavior, we need to allow const-propping tensors that store symint data.
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
"""
log.warning(msg)
unimplemented(msg)
log.warning(msg)
unimplemented(msg)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any_symints_or_symfloats:
torch_sym_op = f"_sym_{self.value.__name__}"
if getattr(self.value, "__module__", None) == "math" and hasattr(
torch, torch_sym_op
):
fn_ = getattr(torch, torch_sym_op)
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
fake_out_shape = None
if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
# Calling fake tensor propagation can mutate the out= tensor in
# tx.output.tracked_fakes. tracked_fakes are used to apply
# symbolic_shape guards. Mutating them destroys the information
# prior to tracing, which is essential for creating right
# guards. So save the shape now, and check later if it has
# changed. If it has, graph break.
fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
)
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
if (
isinstance(tensor_variable, TensorVariable)
and "requires_grad" in kwargs
and kwargs["requires_grad"].as_python_constant()
):
unimplemented(
"""factory functions that return tensors that require grad are not supported.
Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
)
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
for out_tensor, result_tensor in zip(
kwargs["out"].items, tensor_variable.items
):
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
out_tensor.source
and out_tensor in tx.output.graphargs
and isinstance(out_tensor, variables.TensorVariable)
and isinstance(result_tensor, variables.TensorVariable)
and out_tensor.size != result_tensor.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
assert "example_value" in kwargs["out"].proxy.node.meta
fake_tensor = tensor_variable.proxy.node.meta["example_value"]
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and fake_out_shape != fake_tensor.shape
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
)
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
elif (
isinstance(tensor_variable, ConstantVariable)
and tensor_variable.value is None
):
# Handle out-variant custom ops that return None.
if isinstance(kwargs["out"], TensorVariable):
assert "example_value" in kwargs["out"].proxy.node.meta
fake_out = kwargs["out"].proxy.node.meta["example_value"]
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where output tensor was non-contiguous"
"out= op was called where some of the output tensors were non-contiguous"
)
elif isinstance(kwargs["out"], ListVariable):
for idx, x in enumerate(kwargs["out"].items):
assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined]
fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined]
if not torch._prims_common.is_contiguous(fake_out):
# It's difficult to handle strides correctly in functionalization
# when calling an out= op with a non-contiguous out argument
unimplemented(
"out= op was called where some of the output tensors were non-contiguous"
)
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
return tensor_variable
return tensor_variable
def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
"""inline behavior of torch.nn.modules.utils._ntuple"""
@ -1118,3 +1129,9 @@ Either create the tensor outside the compiled region, or do not set the tensor t
source
)
return result
def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()
and can_dispatch_torch_function(tx, args, kwargs)
)

View File

@ -1,20 +1,34 @@
# mypy: ignore-errors
import collections
import contextlib
import inspect
from typing import Dict, List, TYPE_CHECKING
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.overrides import (
_get_overloaded_args,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from ..utils import (
class_has_getattribute,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .ctx_manager import GenericContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
@ -59,6 +73,67 @@ banned_attrs = [
IGNORED_MODES = {DeviceContext}
class SymbolicTorchFunctionState:
def __init__(self, py_stack):
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions:
# 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered
# (if either are entered, this will be False)
# 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR
# torch._C.DisableTorchFunction has been entered
# To disambiguate these and keep myself sane I added a C API to check whether all torch function
# concepts (modes and subclasses) are enabled.
# This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate
# the stack length from the enablement state of torch function modes.
# This is important because now if a mode is pushed while dynamo is tracing, we know whether
# or not torch function modes are enabled and whether we should trace it.
self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled()
# This differs from the C API of the same name
# this will only be false iff we have entered torch._C.DisableTorchFunction
# and does not take into account the mode stack length, while the C API bundles these
# two concepts
self.torch_function_mode_enabled = (
not torch._C._is_torch_function_all_disabled()
)
self.cur_mode = None
TorchFunctionModeStackVariable.reset()
self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque()
for i, val in enumerate(py_stack):
self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i))
)
def in_torch_function_mode(self):
return len(self.mode_stack) > 0
def pop_torch_function_mode(self):
return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var):
self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs):
with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager
def _pop_mode_for_inlining(self):
old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode()
try:
yield self.cur_mode
finally:
mode = self.cur_mode
self.cur_mode = old_mode
self.push_torch_function_mode(mode)
class TorchFunctionModeStackVariable(VariableTracker):
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
@ -88,19 +163,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
source=Source(),
symbolic_stack=tx.symbolic_torch_function_state.mode_stack,
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
tx.symbolic_torch_function_mode_stack.insert(
stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
@ -109,7 +185,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@ -123,24 +199,91 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.value = value
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def get_global_mangled_name(tx, val):
return get_safe_global_name(
tx, f"__torch_function_mode_{val.__class__.__name__}", val
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (DeviceContext, NoEnterTorchFunctionMode)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen):
# We don't support locally created torch function modes yet
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def _call_func(self, tx, values):
unimplemented("torch function mode context manager is not supported yet")
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self,
build_torch_function_fn(tx, self.value, self.source),
fn,
types,
args,
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
# NoEnterDeviceTorchFunctionMode
# if isinstance(self.value, DeviceContext)
# else NoEnterTorchFunctionMode
# codegen(
# AttrSource(
# codegen.tx.import_source(torch._dynamo.polyfills.__name__), ty.__name__),
# )
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
@ -231,9 +374,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source):
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
return tx.output.torch_function_enabled and any(
has_overridden_args = any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
tf_state = tx.symbolic_torch_function_state
return (has_overridden_args and tf_state.torch_function_subclass_enabled) or (
tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode()
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
@ -245,11 +392,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
_get_subclass_type,
)
types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args])
if tx.symbolic_torch_function_state.in_torch_function_mode():
res = tx.symbolic_torch_function_state.call_torch_function_mode(
tx, fn, types, args, kwargs
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
types,
args,
kwargs,
)

View File

@ -82,11 +82,6 @@ def is_forbidden_context_manager(ctx):
from _pytest.python_api import RaisesContext
from _pytest.recwarn import WarningsChecker
# TODO mlazos: Temporary to get this stack to pass
# remove in subsequent PR
from torch.overrides import BaseTorchFunctionMode
f_ctxs.append(BaseTorchFunctionMode)
f_ctxs.append(RaisesContext)
f_ctxs.append(WarningsChecker)
except ImportError:
@ -413,15 +408,25 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
# import here to avoid an unfortunate circular dependency.
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, GenericContextWrappingVariable, {}
self.source, self.value, var_cls, {}
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj
elif is_namedtuple_cls(self.value):
fields = namedtuple_fields(self.value)
# check if this a quasi-namedtuple or a real one

View File

@ -506,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
if (
not torch.compiler.is_dynamo_compiling()
and log.isEnabledFor(logging.DEBUG)
and config.extended_debug_current_loc
):
frame = _find_user_code_frame()
if frame is not None:
log.debug(

View File

@ -28,6 +28,7 @@ from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
@ -129,6 +130,10 @@ def cond(pred, true_fn, false_fn, operands):
if torch.compiler.is_dynamo_compiling():
return cond_op(pred, true_fn, false_fn, operands)
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
if isinstance(pred, (bool, int, float)):
log.warning(
"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
@ -169,12 +174,15 @@ def cond(pred, true_fn, false_fn, operands):
def _cond_op_wrapper(*args, **kwargs):
return cond_op(*args, **kwargs)
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
with _temp_remove_pre_dispatch_torch_function_mode():
return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
else:
backend = "eager"
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
pred, true_fn, false_fn, operands
)
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):

View File

@ -15,7 +15,11 @@ from torch._higher_order_ops.utils import (
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
class WhileLoopOp(HigherOrderOperator):
@ -113,6 +117,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
"""
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
@ -140,9 +147,15 @@ def while_loop(cond_fn, body_fn, carried_inputs):
return while_loop_op(*args, **kwargs)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)(
cond_fn, body_fn, carried_inputs, additional_inputs
)
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
else:
backend = "eager"
return torch.compile(
_while_loop_op_wrapper, backend=backend, fullgraph=True
)(cond_fn, body_fn, carried_inputs, additional_inputs)
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)

View File

@ -17,7 +17,7 @@ import typing_extensions
import warnings
import weakref
from collections import defaultdict
from contextlib import contextmanager, ExitStack, nullcontext
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import (
Any,
@ -1084,38 +1084,43 @@ class PythonKeyTracer(Tracer):
return e
@contextmanager
def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]:
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
def _make_temp_remove_mode_context_manager(
mode_ty: Type[TorchFunctionMode],
) -> Callable[[], _GeneratorContextManager[None]]:
@contextmanager
def context_manager_fn() -> Generator[None, None, None]:
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
temp_elements = []
pre_dispatch_mode = None
temp_elements = []
removed_mode = None
while _len_torch_function_stack() > 0:
mode = _pop_mode()
if isinstance(mode, PreDispatchTorchFunctionMode):
pre_dispatch_mode = mode
break
else:
temp_elements.append(mode)
while _len_torch_function_stack() > 0:
mode = _pop_mode()
if isinstance(mode, mode_ty):
removed_mode = mode
break
else:
temp_elements.append(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
try:
yield
try:
yield removed_mode
finally:
if pre_dispatch_mode is not None:
count = len(temp_elements)
while count > 0:
mode = _pop_mode()
count -= 1
finally:
if removed_mode is not None:
count = len(temp_elements)
while count > 0:
mode = _pop_mode()
count -= 1
temp_elements.append(pre_dispatch_mode)
temp_elements.append(removed_mode)
for mode in reversed(temp_elements):
_push_mode(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
return context_manager_fn
@torch._disable_dynamo
@ -1230,6 +1235,11 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
return func(*args, **kwargs)
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
TorchFunctionMetadataMode
)
# This mode is **only** used for pre_dispatch tracing.
# In particular, we need to make sure that autograd/autocast API's
# that do not desugar into dispatcher operators stay in the graph.
@ -1258,6 +1268,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
return func(*args, **kwargs)
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
PreDispatchTorchFunctionMode
)
class ProxyTorchDispatchMode(TorchDispatchMode):
# Ensure this is read-only; this exists only for legacy reasons
@property

View File

@ -19,6 +19,7 @@ from torch._higher_order_ops.flex_attention import (
)
from torch._higher_order_ops.utils import _set_compilation_env
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
)
from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input
@ -1035,18 +1036,25 @@ def flex_attention(
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
with _temp_remove_pre_dispatch_torch_function_mode():
out, lse = torch.compile(
_flex_attention_hop_wrapper, backend="eager", fullgraph=True
)(
query,
key,
value,
score_mod,
block_mask.as_tuple(),
scale,
kernel_options,
)
if return_lse:
return out, lse * math.log(2)
else:
return out
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = torch._dynamo.backends.make_eager_backend_with_torch_function_mode(
metadata_mode
)
else:
backend = "eager"
out, lse = torch.compile(
_flex_attention_hop_wrapper, backend=backend, fullgraph=True
)(
query,
key,
value,
score_mod,
block_mask.as_tuple(),
scale,
kernel_options,
)
if return_lse:
return out, lse * math.log(2)
else:
return out