mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Compare commits
8 Commits
v2.9.1-rc2
...
mlazos/tf-
| Author | SHA1 | Date | |
|---|---|---|---|
| 1cad5436f6 | |||
| 318075f6cb | |||
| 6e8b6b11cd | |||
| 9655ebd499 | |||
| df9ef729fd | |||
| 244dc9d802 | |||
| 39867e316c | |||
| e1db3582d8 |
@ -1,5 +1,5 @@
|
||||
add_loop_eager, compile_time_instruction_count, 2834456320, 0.015
|
||||
add_loop_eager_dynamic, compile_time_instruction_count, 5528896630, 0.025
|
||||
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
|
||||
add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025
|
||||
add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015
|
||||
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025
|
||||
add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015
|
||||
|
||||
|
@ -701,7 +701,7 @@ class CompileTest(TestCase):
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
|
||||
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
|
||||
".default([arg3_1, arg2_1, arg1_1, arg0_1]"
|
||||
)
|
||||
.check("buf1 = buf0[0]")
|
||||
.check("buf2 = buf0[1]")
|
||||
@ -717,8 +717,8 @@ class CompileTest(TestCase):
|
||||
)
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
# out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
|
||||
@ -938,6 +938,16 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
else:
|
||||
return x - 1
|
||||
|
||||
@make_test
|
||||
def test_tensor_size(x):
|
||||
fn = torch.Tensor.size
|
||||
return fn(x + 1)
|
||||
|
||||
@make_test
|
||||
def test_tensor_dim(x):
|
||||
fn = torch.Tensor.dim
|
||||
return fn(x + 1)
|
||||
|
||||
@make_test
|
||||
def test_tensor_is_inference(x):
|
||||
if x.is_inference():
|
||||
|
||||
@ -646,10 +646,10 @@ print("arf")
|
||||
self.assertExpectedInline(
|
||||
munge_shape_guards(record.getMessage()),
|
||||
"""\
|
||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||
+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||
)
|
||||
|
||||
@make_logging_test(guards=True)
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import operator
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -10,6 +12,7 @@ from torch._C import (
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
from torch.utils._device import DeviceContext
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
@ -107,70 +110,6 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
fn(inp)
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
def _run_ignored_mode_types_test(self):
|
||||
class IgnoredMode(BaseTorchFunctionMode):
|
||||
pass
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnt.__call__, fullgraph=True)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
inp = torch.ones(2, 2)
|
||||
|
||||
with patch(
|
||||
"torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
|
||||
):
|
||||
# initial compile
|
||||
fn(inp)
|
||||
|
||||
# no recompile, mode ignored
|
||||
# note: the ref stack is length 0, and the stack we are checking against has length 2
|
||||
# we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
|
||||
with IgnoredMode(), IgnoredMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
# recompile due to new mode on the stack
|
||||
with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
# recompile
|
||||
# tests both ref stack len > runtime stack len for the above guard check
|
||||
# and ref stack len < runtime stack len for the initial zero mode case
|
||||
with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
# no recompile
|
||||
with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
# This is tricky, basically the ignored modes are baked into the guard
|
||||
# IgnoredMode will be ignored forever by that guard.
|
||||
# This is okay since we don't expect to be modifying IGNORED_MODES
|
||||
# in the middle of execution except for the purposes of testing.
|
||||
torch._dynamo.reset()
|
||||
|
||||
with IgnoredMode():
|
||||
fn(inp)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||
def test_torch_function_mode_guards_ignored_types_py(self):
|
||||
self._run_ignored_mode_types_test()
|
||||
|
||||
def test_torch_function_mode_guards_ignored_types_cpp(self):
|
||||
self._run_ignored_mode_types_test()
|
||||
|
||||
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
||||
def test_torch_function_mode_guards_py(self):
|
||||
self._run_torch_function_mode_guard_test()
|
||||
@ -461,6 +400,205 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_enter_exit(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
torch._dynamo.graph_break()
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_torch_function_mode_restore_on_exc(self):
|
||||
@torch._dynamo.disable()
|
||||
def err():
|
||||
raise RuntimeError("test")
|
||||
|
||||
@torch.compile()
|
||||
def fn(x):
|
||||
with TestMode():
|
||||
x += 1
|
||||
err()
|
||||
x += 2
|
||||
return x
|
||||
|
||||
try:
|
||||
fn(torch.ones(2, 2))
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertEqual(_len_torch_function_stack(), 0)
|
||||
|
||||
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
||||
def fn(x, y):
|
||||
with TestMode():
|
||||
z = _pop_torch_function_stack()
|
||||
z.y = 5
|
||||
torch._dynamo.graph_break()
|
||||
_push_on_torch_function_stack(z)
|
||||
o = torch.add(x, 3)
|
||||
o = torch.mul(o, z.y)
|
||||
|
||||
return torch.add(o, y)
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
expected = fn(*inp)
|
||||
actual = fn_opt(*inp)
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# Needs larger cache size since we recompile for each op
|
||||
@patch.object(torch._dynamo.config, "cache_size_limit", 48)
|
||||
def test_builtin_equivalent_funcs(self):
|
||||
from torch._dynamo.variables.torch_function import (
|
||||
bin_int_ops,
|
||||
bin_ops,
|
||||
BUILTIN_TO_TENSOR_FN_MAP,
|
||||
BUILTIN_TO_TENSOR_RFN_MAP,
|
||||
tensor_and_int_ops,
|
||||
un_int_ops,
|
||||
un_ops,
|
||||
)
|
||||
|
||||
expected_func = None
|
||||
valid = False
|
||||
|
||||
class FuncEquivMode(BaseTorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
nonlocal expected_func
|
||||
nonlocal valid
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
if torch._dynamo.is_compiling():
|
||||
valid = expected_func == func
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
inp0 = torch.ones(1, 1)
|
||||
inp1 = torch.ones(1, 1)
|
||||
inp0_int = torch.ones(1, 1, dtype=torch.int32)
|
||||
inp1_int = torch.ones(1, 1, dtype=torch.int32)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn_un(op, inp):
|
||||
return op(inp)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn_un_int(op, inp):
|
||||
return op(inp)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn_bin(op, inp0, inp1):
|
||||
return op(inp0, inp1)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn_bin_int(op, inp0, inp1):
|
||||
return op(inp0, inp1)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn_tensor_and_int(op, inp0, inp1):
|
||||
return op(inp0, inp1)
|
||||
|
||||
setups_and_oplists = [
|
||||
(lambda o: fn_un(o, inp0), un_ops),
|
||||
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
|
||||
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
|
||||
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
|
||||
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
|
||||
]
|
||||
|
||||
# gather the reverse functions
|
||||
rsetups_and_oplists = [
|
||||
(
|
||||
lambda o: fn_bin(o, 1, inp1),
|
||||
bin_ops,
|
||||
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
||||
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
|
||||
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
|
||||
]
|
||||
|
||||
skips = {operator.not_} # Has local scalar dense call which graph breaks
|
||||
rskips = {
|
||||
operator.matmul,
|
||||
operator.imatmul,
|
||||
operator.getitem,
|
||||
} # Doesn't type check with reversed args
|
||||
|
||||
def run_checks(setups_and_oplists, skips, ref_map):
|
||||
nonlocal valid
|
||||
nonlocal expected_func
|
||||
for setup_fn, op_list in setups_and_oplists:
|
||||
for op in op_list:
|
||||
if op in skips or op not in ref_map:
|
||||
continue
|
||||
with FuncEquivMode():
|
||||
expected_func = ref_map[op]
|
||||
setup_fn(op)
|
||||
self.assertTrue(valid)
|
||||
|
||||
expected_func = None
|
||||
valid = False
|
||||
|
||||
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
|
||||
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
|
||||
|
||||
@requires_cuda
|
||||
def test_flex_attention(self):
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
|
||||
prefix_lengths = torch.arange(8)
|
||||
|
||||
def prefix_lm(b, h, q, kv):
|
||||
return prefix_lengths[b] >= kv
|
||||
|
||||
# This runs in fullgraph already
|
||||
mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -672,7 +672,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
|
||||
|
||||
def fn(w):
|
||||
return w.sigmoid()
|
||||
return w.exp()
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
@ -683,6 +683,38 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(res_exp, res_exp2)
|
||||
|
||||
def test_torch_function_call_on_method_arg(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if func == torch._C.TensorBase.add_:
|
||||
func = torch._C.TensorBase.sub_
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def sigmoid(self):
|
||||
return None
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
y = torch.ones(2, 2)
|
||||
z = torch.ones(2, 2)
|
||||
wrapped = y.as_subclass(LocalSubclass)
|
||||
wrapped2 = z.as_subclass(LocalSubclass)
|
||||
|
||||
def fn(a, w):
|
||||
a.add_(w)
|
||||
return a
|
||||
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x, wrapped)
|
||||
res_act = fn_opt(y, wrapped2)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_user_overidden_method_unsupported(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
|
||||
@ -49,9 +49,9 @@ def forward(self, b_submodule_buffer1, x):
|
||||
sin = torch.ops.aten.sin.default(x)
|
||||
strict_graph_0 = self.strict_graph_0
|
||||
strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None
|
||||
getitem_2 = strict_mode[0]; strict_mode = None
|
||||
getitem = strict_mode[0]; strict_mode = None
|
||||
add = torch.ops.aten.add.Tensor(x, 3); x = None
|
||||
return (getitem_2, add)""",
|
||||
return (getitem, add)""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
|
||||
@ -64,6 +64,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
TEST_TRANSFORMERS,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
@ -6989,6 +6990,7 @@ def forward(self, x):
|
||||
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
|
||||
self.assertEqual(expected_names_and_ops, real_names_and_ops)
|
||||
|
||||
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
|
||||
def test_placeholder_naming_collisions_hoo_subgraphs(self):
|
||||
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
|
||||
class Foo(torch.nn.Module):
|
||||
@ -8325,6 +8327,7 @@ class TestOneOffModelExportResult(TestCase):
|
||||
# getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None
|
||||
# return (getitem,)""")
|
||||
|
||||
@skipIfCrossRef
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Can't run fused SDPA on this platform",
|
||||
|
||||
@ -4902,6 +4902,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
return [getitem]""", # noqa: B950
|
||||
)
|
||||
|
||||
@skipIfCrossRef # Arg order changes with crossref
|
||||
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
|
||||
def true_fn(x):
|
||||
return x + x.cos()
|
||||
@ -5252,6 +5253,7 @@ def forward(self, arg0_1):
|
||||
):
|
||||
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
|
||||
|
||||
@skipIfCrossRef # Arg order changes with crossref
|
||||
def test_cond_trace_set__and_mutate_intermediate(self):
|
||||
def f(a, tmp):
|
||||
a = a.clone()
|
||||
|
||||
@ -180,12 +180,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
|
||||
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
# No stacktrace found for following nodes
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
|
||||
arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""",
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""", # noqa: B950
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
@ -239,7 +237,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
@ -402,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
@ -414,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
@ -503,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
||||
@ -67,7 +67,7 @@ class GuardManager:
|
||||
) -> None: ...
|
||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_torch_function_mode_stack_guard(
|
||||
self, initial_stack, ignored_types, verbose_code_parts: list[str]
|
||||
self, initial_stack, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
|
||||
class RootGuardManager(GuardManager):
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._ops import HigherOrderOperator, OpOverload
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
__all__ = ["trace_wrapped"]
|
||||
|
||||
|
||||
@ -43,6 +50,27 @@ __all__ = ["trace_wrapped"]
|
||||
# compiled autograd do we inline into the function.
|
||||
|
||||
|
||||
class TransformGetItemToIndex(TorchFunctionMode):
|
||||
# This is needed since we want to support calling
|
||||
# A[q_idx], where q_idx is a scalar tensor in score_mod.
|
||||
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
|
||||
# scalar and create a view. We do not want that behavior in this case, so we
|
||||
# use this torchfunctionmode to override that behavior for score_mod
|
||||
# wherever we're running it.
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Tuple[torch._C._TensorMeta, ...],
|
||||
args: Tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None,
|
||||
) -> object:
|
||||
if func == torch.Tensor.__getitem__:
|
||||
index_args = pytree.tree_leaves(args[1])
|
||||
if all(isinstance(x, torch.Tensor) for x in index_args):
|
||||
return torch.ops.aten.index(args[0], index_args)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
def trace_wrapped(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
return _trace_wrapped_op(*args, **kwargs)
|
||||
|
||||
@ -32,13 +32,23 @@ def eager(gm, fake_tensor_inputs, **kwargs):
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_mode(mode):
|
||||
return make_eager_backend_with_torch_function_modes([mode])
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_modes(modes):
|
||||
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
|
||||
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
||||
in the HOP, so we need to externally run this mode and not trace it."""
|
||||
from contextlib import ExitStack
|
||||
|
||||
def fn(gm, fake_tensor_inputs, **kwargs):
|
||||
with mode:
|
||||
return gm.forward
|
||||
stack = ExitStack()
|
||||
for mode in modes:
|
||||
stack.enter_context(mode)
|
||||
|
||||
result = gm.forward
|
||||
stack.close()
|
||||
return result
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@ -120,6 +120,7 @@ from .utils import (
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
)
|
||||
from .variables.torch_function import torch_function_mode_stack_state_mgr
|
||||
|
||||
|
||||
np: Optional[ModuleType]
|
||||
@ -218,15 +219,18 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
|
||||
exit_stack = contextlib.ExitStack()
|
||||
exit_stack.enter_context(
|
||||
torch.fx._symbolic_trace._maybe_revert_all_patches()
|
||||
)
|
||||
exit_stack.enter_context(torch_function_mode_stack_state_mgr)
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
cleanup.close()
|
||||
assert (
|
||||
torch._C._len_torch_function_stack() == 0
|
||||
), "Torch function mode stack state changed while dynamo tracing, please report a bug"
|
||||
exit_stack.close()
|
||||
torch._C._set_grad_enabled(prior_grad_mode)
|
||||
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
|
||||
|
||||
@ -2356,15 +2356,12 @@ class CheckFunctionManager:
|
||||
)
|
||||
|
||||
if config.enable_cpp_guard_manager:
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
# Insert the global_state guard
|
||||
assert self.guard_manager # to make mypy happy
|
||||
self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
|
||||
|
||||
self.guard_manager.root.add_torch_function_mode_stack_guard(
|
||||
self.torch_function_mode_stack,
|
||||
list(IGNORED_MODES),
|
||||
["___check_torch_function_mode_stack()"],
|
||||
)
|
||||
# Clear references to torch_function modes held in the list
|
||||
@ -2671,18 +2668,14 @@ def is_recompiles_verbose_enabled():
|
||||
# this will only be used if cpp guards are disabled
|
||||
def make_torch_function_mode_stack_guard(intial_stack):
|
||||
types = [type(x) for x in intial_stack]
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
def check_torch_function_mode_stack():
|
||||
cur_stack = get_torch_function_mode_stack()
|
||||
|
||||
types_ = [ty for ty in types if ty not in IGNORED_MODES]
|
||||
cur_stack_ = [mode for mode in cur_stack if type(mode) not in IGNORED_MODES]
|
||||
|
||||
if len(cur_stack_) != len(types_):
|
||||
if len(cur_stack) != len(types):
|
||||
return False
|
||||
|
||||
for ty, mode in zip(types_, cur_stack_):
|
||||
for ty, mode in zip(types, cur_stack):
|
||||
if ty != type(mode):
|
||||
return False
|
||||
|
||||
|
||||
@ -78,7 +78,6 @@ from .utils import (
|
||||
get_instruction_source_311,
|
||||
get_locals_to_steal,
|
||||
get_static_address_type,
|
||||
get_torch_function_mode_stack,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
lazy_format_graph_code,
|
||||
@ -250,6 +249,7 @@ class OutputGraph:
|
||||
local_scope: Scope,
|
||||
global_scope: Scope,
|
||||
f_code,
|
||||
torch_function_mode_stack,
|
||||
):
|
||||
super().__init__()
|
||||
self.tracers = [SubgraphTracer(self, export_root=export)]
|
||||
@ -368,7 +368,7 @@ class OutputGraph:
|
||||
# This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
|
||||
self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
|
||||
# This records the initial torch function mode stack for guarding
|
||||
self.torch_function_mode_stack = get_torch_function_mode_stack()
|
||||
self.torch_function_mode_stack = torch_function_mode_stack
|
||||
|
||||
# Tracks if the output graph has a user defined allowed function in the
|
||||
# graph. This is used later to determine if we should fallback to eager
|
||||
@ -1021,7 +1021,7 @@ class OutputGraph:
|
||||
prefix_insts.clear()
|
||||
|
||||
for block in reversed(tx.block_stack):
|
||||
block.exit(tx)
|
||||
block.exit(tx, is_graph_break=reason.graph_break)
|
||||
|
||||
self.cleanup_graph()
|
||||
tx.prune_dead_locals()
|
||||
|
||||
@ -25,6 +25,26 @@ if TYPE_CHECKING:
|
||||
sys as sys,
|
||||
)
|
||||
|
||||
from torch.overrides import BaseTorchFunctionMode
|
||||
|
||||
|
||||
# These classes handle support for TorchFunctionModes across
|
||||
# graph breaks
|
||||
# Today the TorchFunctionMode enter (for the classes we support)
|
||||
# simply pushes the mode onto the stack. Since after this occurs
|
||||
# the stack is mutated, and we replay these mutations, we don't need
|
||||
# any cleanup logic to be run once the graph break occurs, we simply replay
|
||||
# these mutations to ensure at the graph break the torch function mode stack is correct
|
||||
# and reconstruct the torch function mode stack normally
|
||||
# when we compile the resume function on the other side of the break.
|
||||
# However, to ensure we exit properly
|
||||
# in the resume function, we need to re-enter the contexts as we do other contexts.
|
||||
# These contexts do nothing on enter, but provide the correct exit logic to ensure
|
||||
# the stack state is correct.
|
||||
class NoEnterTorchFunctionMode(BaseTorchFunctionMode):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def index(iterator, item, start=0, end=None):
|
||||
from itertools import islice
|
||||
|
||||
@ -90,27 +90,25 @@ class ReenterWith:
|
||||
stack_index: int
|
||||
target_values: Optional[Tuple[Any, ...]] = None
|
||||
|
||||
# TODO(mlazos) - Uncomment with the reland of torch function mode support
|
||||
# def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
# """
|
||||
# Codegen based off of:
|
||||
# try:
|
||||
# (rest)
|
||||
# except:
|
||||
# (restore previous tf mode stack)
|
||||
# raise
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]):
|
||||
"""
|
||||
Codegen based off of:
|
||||
try:
|
||||
(rest)
|
||||
except:
|
||||
(restore previous tf mode stack)
|
||||
raise
|
||||
"""
|
||||
from .variables.torch_function import get_prev_stack_var_name
|
||||
|
||||
# """
|
||||
# from .variables.torch_function import get_prev_stack_var_name
|
||||
setup_try_except, epilogue = _bytecode_from_template_with_split(
|
||||
_try_except_tf_mode_template,
|
||||
self.stack_index,
|
||||
varname_map={"stack_var_name": get_prev_stack_var_name()},
|
||||
)
|
||||
cleanup[:] = epilogue + cleanup
|
||||
|
||||
# setup_try_except, epilogue = _bytecode_from_template_with_split(
|
||||
# _try_except_tf_mode_template,
|
||||
# self.stack_index,
|
||||
# varname_map={"stack_var_name": get_prev_stack_var_name()},
|
||||
# )
|
||||
# cleanup[:] = epilogue + cleanup
|
||||
|
||||
# return setup_try_except
|
||||
return setup_try_except
|
||||
|
||||
# If we do not want to destroy the stack, we can do the same thing as a
|
||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||
|
||||
@ -629,11 +629,22 @@ class SideEffects:
|
||||
elif isinstance(
|
||||
var, variables.torch_function.TorchFunctionModeStackVariable
|
||||
):
|
||||
# Needed in the finally block for stack restoration
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "get_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
cg.call_function(0, False)
|
||||
name = variables.torch_function.get_prev_stack_var_name()
|
||||
cg.code_options["co_varnames"] += (name,)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(
|
||||
utils.__name__, "set_torch_function_mode_stack"
|
||||
)
|
||||
)
|
||||
|
||||
cg.foreach(var.symbolic_stack)
|
||||
cg.append_output(
|
||||
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
|
||||
|
||||
@ -267,13 +267,12 @@ class BlockStackEntry:
|
||||
else:
|
||||
return ReenterWith(self.stack_index)
|
||||
|
||||
def exit(self, tx):
|
||||
if hasattr(self, "graph_break") and isinstance(
|
||||
self.with_context, TorchFunctionModeVariable
|
||||
):
|
||||
return
|
||||
def exit(self, tx, is_graph_break):
|
||||
assert self.with_context is not None
|
||||
return self.with_context.exit(tx)
|
||||
if (
|
||||
is_graph_break and self.with_context.exit_on_graph_break()
|
||||
) or not is_graph_break:
|
||||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
class ReturnValueOp(Exception):
|
||||
@ -657,10 +656,17 @@ def break_graph_if_unsupported(*, push):
|
||||
cleanup: List[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
# Don't exit any modes we have entered,
|
||||
# output bytecode will mutate the tf mode stack accordingly
|
||||
if isinstance(b.with_context, TorchFunctionModeVariable):
|
||||
cg.extend_output(
|
||||
b.resume_fn().try_except_torch_function_mode(
|
||||
cg.code_options, cleanup
|
||||
)
|
||||
)
|
||||
continue
|
||||
assert b.with_context is not None
|
||||
assert isinstance(
|
||||
b.with_context, (ContextWrappingVariable, TorchFunctionModeVariable)
|
||||
)
|
||||
assert isinstance(b.with_context, (ContextWrappingVariable))
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
@ -2314,7 +2320,10 @@ class InstructionTranslatorBase(
|
||||
):
|
||||
unimplemented(f"{inst.opname} {ctx}")
|
||||
|
||||
if isinstance(ctx, GenericContextWrappingVariable):
|
||||
if (
|
||||
isinstance(ctx, GenericContextWrappingVariable)
|
||||
and not ctx.supports_graph_breaks()
|
||||
):
|
||||
self.generic_context_manager_depth += 1
|
||||
|
||||
# Need this redundant check for mypy
|
||||
@ -2687,6 +2696,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
local_scope=f_locals,
|
||||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
),
|
||||
instructions=instructions,
|
||||
f_locals=f_locals,
|
||||
|
||||
@ -187,6 +187,7 @@ def debug_insert_nops(
|
||||
local_scope=locals(),
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
|
||||
@ -304,6 +304,7 @@ manual_torch_name_rule_map = {
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
"torch.set_default_device": UserFunctionVariable,
|
||||
"torch.sparse_bsc_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_bsr_tensor": SkipFunctionVariable,
|
||||
"torch.sparse_csc_tensor": SkipFunctionVariable,
|
||||
@ -2802,7 +2803,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.random.initial_seed",
|
||||
"torch.random.seed",
|
||||
"torch.return_types.pytree_register_structseq",
|
||||
"torch.set_default_device",
|
||||
"torch.set_default_dtype",
|
||||
"torch.set_default_tensor_type",
|
||||
"torch.set_deterministic_debug_mode",
|
||||
@ -2912,6 +2912,9 @@ def get_tensor_method():
|
||||
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
|
||||
):
|
||||
s.add(method)
|
||||
|
||||
# mlazos: this is a function which we handle specially in TensorVariable
|
||||
s.add(torch.Tensor.__contains__) # type: ignore[arg-type]
|
||||
return frozenset(s)
|
||||
|
||||
|
||||
|
||||
@ -2912,18 +2912,28 @@ def is_torch_function_object(value):
|
||||
|
||||
|
||||
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
|
||||
from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
|
||||
from torch._dynamo.variables import UserDefinedObjectVariable
|
||||
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
if isinstance(vt, TensorWithTFOverrideVariable):
|
||||
return True
|
||||
# Note on lazy vars: The value will either be realized or not throughout the course of execution
|
||||
# if the value has a torch function, it will eventually be realized so we can realize it here
|
||||
# if the value does not have a torch function, it may or may not be realized
|
||||
# if it is realized it will be used and guards will be installed properly
|
||||
# if it is not used, guards won't be installed, and it doesn't matter
|
||||
# if the value has a torch function or not, so we should *not* realize it.
|
||||
# NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
|
||||
# but mypy does not unfortunately
|
||||
if vt.is_realized() or (
|
||||
hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
|
||||
):
|
||||
if isinstance(vt, TensorWithTFOverrideVariable):
|
||||
return True
|
||||
|
||||
if isinstance(vt, LazyVariableTracker):
|
||||
LazyVariableTracker.realize(vt)
|
||||
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
|
||||
vt.value, "__torch_function__"
|
||||
)
|
||||
|
||||
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
|
||||
vt.value, "__torch_function__"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# see note [Tensor Fakification and Symbol Caching]
|
||||
@ -3116,16 +3126,10 @@ def is_parameter_freezing():
|
||||
return torch._inductor.config.freezing and not torch.is_grad_enabled()
|
||||
|
||||
|
||||
def get_torch_function_mode_stack(filter_ignored=True):
|
||||
from .variables.torch_function import IGNORED_MODES
|
||||
|
||||
stack = [
|
||||
def get_torch_function_mode_stack():
|
||||
return [
|
||||
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
|
||||
]
|
||||
if filter_ignored:
|
||||
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]
|
||||
|
||||
return stack
|
||||
|
||||
|
||||
def get_torch_function_mode_stack_at(ind):
|
||||
|
||||
@ -204,6 +204,7 @@ from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
||||
from .torch_function import (
|
||||
build_torch_function_fn,
|
||||
TensorWithTFOverrideVariable,
|
||||
torch_function_mode_stack_state_mgr,
|
||||
TorchFunctionModeVariable,
|
||||
)
|
||||
from .user_defined import (
|
||||
@ -1669,15 +1670,16 @@ class VariableBuilder:
|
||||
# but warning is not the end of the world
|
||||
assert isinstance(value.base, np.nditer)
|
||||
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
with torch_function_mode_stack_state_mgr.temp_restore_stack():
|
||||
try:
|
||||
tensor_value = _util._try_convert_to_tensor(value)
|
||||
if readonly:
|
||||
from torch._prims_common import clone_preserve_strides
|
||||
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
tensor_value = clone_preserve_strides(tensor_value)
|
||||
except NotImplementedError as e:
|
||||
# failed to convert to tensor, graph break
|
||||
unimplemented(str(e))
|
||||
|
||||
# We do this because we want the full behavior of guarding the numpy ndarray as if it were
|
||||
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
|
||||
|
||||
@ -200,7 +200,6 @@ class BuiltinVariable(VariableTracker):
|
||||
operator.ne,
|
||||
operator.eq,
|
||||
operator.sub,
|
||||
operator.getitem,
|
||||
operator.length_hint,
|
||||
operator.lshift,
|
||||
operator.rshift,
|
||||
@ -212,6 +211,7 @@ class BuiltinVariable(VariableTracker):
|
||||
operator.imatmul,
|
||||
operator.ifloordiv,
|
||||
operator.itruediv,
|
||||
operator.getitem,
|
||||
operator.imod,
|
||||
operator.iadd,
|
||||
operator.isub,
|
||||
@ -858,6 +858,39 @@ class BuiltinVariable(VariableTracker):
|
||||
if kwargs and not self.tensor_args(*args, *kwargs.values()):
|
||||
return
|
||||
|
||||
# insert handling for torch function here
|
||||
from .builder import SourcelessBuilder
|
||||
from .torch_function import (
|
||||
BUILTIN_TO_TENSOR_FN_MAP,
|
||||
BUILTIN_TO_TENSOR_RFN_MAP,
|
||||
can_dispatch_torch_function,
|
||||
dispatch_torch_function,
|
||||
)
|
||||
|
||||
if can_dispatch_torch_function(tx, args, kwargs):
|
||||
# Only remap the fn to tensor methods if we aren't exporting
|
||||
# export serde does not handle method descriptors today
|
||||
if not tx.export:
|
||||
# Use sourceless builder, we built the map ourselves
|
||||
if not isinstance(args[0], TensorVariable):
|
||||
if self.fn in BUILTIN_TO_TENSOR_RFN_MAP:
|
||||
func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn]
|
||||
else:
|
||||
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
|
||||
|
||||
tmp = args[0]
|
||||
# swap args and call reverse version of func
|
||||
args[0] = args[1]
|
||||
args[1] = tmp
|
||||
else:
|
||||
func = BUILTIN_TO_TENSOR_FN_MAP[self.fn]
|
||||
else:
|
||||
func = self.fn
|
||||
|
||||
fn_var = SourcelessBuilder.create(tx, func)
|
||||
|
||||
return dispatch_torch_function(tx, fn_var, args, kwargs)
|
||||
|
||||
fn = self.fn
|
||||
try:
|
||||
# Constant fold for constant tensor and python constants
|
||||
|
||||
@ -125,6 +125,12 @@ class ContextWrappingVariable(VariableTracker):
|
||||
if isinstance(args[0], UserFunctionVariable):
|
||||
return WrappedUserFunctionVariable(args[0], self)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
@ -183,6 +189,12 @@ class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
tx.generic_context_manager_depth -= 1
|
||||
return x
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return False
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
|
||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch grad requries grad"""
|
||||
|
||||
@ -1998,8 +1998,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
fn: "VariableTracker",
|
||||
fn_name: str,
|
||||
):
|
||||
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
|
||||
|
||||
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
tx: InstructionTranslator = tx
|
||||
|
||||
@ -80,6 +80,14 @@ class LazyVariableTracker(VariableTracker):
|
||||
self.realize()
|
||||
return VariableTracker.clone(self.unwrap(), **kwargs)
|
||||
|
||||
def peek_type(self) -> type[Any]:
|
||||
assert not self.is_realized()
|
||||
return type(self._cache.value)
|
||||
|
||||
def peek_value(self) -> Any:
|
||||
assert not self.is_realized()
|
||||
return self._cache.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_realized():
|
||||
return self.unwrap().__str__()
|
||||
|
||||
@ -510,9 +510,37 @@ class TensorVariable(VariableTracker):
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
|
||||
unimplemented(f"Illegal method invocation {name} in strict mode")
|
||||
|
||||
# Only override builtin tensor methods
|
||||
# The user can manually add override handling
|
||||
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
|
||||
has_torch_function_override = False
|
||||
try:
|
||||
inspect.getattr_static(torch.Tensor, name)
|
||||
has_torch_function_override = True
|
||||
except AttributeError:
|
||||
has_torch_function_override = False
|
||||
|
||||
if (
|
||||
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
|
||||
and has_torch_function_override
|
||||
):
|
||||
if self.source:
|
||||
func_var = VariableBuilder(
|
||||
tx, AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)(inspect.getattr_static(torch.Tensor, name))
|
||||
else:
|
||||
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
|
||||
|
||||
return dispatch_torch_function(
|
||||
tx, func_var, tuple([self] + list(args)), kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
Dispatch to a method-specific handler defined below. If the
|
||||
handler returns None (or doesn't exist) we put the method call
|
||||
@ -772,6 +800,30 @@ class TensorVariable(VariableTracker):
|
||||
self._warn_capture_scalar_outputs()
|
||||
unimplemented("Tensor.item")
|
||||
|
||||
def method___getitem__(self, *args, **kwargs):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
if isinstance(args[0], SymNodeVariable):
|
||||
# Standard indexing will force specialization due to
|
||||
# __index__. Rewrite as a regular torch op which will
|
||||
# trace fine
|
||||
fn, args = torch.select, [
|
||||
variables.ConstantVariable.create(0),
|
||||
args[0],
|
||||
]
|
||||
else:
|
||||
fn = operator.getitem
|
||||
|
||||
proxy = tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
*proxy_args_kwargs([self] + list(args), kwargs),
|
||||
)
|
||||
|
||||
return wrap_fx_proxy(tx, proxy)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _warn_capture_scalar_outputs():
|
||||
|
||||
@ -159,7 +159,17 @@ def get_overridable_functions():
|
||||
|
||||
from torch.overrides import get_overridable_functions as get_overridable_functions_
|
||||
|
||||
return set(chain(*get_overridable_functions_().values()))
|
||||
funcs = set(chain(*get_overridable_functions_().values()))
|
||||
more = {
|
||||
torch.ones,
|
||||
torch.ones_like,
|
||||
torch.zeros,
|
||||
torch.zeros_like,
|
||||
torch.empty,
|
||||
torch.full,
|
||||
}
|
||||
funcs.update(more)
|
||||
return funcs
|
||||
|
||||
|
||||
class BaseTorchVariable(VariableTracker):
|
||||
@ -835,6 +845,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
len(tx.symbolic_torch_function_state.mode_stack)
|
||||
)
|
||||
|
||||
@register(torch._C._get_function_stack_at)
|
||||
def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
assert len(args) == 1 and not kwargs
|
||||
ind = args[0].as_python_constant()
|
||||
assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack)
|
||||
return tx.symbolic_torch_function_state.mode_stack[ind]
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
@ -852,7 +869,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
else:
|
||||
TorchFunctionModeStackVariable.register_device_context_insertion(tx)
|
||||
|
||||
return None
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
return handlers
|
||||
|
||||
@ -883,6 +900,9 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
),
|
||||
)
|
||||
|
||||
if self.is_tensor_method():
|
||||
return self.call_tensor_method(tx, args, kwargs)
|
||||
|
||||
special_handler = self._get_handlers().get(self.value)
|
||||
if special_handler:
|
||||
result = special_handler(self, tx, *args, **kwargs)
|
||||
@ -1155,6 +1175,16 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
)
|
||||
return result
|
||||
|
||||
def call_tensor_method(self, tx, args, kwargs):
|
||||
return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs)
|
||||
|
||||
def is_tensor_method(self):
|
||||
return (
|
||||
inspect.ismethoddescriptor(self.get_function())
|
||||
and hasattr(self.get_function(), "__objclass__")
|
||||
and self.get_function().__objclass__ == torch._C.TensorBase
|
||||
) or self.get_function() is torch.Tensor.__contains__
|
||||
|
||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||
return (
|
||||
self.get_function() in get_overridable_functions()
|
||||
|
||||
@ -2,22 +2,37 @@
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from typing import Deque, Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import Source
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
from torch.overrides import (
|
||||
_get_overloaded_args,
|
||||
BaseTorchFunctionMode,
|
||||
get_default_nowrap_functions,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
|
||||
from ..utils import (
|
||||
class_has_getattribute,
|
||||
clear_torch_function_mode_stack,
|
||||
get_safe_global_name,
|
||||
has_torch_function,
|
||||
is_tensor_base_attr_getter,
|
||||
set_torch_function_mode_stack,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
@ -49,6 +64,125 @@ if TYPE_CHECKING:
|
||||
|
||||
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
|
||||
|
||||
bin_ops = [
|
||||
operator.pow,
|
||||
operator.mul,
|
||||
operator.matmul,
|
||||
operator.floordiv,
|
||||
operator.truediv,
|
||||
operator.mod,
|
||||
operator.add,
|
||||
operator.lt,
|
||||
operator.gt,
|
||||
operator.ge,
|
||||
operator.le,
|
||||
operator.ne,
|
||||
operator.eq,
|
||||
operator.sub,
|
||||
operator.ipow,
|
||||
operator.imul,
|
||||
operator.imatmul,
|
||||
operator.ifloordiv,
|
||||
operator.itruediv,
|
||||
operator.imod,
|
||||
operator.iadd,
|
||||
operator.isub,
|
||||
]
|
||||
|
||||
bin_int_ops = [
|
||||
operator.and_,
|
||||
operator.or_,
|
||||
operator.xor,
|
||||
operator.iand,
|
||||
operator.ixor,
|
||||
operator.ior,
|
||||
]
|
||||
|
||||
un_int_ops = [operator.invert]
|
||||
|
||||
tensor_and_int_ops = [
|
||||
operator.lshift,
|
||||
operator.rshift,
|
||||
operator.ilshift,
|
||||
operator.irshift,
|
||||
operator.getitem,
|
||||
]
|
||||
|
||||
un_ops = [
|
||||
operator.abs,
|
||||
operator.pos,
|
||||
operator.neg,
|
||||
operator.not_, # Note: this has a local scalar dense call
|
||||
operator.length_hint,
|
||||
]
|
||||
|
||||
BUILTIN_TO_TENSOR_FN_MAP = {}
|
||||
|
||||
# These functions represent the r* versions of the above ops
|
||||
# Basically, if __add__(1, Tensor) is called, it is translated
|
||||
# to __radd__(Tensor, 1).
|
||||
# In the builtin var, we check if there is a tensor in the first args position,
|
||||
# if not, we swap the args and use the r* version of the op.
|
||||
BUILTIN_TO_TENSOR_RFN_MAP = {}
|
||||
|
||||
|
||||
def populate_builtin_to_tensor_fn_map():
|
||||
global BUILTIN_TO_TENSOR_FN_MAP
|
||||
|
||||
most_recent_func = None
|
||||
|
||||
class GetMethodMode(BaseTorchFunctionMode):
|
||||
"""
|
||||
Mode to extract the correct methods from torch function invocations
|
||||
(Used to get the correct torch.Tensor methods from builtins)
|
||||
"""
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
nonlocal most_recent_func
|
||||
most_recent_func = func
|
||||
return func(*args, **kwargs)
|
||||
|
||||
inp0 = torch.ones(1)
|
||||
inp1 = torch.ones(1)
|
||||
inp0_int = torch.ones(1, dtype=torch.int32)
|
||||
inp1_int = torch.ones(1, dtype=torch.int32)
|
||||
with GetMethodMode():
|
||||
setups_and_oplists = [
|
||||
(lambda o: o(inp0), un_ops),
|
||||
(lambda o: o(inp0_int), un_int_ops),
|
||||
(lambda o: o(inp0, inp1), bin_ops),
|
||||
(lambda o: o(inp0_int, inp1_int), bin_int_ops),
|
||||
(lambda o: o(inp0_int, 0), tensor_and_int_ops),
|
||||
]
|
||||
for setup_fn, op_list in setups_and_oplists:
|
||||
for op in op_list:
|
||||
setup_fn(op)
|
||||
assert most_recent_func is not None
|
||||
BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func
|
||||
|
||||
# gather the reverse functions
|
||||
rsetups_and_oplists = [
|
||||
(
|
||||
lambda o: o(1, inp1),
|
||||
bin_ops,
|
||||
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
||||
(lambda o: o(1, inp1_int), bin_int_ops),
|
||||
(lambda o: o(0, inp0_int), tensor_and_int_ops),
|
||||
]
|
||||
|
||||
rskips = {operator.matmul, operator.imatmul, operator.getitem}
|
||||
for setup_fn, op_list in rsetups_and_oplists:
|
||||
for op in op_list:
|
||||
if op in rskips:
|
||||
continue
|
||||
setup_fn(op)
|
||||
assert most_recent_func is not None
|
||||
if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]:
|
||||
BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func
|
||||
|
||||
|
||||
populate_builtin_to_tensor_fn_map()
|
||||
|
||||
banned_attrs = [
|
||||
fn.__self__.__name__
|
||||
@ -56,11 +190,38 @@ banned_attrs = [
|
||||
if is_tensor_base_attr_getter(fn)
|
||||
]
|
||||
|
||||
# Today set default device is placed in the graph and guarded on separately
|
||||
# so we should not trace through it. In the future we can trace it once
|
||||
# mode tracing is implemented and not put in the graph, but this is more
|
||||
# of a BE project and can be evaluated later
|
||||
IGNORED_MODES = {DeviceContext}
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_prev_stack_var_name():
|
||||
from ..bytecode_transformation import unique_id
|
||||
|
||||
return unique_id("___prev_torch_function_mode_stack")
|
||||
|
||||
|
||||
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed
|
||||
class TorchFunctionModeStackStateManager:
|
||||
def __init__(self):
|
||||
self.stack = []
|
||||
|
||||
def __enter__(self):
|
||||
self.stack = torch.overrides._get_current_function_mode_stack()
|
||||
clear_torch_function_mode_stack()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
self.stack = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_restore_stack(self):
|
||||
prev = torch.overrides._get_current_function_mode_stack()
|
||||
set_torch_function_mode_stack(self.stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_torch_function_mode_stack(prev)
|
||||
|
||||
|
||||
torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
|
||||
|
||||
|
||||
class SymbolicTorchFunctionState:
|
||||
@ -189,9 +350,26 @@ class TorchFunctionModeStackVariable(VariableTracker):
|
||||
return ind + cls.offset
|
||||
|
||||
|
||||
class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
||||
@staticmethod
|
||||
def is_supported_torch_function_mode(ty):
|
||||
# Supported in this sense means we can support graph breaks under the
|
||||
# context.
|
||||
# We are able to trace custom modes but if there are graph breaks under them
|
||||
# and they have a custom __enter__/__exit__ we don't handle this for the
|
||||
# same reason we don't handle generic context managers: there may be side effects
|
||||
# that are now affected by executing the funtion across two frames instead of one
|
||||
# Today we support the enter/exit of the default TorchFunctionMode as well as
|
||||
# DeviceContext (which is used for set_default_device)
|
||||
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
|
||||
not class_has_getattribute(ty)
|
||||
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
|
||||
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
|
||||
)
|
||||
|
||||
def __init__(self, value, source=None, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
if value is not None:
|
||||
super().__init__(value, **kwargs)
|
||||
self.value = value
|
||||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
@ -221,8 +399,39 @@ class TorchFunctionModeVariable(ContextWrappingVariable):
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
unimplemented("enter/exit for torch function mode NYI")
|
||||
def enter(self, tx):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
if isinstance(self.value, NoEnterTorchFunctionMode):
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
TorchInGraphFunctionVariable(
|
||||
torch._C._push_on_torch_function_stack
|
||||
).call_function(tx, [self], {})
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
|
||||
tx, [], {}
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(ty.__module__),
|
||||
ty.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def supports_graph_breaks(self):
|
||||
return True
|
||||
|
||||
def exit_on_graph_break(self):
|
||||
return False
|
||||
|
||||
|
||||
def _get_all_args(args, kwargs):
|
||||
@ -233,7 +442,6 @@ def _flatten_vts(vts):
|
||||
from collections import deque
|
||||
|
||||
from .dicts import ConstDictVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import ListVariable
|
||||
|
||||
vts = deque(vts)
|
||||
@ -241,13 +449,17 @@ def _flatten_vts(vts):
|
||||
|
||||
while vts:
|
||||
vt = vts.pop()
|
||||
LazyVariableTracker.realize_all(vt)
|
||||
if isinstance(vt, ListVariable):
|
||||
vts.extend(vt.items)
|
||||
elif isinstance(vt, ConstDictVariable):
|
||||
vts.extend(vt.items.values())
|
||||
else:
|
||||
output.append(vt)
|
||||
|
||||
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
|
||||
vt.realize()
|
||||
|
||||
if vt.is_realized():
|
||||
if isinstance(vt, ListVariable):
|
||||
vts.extend(vt.items)
|
||||
elif isinstance(vt, ConstDictVariable):
|
||||
vts.extend(vt.items.values())
|
||||
|
||||
output.append(vt)
|
||||
|
||||
return output
|
||||
|
||||
@ -301,8 +513,15 @@ def call_torch_function(
|
||||
|
||||
|
||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||
from types import FunctionType
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
func = value.__torch_function__.__func__
|
||||
|
||||
if not isinstance(func, FunctionType):
|
||||
unimplemented("Builtin/C++ torch function implementations NYI")
|
||||
|
||||
if source:
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
|
||||
@ -413,10 +413,22 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .torch_function import TorchFunctionModeVariable
|
||||
|
||||
if issubclass(
|
||||
self.value, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
||||
self.value
|
||||
):
|
||||
var_cls = TorchFunctionModeVariable
|
||||
else:
|
||||
var_cls = GenericContextWrappingVariable
|
||||
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
self.source, self.value, GenericContextWrappingVariable, {}
|
||||
self.source, self.value, var_cls, {}
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
||||
@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import (
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator, OpOverload
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
make_fx,
|
||||
@ -19,7 +19,6 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
|
||||
@ -69,27 +68,6 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch
|
||||
return new_out
|
||||
|
||||
|
||||
class TransformGetItemToIndex(TorchFunctionMode):
|
||||
# This is needed since we want to support calling
|
||||
# A[q_idx], where q_idx is a scalar tensor in score_mod.
|
||||
# Today, when q_idx is a scalar tensor, we implicitly convert it to a python
|
||||
# scalar and create a view. We do not want that behavior in this case, so we
|
||||
# use this torchfunctionmode to override that behavior for score_mod
|
||||
# wherever we're running it.
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Tuple[torch._C._TensorMeta, ...],
|
||||
args: Tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None,
|
||||
) -> object:
|
||||
if func == torch.Tensor.__getitem__:
|
||||
index_args = pytree.tree_leaves(args[1])
|
||||
if all(isinstance(x, torch.Tensor) for x in index_args):
|
||||
return torch.ops.aten.index(args[0], index_args)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class FlexAttentionHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("flex_attention", cacheable=True)
|
||||
@ -185,6 +163,8 @@ def _math_attention_inner(
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
|
||||
|
||||
scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
|
||||
@ -318,6 +298,8 @@ def trace_flex_attention(
|
||||
This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
|
||||
access this graph module in inductor to inline the score_mod function to the triton template.
|
||||
"""
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
example_out = flex_attention(
|
||||
query,
|
||||
key,
|
||||
@ -414,6 +396,8 @@ def flex_attention_functionalize(
|
||||
guard against any mutations in the score_mod function, to the other_buffers since those
|
||||
are free variables.
|
||||
"""
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
query_unwrapped = ctx.unwrap_tensors(query)
|
||||
key_unwrapped = ctx.unwrap_tensors(key)
|
||||
value_unwrapped = ctx.unwrap_tensors(value)
|
||||
@ -715,6 +699,8 @@ def flex_attention_autograd(
|
||||
score_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
with TransformGetItemToIndex():
|
||||
input_requires_grad = any(t.requires_grad for t in (query, key, value))
|
||||
if torch.is_grad_enabled() and input_requires_grad:
|
||||
@ -765,6 +751,8 @@ def sdpa_dense_backward(
|
||||
score_mod_other_buffers: Tuple,
|
||||
mask_mod_other_buffers: Tuple,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
# Get outputs before calling repeat interleave
|
||||
actual_grad_query = torch.empty_like(query)
|
||||
actual_grad_key = torch.empty_like(key)
|
||||
@ -892,6 +880,8 @@ def trace_flex_attention_backward(
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
|
||||
example_out = flex_attention_backward(
|
||||
query,
|
||||
key,
|
||||
|
||||
@ -8,6 +8,8 @@ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_imp
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
@ -18,14 +20,26 @@ from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
@exposed_in("torch")
|
||||
def strict_mode(callable, operands):
|
||||
from torch._dynamo.backends.debugging import (
|
||||
make_eager_backend_with_torch_function_modes,
|
||||
)
|
||||
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return strict_mode_op(callable, operands)
|
||||
|
||||
with _set_compilation_env():
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
|
||||
callable, operands
|
||||
)
|
||||
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
||||
with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
|
||||
modes = [metadata_mode, predispatch_mode]
|
||||
modes = [mode for mode in modes if mode is not None]
|
||||
if modes:
|
||||
backend = make_eager_backend_with_torch_function_modes(modes)
|
||||
else:
|
||||
backend = "eager"
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(
|
||||
strict_mode_op, backend=backend, fullgraph=True
|
||||
)(callable, operands)
|
||||
|
||||
|
||||
class StrictMode(HigherOrderOperator):
|
||||
|
||||
@ -2540,90 +2540,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
|
||||
public:
|
||||
TORCH_FUNCTION_MODE_STACK(
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts)
|
||||
: LeafGuard(std::move(verbose_code_parts)),
|
||||
_ref_stack(),
|
||||
_ignored_types() {
|
||||
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
|
||||
Py_ssize_t len = PyList_Size(initial_stack.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
|
||||
auto type = Py_TYPE(mode);
|
||||
this->_ref_stack.push_back(type);
|
||||
}
|
||||
|
||||
len = PyList_Size(ignored_types.ptr());
|
||||
for (Py_ssize_t idx = 0; idx < len; idx++) {
|
||||
PyObject* type_obj =
|
||||
PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
|
||||
if (PyType_Check(type_obj) == 0) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "ignored_types should contain a list of types");
|
||||
return;
|
||||
}
|
||||
PyTypeObject* type = (PyTypeObject*)type_obj;
|
||||
this->_ignored_types.insert(type);
|
||||
}
|
||||
}
|
||||
|
||||
bool check_nopybind(PyObject* value) override {
|
||||
// Ignore value arg, only used to satisfy the interface
|
||||
size_t ref_ind = 0;
|
||||
const int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
const size_t ref_stack_size = this->_ref_stack.size();
|
||||
|
||||
int64_t idx = 0;
|
||||
while ((idx < len) && (ref_ind < ref_stack_size)) {
|
||||
if (len != ref_stack_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int64_t idx = 0; (size_t)idx < len; idx++) {
|
||||
std::shared_ptr<c10::SafePyObject> mode =
|
||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
|
||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||
bool act_ignored = this->_ignored_types.count(mode_type) > 0;
|
||||
bool ref_ignored =
|
||||
this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0;
|
||||
// skip ignored types
|
||||
if (act_ignored && ref_ignored) {
|
||||
idx++;
|
||||
ref_ind++;
|
||||
continue;
|
||||
} else if (ref_ignored) {
|
||||
ref_ind++;
|
||||
continue;
|
||||
} else if (act_ignored) {
|
||||
idx++;
|
||||
continue;
|
||||
}
|
||||
// if we already have more non-ignored modes than the ref stack
|
||||
// or if the mode doesn't match at the current index, return false
|
||||
else if (mode_type != _ref_stack.at(ref_ind)) {
|
||||
return false;
|
||||
}
|
||||
ref_ind++;
|
||||
idx++;
|
||||
}
|
||||
|
||||
for (; ref_ind < ref_stack_size; ref_ind++) {
|
||||
if (!(this->_ignored_types.count(this->_ref_stack.at(ref_ind)) > 0)) {
|
||||
if (mode_type != _ref_stack.at(idx)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (; idx < len; idx++) {
|
||||
std::shared_ptr<c10::SafePyObject> mode =
|
||||
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
|
||||
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
|
||||
if (!(this->_ignored_types.count(mode_type) > 0)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return ref_ind == ref_stack_size && idx == len;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PyTypeObject*> _ref_stack;
|
||||
std::set<PyTypeObject*> _ignored_types;
|
||||
};
|
||||
|
||||
class TENSOR_MATCH : public LeafGuard {
|
||||
@ -3792,7 +3742,7 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
LeafGuard,
|
||||
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
|
||||
py_m, "TORCH_FUNCTION_MODE_STACK")
|
||||
.def(py::init<py::list, py::list, py::list>())
|
||||
.def(py::init<py::list, py::list>())
|
||||
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
|
||||
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
|
||||
py_m, "DATA_PTR_MATCH")
|
||||
@ -4029,10 +3979,9 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
"add_torch_function_mode_stack_guard",
|
||||
[](GuardManager& self,
|
||||
const py::list& initial_stack,
|
||||
const py::list& ignored_types,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
|
||||
initial_stack, ignored_types, std::move(verbose_code_parts)));
|
||||
initial_stack, std::move(verbose_code_parts)));
|
||||
})
|
||||
.def(
|
||||
"add_data_ptr_guard",
|
||||
|
||||
@ -13,10 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention as flex_attention_hop,
|
||||
TransformGetItemToIndex,
|
||||
)
|
||||
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
|
||||
from torch._higher_order_ops.utils import _set_compilation_env
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
|
||||
Reference in New Issue
Block a user