Compare commits

...

8 Commits

Author SHA1 Message Date
1cad5436f6 [Dynamo] add flex attention mode test
ghstack-source-id: 9bcf8043e94f0b3f9cdfd4e7c1bf787ccaa459a6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137121
2024-10-08 14:11:04 -07:00
318075f6cb [Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods
ghstack-source-id: c33006a10a33048d8186e91b2817bf60c2cee306
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119

nested_tensor fix

realization fix

fixes

Fixes more

fixes

fixes 2

Fixes

fix mpypy

Fix

fix

fixes
2024-10-08 14:11:04 -07:00
6e8b6b11cd [Dynamo] handle extracted unbound tensor methods for flex attention
fixes2

ghstack-source-id: a6d24ca120a9d53a79cff9be2c1ce07e764d8a5d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227

handle tensor methods bett

enable test

More test enablement
2024-10-08 14:11:04 -07:00
9655ebd499 [Dynamo] Add flex attention to mod inline list
ghstack-source-id: bb563dff0454c1afa70bcf0158d222f8fdb4e741
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137120

Move tf mode

Move shiz

Fixes

fixes

fix

fixes

fix

Add some skips

More skips

More skips
2024-10-08 14:11:03 -07:00
df9ef729fd [Dynamo] Dispatch torch function on builtins
ghstack-source-id: 3f67f72fd42457b138a9551328e3fbf311f60b58
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137117

fixes

Fixes2

Skip crossref

Additional fixes

et update

fixes again

update instruction counts
2024-10-08 14:11:03 -07:00
244dc9d802 [Dynamo] Remove ignored modes from torch function mode stack guard (#135503)
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137, #135443, #135444, #135422, #135502

ghstack-source-id: b0ae44675d64ba9c34a15a65edda61e1d5445db5
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137116
2024-10-08 14:11:03 -07:00
39867e316c [Dynamo] Remove ignored modes workaround (#135502)
Approved by: https://github.com/anijain2305
ghstack dependencies: #134732, #133137, #135443, #135444, #135422

ghstack-source-id: d8b6e3185d7e93bf6b3c4c8eac81c1805b1b6ed2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137115
2024-10-08 14:11:02 -07:00
e1db3582d8 [Dynamo] Trace enter/exit of TorchFunctionModes (#135422)
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444

ghstack-source-id: 5174acedd4cd03a8f0d04b1df589b805b7264c08
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137114
2024-10-08 14:11:02 -07:00
48 changed files with 861 additions and 279 deletions

View File

@ -1,5 +1,5 @@
add_loop_eager, compile_time_instruction_count, 2834456320, 0.015
add_loop_eager_dynamic, compile_time_instruction_count, 5528896630, 0.025
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025
add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015

1 add_loop_eager compile_time_instruction_count 2834456320 3004749893 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5528896630 5726573328 0.025
3 add_loop_inductor compile_time_instruction_count 24146845503 24146845503 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39411706509 39411706509 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 22171041650 22171041650 0.015

View File

@ -701,7 +701,7 @@ class CompileTest(TestCase):
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
".default([arg3_1, arg2_1, arg1_1, arg0_1]"
)
.check("buf1 = buf0[0]")
.check("buf2 = buf0[1]")
@ -717,8 +717,8 @@ class CompileTest(TestCase):
)
# Test aoti
out = AOTIRunnerUtil.run("cuda", func, (args,))
torch.cuda.synchronize()
# out = AOTIRunnerUtil.run("cuda", func, (args,))
# torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()

View File

@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
else:
return x - 1
@make_test
def test_tensor_size(x):
fn = torch.Tensor.size
return fn(x + 1)
@make_test
def test_tensor_dim(x):
fn = torch.Tensor.dim
return fn(x + 1)
@make_test
def test_tensor_is_inference(x):
if x.is_inference():

View File

@ -646,10 +646,10 @@ print("arf")
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
)
@make_logging_test(guards=True)

View File

@ -1,4 +1,6 @@
# Owner(s): ["module: dynamo"]
import operator
from unittest.mock import patch
import torch
@ -10,6 +12,7 @@ from torch._C import (
_push_on_torch_function_stack,
)
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
from torch.testing._internal.triton_utils import requires_cuda
from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode
@ -107,70 +110,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
fn(inp)
self.assertEqual(cnt.frame_count, 4)
def _run_ignored_mode_types_test(self):
class IgnoredMode(BaseTorchFunctionMode):
pass
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt.__call__, fullgraph=True)
def fn(x):
return x + 1
inp = torch.ones(2, 2)
with patch(
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
):
# initial compile
fn(inp)
# no recompile, mode ignored
# note: the ref stack is length 0, and the stack we are checking against has length 2
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
with IgnoredMode(), IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 1)
# recompile due to new mode on the stack
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 2)
# recompile
# tests both ref stack len > runtime stack len for the above guard check
# and ref stack len < runtime stack len for the initial zero mode case
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# no recompile
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
# This is tricky, basically the ignored modes are baked into the guard
# IgnoredMode will be ignored forever by that guard.
# This is okay since we don't expect to be modifying IGNORED_MODES
# in the middle of execution except for the purposes of testing.
torch._dynamo.reset()
with IgnoredMode():
fn(inp)
self.assertEqual(cnt.frame_count, 4)
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_ignored_types_py(self):
self._run_ignored_mode_types_test()
def test_torch_function_mode_guards_ignored_types_cpp(self):
self._run_ignored_mode_types_test()
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_py(self):
self._run_torch_function_mode_guard_test()
@ -461,6 +400,205 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
# Needs larger cache size since we recompile for each op
@patch.object(torch._dynamo.config, "cache_size_limit", 48)
def test_builtin_equivalent_funcs(self):
from torch._dynamo.variables.torch_function import (
bin_int_ops,
bin_ops,
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
tensor_and_int_ops,
un_int_ops,
un_ops,
)
expected_func = None
valid = False
class FuncEquivMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
nonlocal expected_func
nonlocal valid
if not kwargs:
kwargs = {}
if torch._dynamo.is_compiling():
valid = expected_func == func
return super().__torch_function__(func, types, args, kwargs)
inp0 = torch.ones(1, 1)
inp1 = torch.ones(1, 1)
inp0_int = torch.ones(1, 1, dtype=torch.int32)
inp1_int = torch.ones(1, 1, dtype=torch.int32)
@torch.compile(fullgraph=True)
def fn_un(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_un_int(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_bin(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_bin_int(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_tensor_and_int(op, inp0, inp1):
return op(inp0, inp1)
setups_and_oplists = [
(lambda o: fn_un(o, inp0), un_ops),
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
]
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: fn_bin(o, 1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
]
skips = {operator.not_} # Has local scalar dense call which graph breaks
rskips = {
operator.matmul,
operator.imatmul,
operator.getitem,
} # Doesn't type check with reversed args
def run_checks(setups_and_oplists, skips, ref_map):
nonlocal valid
nonlocal expected_func
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
if op in skips or op not in ref_map:
continue
with FuncEquivMode():
expected_func = ref_map[op]
setup_fn(op)
self.assertTrue(valid)
expected_func = None
valid = False
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
@requires_cuda
def test_flex_attention(self):
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device("cuda")
flex_attention = torch.compile(flex_attention, dynamic=False)
prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
return prefix_lengths[b] >= kv
# This runs in fullgraph already
mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -672,7 +672,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
def fn(w):
return w.sigmoid()
return w.exp()
fn_opt = compile_full_eager(fn)
@ -683,6 +683,38 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, res_exp2)
def test_torch_function_call_on_method_arg(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func == torch._C.TensorBase.add_:
func = torch._C.TensorBase.sub_
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
def sigmoid(self):
return None
x = torch.ones(2, 2)
y = torch.ones(2, 2)
z = torch.ones(2, 2)
wrapped = y.as_subclass(LocalSubclass)
wrapped2 = z.as_subclass(LocalSubclass)
def fn(a, w):
a.add_(w)
return a
fn_opt = torch.compile(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x, wrapped)
res_act = fn_opt(y, wrapped2)
self.assertEqual(res_exp, res_act)
def test_user_overidden_method_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod

View File

@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x):
sin = torch.ops.aten.sin.default(x)
strict_graph_0 = self.strict_graph_0
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
getitem_2 = strict_mode[0]; strict_mode = None
getitem = strict_mode[0]; strict_mode = None
add = torch.ops.aten.add.Tensor(x, 3); x = None
return (getitem_2, add)""",
return (getitem, add)""",
)
self.assertExpectedInline(

View File

@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfCrossRef,
TEST_TRANSFORMERS,
TestCase as TorchTestCase,
)
@ -6989,6 +6990,7 @@ def forward(self, x):
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
self.assertEqual(expected_names_and_ops, real_names_and_ops)
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
def test_placeholder_naming_collisions_hoo_subgraphs(self):
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
class Foo(torch.nn.Module):
@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase):
# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
# return (getitem,)""")
@skipIfCrossRef
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Can't run fused SDPA on this platform",

View File

@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1):
return [getitem]""", # noqa: B950
)
@skipIfCrossRef # Arg order changes with crossref
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
def true_fn(x):
return x + x.cos()
@ -5252,6 +5253,7 @@ def forward(self, arg0_1):
):
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
@skipIfCrossRef # Arg order changes with crossref
def test_cond_trace_set__and_mutate_intermediate(self):
def f(a, tmp):
a = a.clone()

View File

@ -180,12 +180,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
return ()""", # noqa: B950
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
@ -239,7 +237,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
return (getitem_4, getitem_5)""", # noqa: B950
@ -402,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -414,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -503,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View File

@ -67,7 +67,7 @@ class GuardManager:
) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
def add_torch_function_mode_stack_guard(
self, initial_stack, ignored_types, verbose_code_parts: list[str]
self, initial_stack, verbose_code_parts: list[str]
) -> None: ...
class RootGuardManager(GuardManager):

View File

@ -1,15 +1,22 @@
# mypy: allow-untyped-defs
from typing import Dict, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._ops import HigherOrderOperator, OpOverload
from torch._subclasses import FakeTensorMode
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.overrides import TorchFunctionMode
from torch.utils._python_dispatch import _get_current_dispatch_mode
from torch.utils._pytree import tree_map_only
Tensor = torch.Tensor
__all__ = ["trace_wrapped"]
@ -43,6 +50,27 @@ __all__ = ["trace_wrapped"]
# compiled autograd do we inline into the function.
class TransformGetItemToIndex(TorchFunctionMode):
# This is needed since we want to support calling
# A[q_idx], where q_idx is a scalar tensor in score_mod.
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
# scalar and create a view. We do not want that behavior in this case, so we
# use this torchfunctionmode to override that behavior for score_mod
# wherever we're running it.
def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
) -> object:
if func == torch.Tensor.__getitem__:
index_args = pytree.tree_leaves(args[1])
if all(isinstance(x, torch.Tensor) for x in index_args):
return torch.ops.aten.index(args[0], index_args)
return func(*args, **(kwargs or {}))
def trace_wrapped(*args, **kwargs):
with torch.no_grad():
return _trace_wrapped_op(*args, **kwargs)

View File

@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs):
def make_eager_backend_with_torch_function_mode(mode):
return make_eager_backend_with_torch_function_modes([mode])
def make_eager_backend_with_torch_function_modes(modes):
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
from contextlib import ExitStack
def fn(gm, fake_tensor_inputs, **kwargs):
with mode:
return gm.forward
stack = ExitStack()
for mode in modes:
stack.enter_context(mode)
result = gm.forward
stack.close()
return result
return fn

View File

@ -120,6 +120,7 @@ from .utils import (
troubleshooting_url,
write_record_to_file,
)
from .variables.torch_function import torch_function_mode_stack_state_mgr
np: Optional[ModuleType]
@ -218,15 +219,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(
torch.fx._symbolic_trace._maybe_revert_all_patches()
)
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
try:
return fn(*args, **kwargs)
finally:
cleanup.close()
assert (
torch._C._len_torch_function_stack() == 0
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
exit_stack.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)

View File

@ -2356,15 +2356,12 @@ class CheckFunctionManager:
)
if config.enable_cpp_guard_manager:
from .variables.torch_function import IGNORED_MODES
# Insert the global_state guard
assert self.guard_manager # to make mypy happy
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
self.guard_manager.root.add_torch_function_mode_stack_guard(
self.torch_function_mode_stack,
list(IGNORED_MODES),
["___check_torch_function_mode_stack()"],
)
# Clear references to torch_function modes held in the list
@ -2671,18 +2668,14 @@ def is_recompiles_verbose_enabled():
# this will only be used if cpp guards are disabled
def make_torch_function_mode_stack_guard(intial_stack):
types = [type(x) for x in intial_stack]
from .variables.torch_function import IGNORED_MODES
def check_torch_function_mode_stack():
cur_stack = get_torch_function_mode_stack()
types_ = [ty for ty in types if ty not in IGNORED_MODES]
cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]
if len(cur_stack_) != len(types_):
if len(cur_stack) != len(types):
return False
for ty, mode in zip(types_, cur_stack_):
for ty, mode in zip(types, cur_stack):
if ty != type(mode):
return False

View File

@ -78,7 +78,6 @@ from .utils import (
get_instruction_source_311,
get_locals_to_steal,
get_static_address_type,
get_torch_function_mode_stack,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@ -250,6 +249,7 @@ class OutputGraph:
local_scope: Scope,
global_scope: Scope,
f_code,
torch_function_mode_stack,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
@ -368,7 +368,7 @@ class OutputGraph:
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
# This records the initial torch function mode stack for guarding
self.torch_function_mode_stack = get_torch_function_mode_stack()
self.torch_function_mode_stack = torch_function_mode_stack
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
@ -1021,7 +1021,7 @@ class OutputGraph:
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx)
block.exit(tx, is_graph_break=reason.graph_break)
self.cleanup_graph()
tx.prune_dead_locals()

View File

@ -25,6 +25,26 @@ if TYPE_CHECKING:
sys as sys,
)
from torch.overrides import BaseTorchFunctionMode
# These classes handle support for TorchFunctionModes across
# graph breaks
# Today the TorchFunctionMode enter (for the classes we support)
# simply pushes the mode onto the stack. Since after this occurs
# the stack is mutated, and we replay these mutations, we don't need
# any cleanup logic to be run once the graph break occurs, we simply replay
# these mutations to ensure at the graph break the torch function mode stack is correct
# and reconstruct the torch function mode stack normally
# when we compile the resume function on the other side of the break.
# However, to ensure we exit properly
# in the resume function, we need to re-enter the contexts as we do other contexts.
# These contexts do nothing on enter, but provide the correct exit logic to ensure
# the stack state is correct.
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
def __enter__(self):
pass
def index(iterator, item, start=0, end=None):
from itertools import islice

View File

@ -90,27 +90,25 @@ class ReenterWith:
stack_index: int
target_values: Optional[Tuple[Any, ...]] = None
# TODO(mlazos) - Uncomment with the reland of torch function mode support
# def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
# """
# Codegen based off of:
# try:
# (rest)
# except:
# (restore previous tf mode stack)
# raise
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
"""
Codegen based off of:
try:
(rest)
except:
(restore previous tf mode stack)
raise
"""
from .variables.torch_function import get_prev_stack_var_name
# """
# from .variables.torch_function import get_prev_stack_var_name
setup_try_except, epilogue = _bytecode_from_template_with_split(
_try_except_tf_mode_template,
self.stack_index,
varname_map={"stack_var_name": get_prev_stack_var_name()},
)
cleanup[:] = epilogue + cleanup
# setup_try_except, epilogue = _bytecode_from_template_with_split(
# _try_except_tf_mode_template,
# self.stack_index,
# varname_map={"stack_var_name": get_prev_stack_var_name()},
# )
# cleanup[:] = epilogue + cleanup
# return setup_try_except
return setup_try_except
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol

