Compare commits

...

21 Commits

Author SHA1 Message Date
4325dd1a3b Update
[ghstack-poisoned]
2025-11-08 23:01:01 -08:00
f2bddf2867 Update
[ghstack-poisoned]
2025-11-08 16:31:08 -08:00
6cfc9d4ebe Update (base update)
[ghstack-poisoned]
2025-11-08 16:31:08 -08:00
0ef2a685a9 Update
[ghstack-poisoned]
2025-11-08 10:46:59 -08:00
60206b5033 Update
[ghstack-poisoned]
2025-11-08 09:08:59 -08:00
63b7316aa2 Update
[ghstack-poisoned]
2025-11-08 08:57:03 -08:00
1ff2f99439 Update
[ghstack-poisoned]
2025-11-08 07:41:42 -08:00
000f889110 Update (base update)
[ghstack-poisoned]
2025-11-07 22:55:00 -08:00
d7b6b80875 Update
[ghstack-poisoned]
2025-11-07 22:55:00 -08:00
af675fb733 Update
[ghstack-poisoned]
2025-11-07 22:51:20 -08:00
490970fd48 Update (base update)
[ghstack-poisoned]
2025-11-07 21:19:25 -08:00
6b2b940258 Update
[ghstack-poisoned]
2025-11-07 21:19:25 -08:00
ff37e9b7b3 Update
[ghstack-poisoned]
2025-11-07 20:30:46 -08:00
a76728cadd Update (base update)
[ghstack-poisoned]
2025-11-07 16:18:26 -08:00
a9a799624f Update
[ghstack-poisoned]
2025-11-07 16:18:26 -08:00
29324e30e2 Update (base update)
[ghstack-poisoned]
2025-11-07 16:14:38 -08:00
c239c6b8cc Update
[ghstack-poisoned]
2025-11-07 16:14:38 -08:00
fc5f8f9fa9 Update
[ghstack-poisoned]
2025-11-07 16:08:09 -08:00
810c4681cb Update
[ghstack-poisoned]
2025-11-07 15:21:35 -08:00
a0c9c5a497 Update (base update)
[ghstack-poisoned]
2025-11-07 14:49:37 -08:00
016756ff0f Update
[ghstack-poisoned]
2025-11-07 14:49:37 -08:00
9 changed files with 495 additions and 428 deletions

View File

@ -1681,14 +1681,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y, y)
return (y,)
""",
)
@ -1798,9 +1797,9 @@ class GraphModule(torch.nn.Module):
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
cos: "f32[4, 4]" = torch.cos(sin_1)
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (cos, sin_2, matmul, o, out, sin_1)
""",
)

View File

@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
matmul: "f32[3, 3]" = l_x_ @ l_y_
sin: "f32[3, 3]" = matmul.sin(); matmul = None
child: "f32[3, 3]" = sin.cos(); sin = None
cos: "f32[3, 3]" = sin.cos(); sin = None
child_1: "f32[3, 3]" = l_x_ + l_y_
child_2: "f32[3, 3]" = l_x_ - l_y_
add: "f32[3, 3]" = l_x_ + l_y_
sub: "f32[3, 3]" = l_x_ - l_y_
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (child, child_1, child_2, child_3)
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (cos, add, sub, matmul_1)
""", # noqa: B950
)
self.assertExpectedInline(

View File

@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# when testing with dynamic shape, symbols are lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# be the input.
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
x = torch.randn(3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
def test_capture_untracked_global(self):
def f(x):
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
else:
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[3]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[3]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
child: "f32[3]" = sin + size; sin = size = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
add: "f32[3]" = sin + size; sin = size = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (add, sin_1)
""",
)
@ -1872,8 +1872,8 @@ def forward(self, L_x_ : torch.Tensor):
getitem_3 = map_impl[3]
getitem_4 = map_impl[4]
getitem_5 = map_impl[5]
value = map_impl[6]; map_impl = None
return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, value)""",
getitem_6 = map_impl[6]; map_impl = None
return (getitem, getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6)""",
)
self.assertExpectedInline(
body_graph,
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (child, child_1)
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (add, add_1)
""",
)
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[2, 3]"):
child: "f32[2, 3]" = l_x_.sin()
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (child, child_1)
sin: "f32[2, 3]" = l_x_.sin()
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (sin, cos)
""",
)
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
value: "f32[3]" = wrap[0]; wrap = None
return (value,)
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
child: "f32[3]" = -l_x_; l_x_ = None
return (child,)
neg: "f32[3]" = -l_x_; l_x_ = None
return (neg,)
""",
)
@ -3318,17 +3318,17 @@ class GraphModule(torch.nn.Module):
hints_wrapper_body_1 = self.hints_wrapper_body_1
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (res,)
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
return (getitem,)
class hints_wrapper_body_1(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
hints_wrapper_body_0 = self.hints_wrapper_body_0
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None
return (x_2,)
x_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
return (x_1,)
class hints_wrapper_body_0(torch.nn.Module):
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):

View File

@ -2354,10 +2354,10 @@ class GraphModule(torch.nn.Module):
ge_1: "Sym(u0 >= 0)" = getitem_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1; getitem_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 1 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
select: "f32[3]" = torch.ops.aten.select.int(x, 0, getitem_1); x = getitem_1 = None
select: "f32[3]" = torch.ops.aten.select.int(x, 0, 0); x = None
return pytree.tree_unflatten((select,), self._out_spec)
class true_graph_0(torch.nn.Module):

View File

@ -5530,9 +5530,9 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
value = while_loop[2]
value_1 = while_loop[3]; while_loop = None
return (getitem, getitem_1, value, value_1)""", # noqa: B950
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
return (getitem, getitem_1, getitem_2, getitem_3)""", # noqa: B950
)
def _wrap_with_functionalize(self, fn, func_type):
@ -5745,9 +5745,9 @@ class GraphModule(torch.nn.Module):
class body_fn_0(torch.nn.Module):
def forward(self, child_2: "i64[]", child_3: "f32[2, 2]", l_self_buffers_dec__cond_fn: "i64[]", l_self_modules_linear_parameters_bias__body_fn: "f32[2]", l_self_modules_linear_parameters_weight__body_fn: "f32[2, 2]"):
child: "i64[]" = child_2 - 1; child_2 = None
child_4: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
return (child, child_4)
sub: "i64[]" = child_2 - 1; child_2 = None
linear: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
return (sub, linear)
""", # noqa: B950
)
@ -8428,14 +8428,14 @@ class GraphModule(torch.nn.Module):
gt_1: "Sym(u2 > 0)" = getitem_4 > 0
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u2 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None
out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None
getitem_1: "f32[s77, s27]" = while_loop[1]; while_loop = None
gt: "Sym(u2 > 0)" = getitem_4 > 0
_check = torch._check(gt); gt = _check = None
add: "Sym(u2 + 1)" = getitem_4 + 1
add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None
add_1: "f32[s77, s27]" = getitem_4 + getitem_1; getitem_1 = None
lt: "Sym(u2 < s77)" = getitem_4 < s77; s77 = None
@ -8592,7 +8592,6 @@ class GraphModule(torch.nn.Module):
getitem_10: "Sym(u17)" = while_loop[2]
getitem_11: "Sym(u18)" = while_loop[3]
getitem_12: "Sym(u19)" = while_loop[4]
getitem_13: "Sym(u20)" = while_loop[5]
getitem_14: "Sym(u21)" = while_loop[6]
child: "f32[2, 3]" = while_loop[7]; while_loop = None
@ -8602,19 +8601,18 @@ class GraphModule(torch.nn.Module):
add_2: "Sym(u17 + 1)" = getitem_10 + 1
add_3: "Sym(u18 + 1)" = getitem_11 + 1
add_4: "Sym(u19 + 1)" = getitem_12 + 1
add_5: "Sym(u20 + 1)" = getitem_13 + 1
add_6: "Sym(u21 + 1)" = getitem_14 + 1
add_7: "f32[2, 3]" = child + 1
add_5: "Sym(u21 + 1)" = getitem_14 + 1
add_6: "f32[2, 3]" = child + 1
add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None
add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None
add_10: "f32[2, 3]" = getitem_10 + l_t_; getitem_10 = None
add_11: "f32[2, 3]" = getitem_11 + l_t_; getitem_11 = None
add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None
add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None
add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None
add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15)
add_7: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None
add_8: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None
add_9: "f32[2, 3]" = getitem_10 + l_t_; getitem_10 = None
add_10: "f32[2, 3]" = getitem_11 + l_t_; getitem_11 = None
add_11: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None
add_12: "f32[2, 3]" = 0 + l_t_
add_13: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None
add_14: "f32[2, 3]" = child + l_t_; child = l_t_ = None
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14)
class cond_fn_0(torch.nn.Module):
def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"):
@ -8627,8 +8625,8 @@ class GraphModule(torch.nn.Module):
class body_fn_0(torch.nn.Module):
def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"):
add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None
child: "f32[2, 3]" = child_1 + 1; child_1 = None
return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child)
add_1: "f32[2, 3]" = child_1 + 1; child_1 = None
return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, add_1)
""", # noqa: B950
)
@ -8763,8 +8761,8 @@ class GraphModule(torch.nn.Module):
add_3: "Sym(u8 + 1)" = unbacked_symint_7 + 1; unbacked_symint_7 = None
add_4: "Sym(u9 + 1)" = unbacked_symint_8 + 1; unbacked_symint_8 = None
child: "f32[s77, s27]" = child_2 + 1; child_2 = None
return (add, add_1, add_2, add_3, add_4, child)
add_5: "f32[s77, s27]" = child_2 + 1; child_2 = None
return (add, add_1, add_2, add_3, add_4, add_5)
""", # noqa: B950
)