View File

@ -629,11 +629,22 @@ class SideEffects:
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)
cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))

View File

@ -267,13 +267,12 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index)
def exit(self, tx):
if hasattr(self, "graph_break") and isinstance(
self.with_context, TorchFunctionModeVariable
):
return
def exit(self, tx, is_graph_break):
assert self.with_context is not None
return self.with_context.exit(tx)
if (
is_graph_break and self.with_context.exit_on_graph_break()
) or not is_graph_break:
return self.with_context.exit(tx)
class ReturnValueOp(Exception):
@ -657,10 +656,17 @@ def break_graph_if_unsupported(*, push):
cleanup: List[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
# Don't exit any modes we have entered,
# output bytecode will mutate the tf mode stack accordingly
if isinstance(b.with_context, TorchFunctionModeVariable):
cg.extend_output(
b.resume_fn().try_except_torch_function_mode(
cg.code_options, cleanup
)
)
continue
assert b.with_context is not None
assert isinstance(
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
)
assert isinstance(b.with_context, (ContextWrappingVariable))
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
@ -2314,7 +2320,10 @@ class InstructionTranslatorBase(
):
unimplemented(f"{inst.opname} {ctx}")
if isinstance(ctx, GenericContextWrappingVariable):
if (
isinstance(ctx, GenericContextWrappingVariable)
and not ctx.supports_graph_breaks()
):
self.generic_context_manager_depth += 1
# Need this redundant check for mypy
@ -2687,6 +2696,7 @@ class InstructionTranslator(InstructionTranslatorBase):
local_scope=f_locals,
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
),
instructions=instructions,
f_locals=f_locals,