View File

@ -899,14 +899,14 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
child: "f32[8]" = mul * 2; mul = None
return (child,)
mul_1: "f32[8]" = mul * 2; mul = None
return (mul_1,)
class subgraph_1(torch.nn.Module):
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
child: "f32[8]" = mul * 3; mul = None
return (child,)
mul_1: "f32[8]" = mul * 3; mul = None
return (mul_1,)
""",
)
@ -983,20 +983,20 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
subgraph_2 = self.subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
subgraph_3 = self.subgraph_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
subgraph_4 = self.subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (x_4,)
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (getitem_4,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
@ -1495,9 +1495,9 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
child: "f32[8, 8]" = l_x_ * 2
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (child, child_1)
mul: "f32[8, 8]" = l_x_ * 2
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (mul, mul_1)
""",
)
@ -2504,6 +2504,89 @@ class GraphModule(torch.nn.Module):
self.assertEqual(f(x, other), f_compile(x, other))
self.assertTrue(called)
def test_udf_output(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
@nested_compile_region
def gn(x, y):
a = torch.sin(x)
b = torch.cos(y)
return Foo(a, b)
def fn(x, y):
foo1 = gn(x, y)
foo2 = gn(foo1.a, y)
return foo1.b + foo2.a # + foo2.b
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
ref = fn(x, y)
res = opt_fn(x_clone, y_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
# High piority - grads are wrong
@unittest.expectedFailure
def test_grad_accuracy_check(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
# @nested_compile_region
# def gn(x):
# a = torch.sin(x)
# b = torch.cos(x)
# return Foo(a, b)
# def fn(x):
# foo1 = gn(x)
# foo2 = gn(foo1.a)
# return foo1.b + foo2.a + foo2.b
@nested_compile_region
def gn(x):
a = torch.sin(x)
b = torch.cos(x)
return (a, b)
def fn(x):
foo1 = gn(x)
foo2 = gn(foo1[0])
return foo1[1] + foo2[0] + foo2[1]
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
x.grad = None
x_clone.grad = None
ref = fn(x)
res = opt_fn(x_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@skipIfTorchDynamo("Not a torch._dynamo test")
@parameterized_class(

View File

@ -286,47 +286,31 @@ class GraphModule(torch.nn.Module):
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
subgraph_0 = self.subgraph_0
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
return (o_6,)
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
return (o_5,)
class subgraph_0(torch.nn.Module):
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
return (out,)
""",
return (out,)""",
ignore_empty_lines=True,
)

View File

@ -255,33 +255,3 @@ def wrap_inline_with_error_on_graph_break(
return fn(*args, **kwargs)
return wrapper
def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]:
"""
masks is a list of bools, where True means the corresponding element in tup
is a const value. Filter out the const values.
"""
out = []
for mask_idx, mask in enumerate(masks):
if not mask:
out.append(tup[mask_idx])
return tuple(out)
def insert_const_values_with_mask(
tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...]
) -> tuple[Any, ...]:
"""
masks and values are of same length. For indices where the mask is True, use
the const_values to fill in.
"""
out = []
idx = 0
for mask_idx, mask in enumerate(masks):
if mask:
out.append(values[mask_idx])
else:
out.append(tup[idx])
idx += 1
return tuple(out)

File diff suppressed because it is too large Load Diff