View File

@ -187,6 +187,7 @@ def debug_insert_nops(
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))

View File

@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
"torch.set_default_device": UserFunctionVariable,
"torch.sparse_bsc_tensor": SkipFunctionVariable,
"torch.sparse_bsr_tensor": SkipFunctionVariable,
"torch.sparse_csc_tensor": SkipFunctionVariable,
@ -2802,7 +2803,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.random.initial_seed",
"torch.random.seed",
"torch.return_types.pytree_register_structseq",
"torch.set_default_device",
"torch.set_default_dtype",
"torch.set_default_tensor_type",
"torch.set_deterministic_debug_mode",
@ -2912,6 +2912,9 @@ def get_tensor_method():
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
):
s.add(method)
# mlazos: this is a function which we handle specially in TensorVariable
s.add(torch.Tensor.__contains__) # type: ignore[arg-type]
return frozenset(s)

View File

@ -2912,18 +2912,28 @@ def is_torch_function_object(value):
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
from torch._dynamo.variables import UserDefinedObjectVariable
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
if isinstance(vt, TensorWithTFOverrideVariable):
return True
# Note on lazy vars: The value will either be realized or not throughout the course of execution
# if the value has a torch function, it will eventually be realized so we can realize it here
# if the value does not have a torch function, it may or may not be realized
# if it is realized it will be used and guards will be installed properly
# if it is not used, guards won't be installed, and it doesn't matter
# if the value has a torch function or not, so we should *not* realize it.
# NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
# but mypy does not unfortunately
if vt.is_realized() or (
hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
):
if isinstance(vt, TensorWithTFOverrideVariable):
return True
if isinstance(vt, LazyVariableTracker):
LazyVariableTracker.realize(vt)
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
vt.value, "__torch_function__"
)
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
vt.value, "__torch_function__"
)
return False
# see note [Tensor Fakification and Symbol Caching]
@ -3116,16 +3126,10 @@ def is_parameter_freezing():
return torch._inductor.config.freezing and not torch.is_grad_enabled()
def get_torch_function_mode_stack(filter_ignored=True):
from .variables.torch_function import IGNORED_MODES
stack = [
def get_torch_function_mode_stack():
return [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
return stack
def get_torch_function_mode_stack_at(ind):

View File

@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .torch_function import (
build_torch_function_fn,
TensorWithTFOverrideVariable,
torch_function_mode_stack_state_mgr,
TorchFunctionModeVariable,
)
from .user_defined import (
@ -1669,15 +1670,16 @@ class VariableBuilder:
# but warning is not the end of the world
assert isinstance(value.base, np.nditer)
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
with torch_function_mode_stack_state_mgr.temp_restore_stack():
try:
tensor_value = _util._try_convert_to_tensor(value)
if readonly:
from torch._prims_common import clone_preserve_strides
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
tensor_value = clone_preserve_strides(tensor_value)
except NotImplementedError as e:
# failed to convert to tensor, graph break
unimplemented(str(e))
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here

View File

@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker):
operator.ne,
operator.eq,
operator.sub,
operator.getitem,
operator.length_hint,
operator.lshift,
operator.rshift,
@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker):
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.getitem,
operator.imod,
operator.iadd,
operator.isub,
@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker):
if kwargs and not self.tensor_args(*args, *kwargs.values()):
return
# insert handling for torch function here
from .builder import SourcelessBuilder
from .torch_function import (
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
can_dispatch_torch_function,
dispatch_torch_function,
)
if can_dispatch_torch_function(tx, args, kwargs):
# Only remap the fn to tensor methods if we aren't exporting
# export serde does not handle method descriptors today
if not tx.export:
# Use sourceless builder, we built the map ourselves
if not isinstance(args[0], TensorVariable):
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
else:
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
tmp = args[0]
# swap args and call reverse version of func
args[0] = args[1]
args[1] = tmp
else:
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
else:
func = self.fn
fn_var = SourcelessBuilder.create(tx, func)
return dispatch_torch_function(tx, fn_var, args, kwargs)
fn = self.fn
try:
# Constant fold for constant tensor and python constants

View File

@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
if isinstance(args[0], UserFunctionVariable):
return WrappedUserFunctionVariable(args[0], self)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return True
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
tx.generic_context_manager_depth -= 1
return x
def supports_graph_breaks(self):
return False
def exit_on_graph_break(self):
return True
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""

View File

@ -1998,8 +1998,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
fn: "VariableTracker",
fn_name: str,
):
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
from .builder import SourcelessBuilder
tx: InstructionTranslator = tx

View File

@ -80,6 +80,14 @@ class LazyVariableTracker(VariableTracker):
self.realize()
return VariableTracker.clone(self.unwrap(), **kwargs)
def peek_type(self) -> type[Any]:
assert not self.is_realized()
return type(self._cache.value)
def peek_value(self) -> Any:
assert not self.is_realized()
return self._cache.value
def __str__(self) -> str:
if self.is_realized():
return self.unwrap().__str__()

View File

@ -510,9 +510,37 @@ class TensorVariable(VariableTracker):
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import SourcelessBuilder, VariableBuilder
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
unimplemented(f"Illegal method invocation {name} in strict mode")
# Only override builtin tensor methods
# The user can manually add override handling
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
has_torch_function_override = False
try:
inspect.getattr_static(torch.Tensor, name)
has_torch_function_override = True
except AttributeError:
has_torch_function_override = False
if (
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
and has_torch_function_override
):
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(torch.Tensor, name))
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
return dispatch_torch_function(
tx, func_var, tuple([self] + list(args)), kwargs
)
"""
Dispatch to a method-specific handler defined below. If the
handler returns None (or doesn't exist) we put the method call
@ -772,6 +800,30 @@ class TensorVariable(VariableTracker):
self._warn_capture_scalar_outputs()
unimplemented("Tensor.item")
def method___getitem__(self, *args, **kwargs):
from ..symbolic_convert import InstructionTranslator
from .builder import wrap_fx_proxy
tx = InstructionTranslator.current_tx()
if isinstance(args[0], SymNodeVariable):
# Standard indexing will force specialization due to
# __index__. Rewrite as a regular torch op which will
# trace fine
fn, args = torch.select, [
variables.ConstantVariable.create(0),
args[0],
]
else:
fn = operator.getitem
proxy = tx.output.create_proxy(
"call_function",
fn,
*proxy_args_kwargs([self] + list(args), kwargs),
)
return wrap_fx_proxy(tx, proxy)
@staticmethod
@functools.lru_cache(None)
def _warn_capture_scalar_outputs():

View File

@ -159,7 +159,17 @@ def get_overridable_functions():
from torch.overrides import get_overridable_functions as get_overridable_functions_
return set(chain(*get_overridable_functions_().values()))
funcs = set(chain(*get_overridable_functions_().values()))
more = {
torch.ones,
torch.ones_like,
torch.zeros,
torch.zeros_like,
torch.empty,
torch.full,
}
funcs.update(more)
return funcs
class BaseTorchVariable(VariableTracker):
@ -835,6 +845,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
len(tx.symbolic_torch_function_state.mode_stack)
)
@register(torch._C._get_function_stack_at)
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
assert len(args) == 1 and not kwargs
ind = args[0].as_python_constant()
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
return tx.symbolic_torch_function_state.mode_stack[ind]
@register(torch.set_default_device)
def handle_set_default_device(
self, tx: "InstructionTranslator", *args, **kwargs
@ -852,7 +869,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
else:
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
return None
return ConstantVariable.create(None)
return handlers
@ -883,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
),
)
if self.is_tensor_method():
return self.call_tensor_method(tx, args, kwargs)
special_handler = self._get_handlers().get(self.value)
if special_handler:
result = special_handler(self, tx, *args, **kwargs)
@ -1155,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
)
return result
def call_tensor_method(self, tx, args, kwargs):
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
def is_tensor_method(self):
return (
inspect.ismethoddescriptor(self.get_function())
and hasattr(self.get_function(), "__objclass__")
and self.get_function().__objclass__ == torch._C.TensorBase
) or self.get_function() is torch.Tensor.__contains__
def torch_function_override_enabled(self, tx, args, kwargs):
return (
self.get_function() in get_overridable_functions()

View File

@ -2,22 +2,37 @@
import collections
import contextlib
import functools
import inspect
import operator
from typing import Deque, Dict, List, TYPE_CHECKING
import torch._C
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.overrides import (
_get_overloaded_args,
BaseTorchFunctionMode,
get_default_nowrap_functions,
TorchFunctionMode,
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from ..utils import (
class_has_getattribute,
clear_torch_function_mode_stack,
get_safe_global_name,
has_torch_function,
is_tensor_base_attr_getter,
set_torch_function_mode_stack,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .ctx_manager import GenericContextWrappingVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -49,6 +64,125 @@ if TYPE_CHECKING:
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
bin_ops = [
operator.pow,
operator.mul,
operator.matmul,
operator.floordiv,
operator.truediv,
operator.mod,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
operator.ne,
operator.eq,
operator.sub,
operator.ipow,
operator.imul,
operator.imatmul,
operator.ifloordiv,
operator.itruediv,
operator.imod,
operator.iadd,
operator.isub,
]
bin_int_ops = [
operator.and_,
operator.or_,
operator.xor,
operator.iand,
operator.ixor,
operator.ior,
]
un_int_ops = [operator.invert]
tensor_and_int_ops = [
operator.lshift,
operator.rshift,
operator.ilshift,
operator.irshift,
operator.getitem,
]
un_ops = [
operator.abs,
operator.pos,
operator.neg,
operator.not_, # Note: this has a local scalar dense call
operator.length_hint,
]
BUILTIN_TO_TENSOR_FN_MAP = {}
# These functions represent the r* versions of the above ops
# Basically, if __add__(1, Tensor) is called, it is translated
# to __radd__(Tensor, 1).
# In the builtin var, we check if there is a tensor in the first args position,
# if not, we swap the args and use the r* version of the op.
BUILTIN_TO_TENSOR_RFN_MAP = {}
def populate_builtin_to_tensor_fn_map():
global BUILTIN_TO_TENSOR_FN_MAP
most_recent_func = None
class GetMethodMode(BaseTorchFunctionMode):
"""
Mode to extract the correct methods from torch function invocations
(Used to get the correct torch.Tensor methods from builtins)
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
nonlocal most_recent_func
most_recent_func = func
return func(*args, **kwargs)
inp0 = torch.ones(1)
inp1 = torch.ones(1)
inp0_int = torch.ones(1, dtype=torch.int32)
inp1_int = torch.ones(1, dtype=torch.int32)
with GetMethodMode():
setups_and_oplists = [
(lambda o: o(inp0), un_ops),
(lambda o: o(inp0_int), un_int_ops),
(lambda o: o(inp0, inp1), bin_ops),
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
]
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
setup_fn(op)
assert most_recent_func is not None
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: o(1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: o(1, inp1_int), bin_int_ops),
(lambda o: o(0, inp0_int), tensor_and_int_ops),
]
rskips = {operator.matmul, operator.imatmul, operator.getitem}
for setup_fn, op_list in rsetups_and_oplists:
for op in op_list:
if op in rskips:
continue
setup_fn(op)
assert most_recent_func is not None
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
populate_builtin_to_tensor_fn_map()
banned_attrs = [
fn.__self__.__name__
@ -56,11 +190,38 @@ banned_attrs = [
if is_tensor_base_attr_getter(fn)
]
# Today set default device is placed in the graph and guarded on separately
# so we should not trace through it. In the future we can trace it once
# mode tracing is implemented and not put in the graph, but this is more
# of a BE project and can be evaluated later
IGNORED_MODES = {DeviceContext}
@functools.lru_cache(None)
def get_prev_stack_var_name():
from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack")
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager:
def __init__(self):
self.stack = []
def __enter__(self):
self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback):
set_torch_function_mode_stack(self.stack)
self.stack = []
@contextlib.contextmanager
def temp_restore_stack(self):
prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack)
try:
yield
finally:
set_torch_function_mode_stack(prev)
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState:
@ -189,9 +350,26 @@ class TorchFunctionModeStackVariable(VariableTracker):
return ind + cls.offset
class TorchFunctionModeVariable(ContextWrappingVariable):
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty):
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the funtion across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
super().__init__(value, **kwargs)
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
@ -221,8 +399,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable):
kwargs,
)
def _call_func(self, tx: "InstructionTranslator", values):
unimplemented("enter/exit for torch function mode NYI")
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
@ -233,7 +442,6 @@ def _flatten_vts(vts):
from collections import deque
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable
vts = deque(vts)
@ -241,13 +449,17 @@ def _flatten_vts(vts):
while vts:
vt = vts.pop()
LazyVariableTracker.realize_all(vt)
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
else:
output.append(vt)
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
vt.realize()
if vt.is_realized():
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
output.append(vt)
return output
@ -301,8 +513,15 @@ def call_torch_function(
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from types import FunctionType
from .builder import SourcelessBuilder, VariableBuilder
func = value.__torch_function__.__func__
if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI")
if source:
return VariableBuilder(
tx,

View File

@ -413,10 +413,22 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
and not is_forbidden_context_manager(self.value)
):
from torch.overrides import TorchFunctionMode
from .ctx_manager import GenericContextWrappingVariable
from .torch_function import TorchFunctionModeVariable
if issubclass(
self.value, TorchFunctionMode
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
self.value
):
var_cls = TorchFunctionModeVariable
else:
var_cls = GenericContextWrappingVariable
cm_obj = tx.output.side_effects.track_object_new(
self.source, self.value, GenericContextWrappingVariable, {}
self.source, self.value, var_cls, {}
)
cm_obj.call_method(tx, "__init__", args, kwargs)
return cm_obj

View File

@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import (
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator, OpOverload
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
make_fx,
@ -19,7 +19,6 @@ from torch.fx.experimental.proxy_tensor import (
track_tensor_tree,
)
from torch.fx.graph_module import GraphModule
from torch.overrides import TorchFunctionMode
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
@ -69,27 +68,6 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch
return new_out
class TransformGetItemToIndex(TorchFunctionMode):
# This is needed since we want to support calling
# A[q_idx], where q_idx is a scalar tensor in score_mod.
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
# scalar and create a view. We do not want that behavior in this case, so we
# use this torchfunctionmode to override that behavior for score_mod
# wherever we're running it.
def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
) -> object:
if func == torch.Tensor.__getitem__:
index_args = pytree.tree_leaves(args[1])
if all(isinstance(x, torch.Tensor) for x in index_args):
return torch.ops.aten.index(args[0], index_args)
return func(*args, **(kwargs or {}))
class FlexAttentionHOP(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("flex_attention", cacheable=True)
@ -185,6 +163,8 @@ def _math_attention_inner(
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
@ -318,6 +298,8 @@ def trace_flex_attention(
This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
access this graph module in inductor to inline the score_mod function to the triton template.
"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
example_out = flex_attention(
query,
key,
@ -414,6 +396,8 @@ def flex_attention_functionalize(
guard against any mutations in the score_mod function, to the other_buffers since those
are free variables.
"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
query_unwrapped = ctx.unwrap_tensors(query)
key_unwrapped = ctx.unwrap_tensors(key)
value_unwrapped = ctx.unwrap_tensors(value)
@ -715,6 +699,8 @@ def flex_attention_autograd(
score_mod_other_buffers: Tuple[Tensor, ...] = (),
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
with TransformGetItemToIndex():
input_requires_grad = any(t.requires_grad for t in (query, key, value))
if torch.is_grad_enabled() and input_requires_grad:
@ -765,6 +751,8 @@ def sdpa_dense_backward(
score_mod_other_buffers: Tuple,
mask_mod_other_buffers: Tuple,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
# Get outputs before calling repeat interleave
actual_grad_query = torch.empty_like(query)
actual_grad_key = torch.empty_like(key)
@ -892,6 +880,8 @@ def trace_flex_attention_backward(
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
example_out = flex_attention_backward(
query,
key,

View File

@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
make_fx,
ProxyTorchDispatchMode,
@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode
@exposed_in("torch")
def strict_mode(callable, operands):
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_modes,
)
if torch.compiler.is_dynamo_compiling():
return strict_mode_op(callable, operands)
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
callable, operands
)
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
modes = [metadata_mode, predispatch_mode]
modes = [mode for mode in modes if mode is not None]
if modes:
backend = make_eager_backend_with_torch_function_modes(modes)
else:
backend = "eager"
with torch._dynamo.utils.disable_cache_limit():
return torch.compile(
strict_mode_op, backend=backend, fullgraph=True
)(callable, operands)
class StrictMode(HigherOrderOperator):

View File

@ -2540,90 +2540,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public:
TORCH_FUNCTION_MODE_STACK(
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_ref_stack(),
_ignored_types() {
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
Py_ssize_t len = PyList_Size(initial_stack.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
auto type = Py_TYPE(mode);
this->_ref_stack.push_back(type);
}
len = PyList_Size(ignored_types.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* type_obj =
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
if (PyType_Check(type_obj) == 0) {
PyErr_SetString(
PyExc_TypeError, "ignored_types should contain a list of types");
return;
}
PyTypeObject* type = (PyTypeObject*)type_obj;
this->_ignored_types.insert(type);
}
}
bool check_nopybind(PyObject* value) override {
// Ignore value arg, only used to satisfy the interface
size_t ref_ind = 0;
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
const size_t ref_stack_size = this->_ref_stack.size();
int64_t idx = 0;
while ((idx < len) && (ref_ind < ref_stack_size)) {
if (len != ref_stack_size) {
return false;
}
for (int64_t idx = 0; (size_t)idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
bool act_ignored = this->_ignored_types.count(mode_type) > 0;
bool ref_ignored =
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
// skip ignored types
if (act_ignored && ref_ignored) {
idx++;
ref_ind++;
continue;
} else if (ref_ignored) {
ref_ind++;
continue;
} else if (act_ignored) {
idx++;
continue;
}
// if we already have more non-ignored modes than the ref stack
// or if the mode doesn't match at the current index, return false
else if (mode_type != _ref_stack.at(ref_ind)) {
return false;
}
ref_ind++;
idx++;
}
for (; ref_ind < ref_stack_size; ref_ind++) {
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
if (mode_type != _ref_stack.at(idx)) {
return false;
}
}
for (; idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (!(this->_ignored_types.count(mode_type) > 0)) {
return false;
}
}
return ref_ind == ref_stack_size && idx == len;
return true;
}
private:
std::vector<PyTypeObject*> _ref_stack;
std::set<PyTypeObject*> _ignored_types;
};
class TENSOR_MATCH : public LeafGuard {
@ -3792,7 +3742,7 @@ PyObject* torch_c_dynamo_guards_init() {
LeafGuard,
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
py_m, "TORCH_FUNCTION_MODE_STACK")
.def(py::init<py::list, py::list, py::list>())
.def(py::init<py::list, py::list>())
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
py_m, "DATA_PTR_MATCH")
@ -4029,10 +3979,9 @@ PyObject* torch_c_dynamo_guards_init() {
"add_torch_function_mode_stack_guard",
[](GuardManager& self,
const py::list& initial_stack,
const py::list& ignored_types,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
initial_stack, ignored_types, std::move(verbose_code_parts)));
initial_stack, std::move(verbose_code_parts)));
})
.def(
"add_data_ptr_guard",

View File

@ -13,10 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch._higher_order_ops.flex_attention import (
flex_attention as flex_attention_hop,
TransformGetItemToIndex,
)
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
from torch._higher_order_ops.utils import _set_compilation_env
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,