Revert " [while_loop][autograd] support autograd_key of while_loop (#160483)"

This reverts commit 2b8a83901c58a0858ea9e4ce00055f48e6ed164c.

Reverted https://github.com/pytorch/pytorch/pull/160483 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but some trunk tests are failing either from this PR or the previous one in the stack ([comment](https://github.com/pytorch/pytorch/pull/160483#issuecomment-3263597325))
This commit is contained in:
PyTorch MergeBot
2025-09-07 08:50:49 +00:00
parent ada43ed39c
commit 7a83cf430e
5 changed files with 117 additions and 591 deletions

View File

@ -394,14 +394,14 @@ def _while_loop_tests():
([torch.randn(3, 3)], {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}),
),
),
"int_carry": (int_carry, (torch.randn(2, 3),)),
"int_carry": (int_carry, (torch.randn(2, 3, requires_grad=True),)),
"pytree_int_carry": (
pytree_int_carry,
(torch.randn(2, 3),),
(torch.randn(2, 3, requires_grad=True),),
),
"const_and_symint_output": (
const_and_symint_output,
(torch.randn(2, 3),),
(torch.randn(2, 3, requires_grad=True),),
),
}
@ -5513,35 +5513,69 @@ def forward(self, arg0_1):
gm = backend.graphs[0]
if torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertExpectedInline(
normalize_gm(gm.print_readable(print_output=False)),
gm.code.strip(),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_iter_: "i64[]", L_x_: "f32[2, 2]", L_self_buffers_dec_: "i64[]", L_self_modules_linear_parameters_weight_: "f32[2, 2]", L_self_modules_linear_parameters_bias_: "f32[2]"):
l_iter_ = L_iter_
l_x_ = L_x_
l_self_buffers_dec_ = L_self_buffers_dec_
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
getitem: "i64[]" = while_loop[0]
getitem_1: "f32[2, 2]" = while_loop[1]; while_loop = None
return (getitem, getitem_1)
class cond_fn_0(torch.nn.Module):
def forward(self, child: "i64[]", child_1: "f32[2, 2]", l_self_buffers_dec__cond_fn: "i64[]", l_self_modules_linear_parameters_bias__body_fn: "f32[2]", l_self_modules_linear_parameters_weight__body_fn: "f32[2, 2]"):
sub: "i64[]" = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
gt: "b8[]" = sub > 0; sub = None
return gt
class body_fn_0(torch.nn.Module):
def forward(self, child_2: "i64[]", child_3: "f32[2, 2]", l_self_buffers_dec__cond_fn: "i64[]", l_self_modules_linear_parameters_bias__body_fn: "f32[2]", l_self_modules_linear_parameters_weight__body_fn: "f32[2, 2]"):
child: "i64[]" = child_2 - 1; child_2 = None
child_4: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
return (child, child_4)
""", # noqa: B950
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
l_iter_ = L_iter_
l_x_ = L_x_
l_self_buffers_dec_ = L_self_buffers_dec_
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
getitem = while_loop[0]
getitem_1 = while_loop[1]; while_loop = None
return (getitem, getitem_1)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_fn_0.code.strip(),
"""\
def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None
gt = sub > 0; sub = None
return gt""", # noqa: B950
)
self.assertExpectedInline(
gm.body_fn_0.code.strip(),
"""\
def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
child = child_2 - 1; child_2 = None
child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
return (child, child_4)""", # noqa: B950
)
else:
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
l_iter_ = L_iter_
l_x_ = L_x_
l__self___dec = self.L__self___dec
l__self___linear_weight = self.L__self___linear_weight
l__self___linear_bias = self.L__self___linear_bias
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
getitem = while_loop[0]
getitem_1 = while_loop[1]; while_loop = None
return (getitem, getitem_1)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
gt = sub > 0; sub = None
return gt""", # noqa: B950
)
self.assertExpectedInline(
gm.body_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
child = l_iter_ - 1; l_iter_ = None
child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
return (child, child_1)""", # noqa: B950
)
def test_while_loop_nested2_traced(self):
@ -8077,7 +8111,7 @@ class GraphModule(torch.nn.Module):
m, args = WHILE_LOOP_TESTS["pytree_int_carry"]
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None
ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes)
if strict and dynamic and not TEST_WITH_CROSSREF:
if strict and dynamic:
self.assertExpectedInline(
normalize_gm(ep.module().print_readable(print_output=False)),
"""\
@ -8235,154 +8269,6 @@ class GraphModule(torch.nn.Module):
self.assertEqual(compiled_out[1].size(0), 3)
self.assertEqual(compiled_out, mod(x))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_while_loop_autograd_simple(self):
backend = torch._dynamo.testing.AotEagerAndRecordGraphs()
class ModEager(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
while x.sum() < 2:
x = x * x + 1 + self.linear(x)
return x
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
def cond_fn(x):
return x.sum() < 2
def body_fn(x):
return x * x + 1 + self.linear(x)
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (x,))
x = torch.randn(3, 3, requires_grad=True)
x_clone = x.clone()
mod = Mod()
mod_eager = ModEager()
# Copy weights from mod to mod_eager
mod_eager.load_state_dict(mod.state_dict())
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(x)
exp_out = mod_eager(x_clone)
compiled_out.sum().backward()
exp_out.sum().backward()
self.assertEqual(compiled_out, exp_out)
eager_parameters = dict(mod_eager.named_parameters())
compiled_parameters = dict(mod.named_parameters())
for name, param in compiled_parameters.items():
self.assertEqual(param, eager_parameters[name])
self.assertEqual(param.grad, eager_parameters[name].grad)
self.assertEqual(
len(
backend.fw_graphs[0].graph.find_nodes(
op="call_function",
target=torch.ops.higher_order.while_loop_stack_output,
)
),
1,
)
self.assertEqual(
len(
backend.bw_graphs[0].graph.find_nodes(
op="call_function", target=torch.ops.higher_order.while_loop
)
),
1,
)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", primals_3: "f32[3]"):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop_stack_output = torch.ops.higher_order.while_loop_stack_output(while_loop_cond_graph_0, while_loop_body_graph_0, (primals_1,), (primals_3, primals_2)); while_loop_cond_graph_0 = while_loop_body_graph_0 = None
getitem: "f32[u2, 3, 3]" = while_loop_stack_output[0]; while_loop_stack_output = None
select: "f32[3, 3]" = torch.ops.aten.select.int(getitem, 0, -1)
unsqueeze: "f32[1, 3, 3]" = torch.ops.aten.unsqueeze.default(primals_1, 0); primals_1 = None
slice_1: "f32[u2 - 1, 3, 3]" = torch.ops.aten.slice.Tensor(getitem, 0, 0, -1); getitem = None
cat: "f32[u2, 3, 3]" = torch.ops.aten.cat.default([unsqueeze, slice_1]); unsqueeze = slice_1 = None
return (select, primals_2, primals_3, cat)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3, 3]"):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3, 3]"):
mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1)
add: "f32[3, 3]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
t: "f32[3, 3]" = torch.ops.aten.t.default(arg2_1); arg2_1 = None
addmm: "f32[3, 3]" = torch.ops.aten.addmm.default(arg1_1, arg0_1, t); arg1_1 = arg0_1 = t = None
add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, addmm); add = addmm = None
return (add_1,)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_2: "f32[3, 3]", primals_3: "f32[3]", cat: "f32[u2, 3, 3]", tangents_1: "f32[3, 3]"):
zeros: "i64[]" = torch.ops.aten.zeros.default([], dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
zeros_like: "f32[3]" = torch.ops.aten.zeros_like.default(primals_3, pin_memory = False)
zeros_like_1: "f32[3, 3]" = torch.ops.aten.zeros_like.default(primals_2, pin_memory = False)
while_loop_cond_graph_1 = self.while_loop_cond_graph_1
while_loop_body_graph_1 = self.while_loop_body_graph_1
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_1, while_loop_body_graph_1, (zeros, tangents_1, zeros_like, zeros_like_1), (cat, primals_3, primals_2)); while_loop_cond_graph_1 = while_loop_body_graph_1 = zeros = tangents_1 = zeros_like = zeros_like_1 = cat = primals_3 = primals_2 = None
getitem_2: "f32[3, 3]" = while_loop[1]
getitem_3: "f32[3]" = while_loop[2]
getitem_4: "f32[3, 3]" = while_loop[3]; while_loop = None
return (getitem_2, getitem_4, getitem_3)
class while_loop_cond_graph_1(torch.nn.Module):
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1: "f32[3, 3]", arg4_1: "f32[u2, 3, 3]", arg5_1: "f32[3]", arg6_1: "f32[3, 3]"):
sym_size_int_1: "Sym(u2)" = torch.ops.aten.sym_size.int(arg4_1, 0); arg4_1 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, sym_size_int_1); arg0_1 = sym_size_int_1 = None
return lt
class while_loop_body_graph_1(torch.nn.Module):
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1: "f32[3, 3]", arg4_1: "f32[u2, 3, 3]", arg5_1: "f32[3]", arg6_1: "f32[3, 3]"):
sym_size_int_1: "Sym(u2)" = torch.ops.aten.sym_size.int(arg4_1, 0)
rsub: "i64[]" = torch.ops.aten.rsub.Scalar(arg0_1, sym_size_int_1); sym_size_int_1 = None
sub_1: "i64[]" = torch.ops.aten.sub.Tensor(rsub, 1); rsub = None
_local_scalar_dense: "Sym(u9)" = torch.ops.aten._local_scalar_dense.default(sub_1); sub_1 = None
select: "f32[3, 3]" = torch.ops.aten.select.int(arg4_1, 0, _local_scalar_dense); arg4_1 = _local_scalar_dense = None
t: "f32[3, 3]" = torch.ops.aten.t.default(arg6_1); arg6_1 = None
t_1: "f32[3, 3]" = torch.ops.aten.t.default(t); t = None
mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg1_1, t_1); t_1 = None
t_2: "f32[3, 3]" = torch.ops.aten.t.default(arg1_1)
mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t_2, select); t_2 = None
t_3: "f32[3, 3]" = torch.ops.aten.t.default(mm_1); mm_1 = None
sum_1: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(arg1_1, [0], True)
view: "f32[3]" = torch.ops.aten.view.default(sum_1, [3]); sum_1 = None
t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None
mul_4: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select)
mul_5: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select); arg1_1 = select = None
add_7: "f32[3, 3]" = torch.ops.aten.add.Tensor(mm, mul_5); mm = mul_5 = None
add_8: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_7, mul_4); add_7 = mul_4 = None
add_9: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None
add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None
return (add_9, add_8, add_10, add_11)
""", # noqa: B950
)
def test_input_output_alias(self):
def fn(f, *args):
return torch.cond(args[0].sum() > 0, f, f, args)

View File

@ -2172,19 +2172,8 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
# mps doesn't support float64
@skipIfMPS
def test_while_loop_with_parameters(self):
inputs = (
torch.randn(
(
10,
20,
),
dtype=torch.float64,
device=self.device,
),
)
inputs = (torch.randn((10, 20), device=self.device),)
dim0_a = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"c": {},

View File

@ -804,12 +804,8 @@ class WhileLoopModels:
class InnerModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.layer1 = torch.nn.Linear(
20, 30, device=device, dtype=torch.float64
)
self.layer2 = torch.nn.Linear(
30, 20, device=device, dtype=torch.float64
)
self.layer1 = torch.nn.Linear(20, 30, device=device)
self.layer2 = torch.nn.Linear(30, 20, device=device)
def forward(self, c, x):
return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
@ -1029,7 +1025,7 @@ class WhileLoopModels:
e = torch.nonzero(b).size(0)
def cond_fn(c, a, b):
return c + d + e + a.shape[0] - b.shape[0] < 10
return d + e + a.shape[0] - b.shape[0] < 10
def body_fn(c, a, b):
return c + 1, a + e, b + d
@ -1112,32 +1108,31 @@ class WhileLoopModels:
class WhileLoopTests(TestCase):
def _run_test(
self, model, inputs, device, dynamic=False, num_counters=1, autograd=False
self,
model,
inputs,
device,
dynamic=False,
num_counters=1,
):
import torch.utils._pytree as pytree
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
import copy
if not autograd:
for p in model.parameters():
p.requires_grad_(False)
compiled_model = copy.deepcopy(model)
compiled_fn = torch.compile(backend=cnt, fullgraph=True)(compiled_model)
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
inputs = pytree.tree_map(lambda t: t.to(device=device), inputs)
input_sets = [inputs]
def mark_first_dim_dyn(inp):
torch._dynamo.mark_dynamic(inp, 0)
if dynamic:
def mark_first_dim_dyn(inp):
torch._dynamo.mark_dynamic(inp, 0)
pytree.tree_map(mark_first_dim_dyn, input_sets)
def tile_fn(inp):
# tile every first dim 5x
tiling = [5] + [1] * (inp.ndim - 1)
t = torch.tile(inp, tiling)
# mark every first dim as dynamic
torch._dynamo.mark_dynamic(inp, 0)
return t
larger_inputs = pytree.tree_map(tile_fn, inputs)
@ -1154,78 +1149,24 @@ class WhileLoopTests(TestCase):
)
unflat_inputs = pytree.tree_unflatten(flat, inp_spec)
inputs_with_counters = counters + unflat_inputs
def process_inputs(inp):
inp = inp.clone()
if dynamic:
mark_first_dim_dyn(inp)
if autograd and inp.dtype.is_floating_point:
inp.requires_grad_(True)
return inp
cloned_inputs = pytree.tree_map(process_inputs, inputs_with_counters)
cloned_inputs2 = pytree.tree_map(process_inputs, inputs_with_counters)
result = model(*cloned_inputs)
result_compiled = compiled_fn(*cloned_inputs2)
cloned_inputs = pytree.tree_map(
lambda t: t.clone(), inputs_with_counters
)
result = model(*inputs_with_counters)
with torch.no_grad():
result_compiled = compiled_model(*inputs_with_counters)
# inputs must not be mutated
torch.testing.assert_close(cloned_inputs, inputs_with_counters)
torch.testing.assert_close(
result, result_compiled, atol=1e-4, rtol=1e-4
)
if autograd and any(
pytree.tree_map_only(
torch.Tensor, lambda t: t.requires_grad, cloned_inputs
)
):
result_loss = loss_fn(pytree.tree_flatten(result)[0])
compiled_loss = loss_fn(pytree.tree_flatten(result_compiled)[0])
self.assertTrue(
not torch.isnan(result_loss) and not torch.isinf(compiled_loss)
)
self.assertTrue(
not torch.isnan(compiled_loss)
and not torch.isinf(compiled_loss)
)
self.assertEqual(result_loss, compiled_loss)
result_loss.backward()
compiled_loss.backward()
model_parameters = dict(model.named_parameters())
compiled_parameters = dict(compiled_model.named_parameters())
for name, param in model_parameters.items():
self.assertEqual(param, compiled_parameters[name])
self.assertEqual(
param.grad,
compiled_parameters[name].grad,
atol=1e-4,
rtol=1e-4,
)
for inp1, inp2 in zip(
pytree.tree_flatten(cloned_inputs)[0],
pytree.tree_flatten(cloned_inputs2)[0],
):
if inp1.requires_grad:
self.assertEqual(
inp1.grad,
inp2.grad,
atol=1e-4,
rtol=1e-4,
)
self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_simple_control_flow(self, device, dynamic, autograd):
def test_while_loop_simple_control_flow(self, device, dynamic):
# while_loop control flow without nesting
self._run_test(
model=WhileLoopModels.Simple(),
@ -1235,15 +1176,12 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_nested_control_flow(self, device, dynamic, autograd):
def test_while_loop_nested_control_flow(self, device, dynamic):
# while_loop control flow with nesting
self._run_test(
model=WhileLoopModels.Nested(),
@ -1254,15 +1192,12 @@ class WhileLoopTests(TestCase):
device=device,
dynamic=dynamic,
num_counters=2,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_outer_code(self, device, dynamic, autograd):
def test_while_loop_with_outer_code(self, device, dynamic):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterCode(),
@ -1272,22 +1207,18 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [False, True])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_parameters(self, device, dynamic, autograd):
def test_while_loop_with_parameters(self, device, dynamic):
# while_loop control flow with parameters
self._run_test(
model=WhileLoopModels.Parameters(device),
inputs=(torch.randn(10, 20, dtype=torch.float64),),
inputs=(torch.randn(10, 20),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -1295,9 +1226,7 @@ class WhileLoopTests(TestCase):
# dynamic=True doesn't work now due to
# https://github.com/pytorch/pytorch/issues/123596
@parametrize("dynamic", [False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_outer_buffers(self, device, dynamic, autograd):
def test_while_loop_with_outer_buffers(self, device, dynamic):
# while_loop control flow with outer code
self._run_test(
model=WhileLoopModels.OuterBuffers(),
@ -1307,15 +1236,13 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
# dynamic=True doesn't work due to we haven't handle lifted symbols
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd):
def test_while_loop_with_pytree_inputs(self, device, dynamic):
self._run_test(
model=WhileLoopModels.PytreeCarry(),
inputs=(
@ -1326,15 +1253,12 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd):
def test_while_loop_with_data_dependent_ops(self, device, dynamic):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
@ -1350,15 +1274,12 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd):
def test_while_loop_with_data_dependent_in_out(self, device, dynamic):
with torch._dynamo.config.patch(
{
"capture_dynamic_output_shape_ops": True,
@ -1375,7 +1296,6 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@parametrize("dynamic", [True, False])
@ -1436,8 +1356,7 @@ class WhileLoopTests(TestCase):
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
@parametrize("autograd", [False, True])
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd):
def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
self._run_test(
model=WhileLoopModels.UnbackedSymIntClosure(),
inputs=(
@ -1446,7 +1365,6 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -1481,11 +1399,10 @@ class WhileLoopTests(TestCase):
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd):
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
self._run_test(
model=WhileLoopModels.SymExprCond(),
inputs=(
@ -1494,27 +1411,22 @@ class WhileLoopTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [False, True])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_with_conv(self, device, dynamic, autograd):
def test_while_loop_with_conv(self, device, dynamic):
self._run_test(
model=WhileLoopModels.Conv(device),
inputs=(torch.randn(2, 4, 4, 4, dtype=torch.float64),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_while_loop_stack_output_simple(self, device, dynamic):
self._run_test(
model=WhileLoopModels.WhileLoopStackOutputSimple(device),
@ -2177,6 +2089,16 @@ class MapTests(TestCase):
self.assertEqual(result, result_compiled)
if autograd:
def loss_fn(result) -> torch.Tensor:
flat_results, _ = pytree.tree_flatten(result)
return sum(
[
torch.sqrt(torch.pow(res.sum() / res.max(), 2)).sum()
for res in flat_results
]
)
loss_fn(result).backward()
loss_fn(result_exp).backward()
loss_fn(result_compiled).backward()

View File

@ -12,9 +12,6 @@ from torch._higher_order_ops.utils import (
autograd_not_implemented,
check_input_alias_and_mutation_return_outputs,
check_meta_consistency,
fill_none_with_masks,
filter_with_masks,
materialize_as_graph,
reenter_make_fx,
validate_subgraph_args_types,
)
@ -329,16 +326,9 @@ def while_loop_dense(
return carried_vals
@while_loop_op.py_autograd_impl
def while_loop_autograd(cond_fn, body_fn, operands, additional_inputs):
return WhileLoopAutogradOp.apply(
cond_fn,
body_fn,
len(operands),
len(additional_inputs),
*operands,
*additional_inputs,
)
while_loop_op.py_autograd_impl(
autograd_not_implemented(while_loop_op, deferred_error=True)
)
def _find_or_create_fake_mode() -> FakeTensorMode:
@ -644,268 +634,6 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
# Note [while_loop autograd]
# Consider wthe following while_loop that can be visualized as:
# additional_inputs
# ┌─────┬─────┼─────┬─────┐
# | | | | |
# ↓ ↓ ↓ ↓ ↓
# x ──→ y0 ─→ y1 ─→ y2 ─→ y3 ─→ y4
#
# The bacwkard can be visualized as follows:
#
# g_additional_inputs
# ┌──────┬──────┼──────┬──────┐
# | | | | |
# | | | | |
# gx <── gy0 <─ gy1 <─ gy2 <─ gy3 <─ gy4
#
# We can compute gx using chain rule:
#
# gx = gy0 * bw(y0, x),
#
# where gy0 denotes the graident of loss with respect to y0, and bw(y0, x) denotes the graident of y0 with
# respect to x. Note that bw can be computed from forward body_fn easily using torch.autograd.grad.
# We could substitute the unknowns gy0, gy1, ..., with chain rule until gy4:
#
# gx = gy1 * bw(y1, y0) * bw(y0, x)
# = gy2 * bw(y2, y1) * bw(y1, y0) * bw(y0, x)
# = ...
# = gy4 * bw(y4, y3) * bw(y3, y2) * bw(y2, y1) * bw(y1, y0) * bw(y0, x)
#
# since gy4 is the graient of the final output, which is given as the backward input, we've got a formula
# to compute gx. A abbr for the formula is: gy4 * bw43210x
#
# In a similar way, we can compute g_additional_inputs using chain rule:
#
# g_additional_inputs = gy0 * bw(y0, addi) + gy1 * bw(y1, addi) + gy2 * bw(y2, addi) + ... + gy4 * bw(y4, addi)
#
# Notice that gy0 = gy4 * bw43210, gy1 = gy4 * bw4321 etc, we now also get a formula for g_additional_inputs.
#
# Implementation:
# The idea of implementation is to construct a while_loop to calculate both gx and g_additional_inputs.
# Specifically, we can implement the backward of while_loop with as follows:
#
# def cond_fn(idx, grad_carries, grad_additional_inputs, fw_additional_inputs, fw_inps):
# return idx < fw_inps.size(0)
#
# def body_fn(idx, grad_carries, grad_additional_inputs, fw_additional_inputs, fw_inps):
# reversed_idx = fw_inps.size(0) - 1 - idx
# next_grad_carry, next_grad_additional_inputs = bw(fw_inps[reversed_idx], fw_additional_inputs, grad_carries)
# return idx + 1, next_grad_carry, next_grad_additional_inputs + grad_additional_inputs
#
# idx = 0
# init_grad_carries = grads
# init_grad_additional_inputs = torch.zeros_like(g_additioanl_inputs)
# fw_inps = torch.cat([ctx.fw_carried_inputs, fw_outputs[:-1]])
# while_loop(cond_fn, body_fn, (idx, init_grad_carries, init_grad_additional_inputs,), (fw_additional_inputs, fw_inps))
class WhileLoopAutogradOp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
cond_fn,
body_fn,
num_carried_inputs,
num_additional_inputs,
*carries_and_inputs,
):
from torch._higher_order_ops.scan import split_into_chunks
carries, additional_inputs = split_into_chunks(
carries_and_inputs, [num_carried_inputs, num_additional_inputs]
)
with torch._C._AutoDispatchBelowAutograd():
fw_outputs = while_loop_stack_output_op(
cond_fn, body_fn, carries, additional_inputs
)
assert not hasattr(ctx, "fw_cond_fn")
assert not hasattr(ctx, "fw_body_fn")
assert not hasattr(ctx, "carries")
assert not hasattr(ctx, "additional_inputs")
assert not hasattr(ctx, "fw_outputs")
ctx.fw_cond_fn = cond_fn
ctx.fw_body_fn = body_fn
ctx.carries = carries
ctx.additional_inputs = additional_inputs
ctx.fw_outputs = fw_outputs
loop_count = None
for out in fw_outputs:
if isinstance(out, torch.Tensor):
if loop_count is not None:
assert out.size(0) == loop_count
else:
loop_count = out.size(0)
assert loop_count is not None
# Remove the loop_count from pending_fresh_unbacked_symbols
# because it's not part of forward output and it's impossible
# to bind it to a proxy in forward graph anyways.
if (
isinstance(loop_count, torch.SymInt)
and (shape_env := loop_count.node.shape_env)
and loop_count in shape_env.pending_fresh_unbacked_symbols
):
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
# Even when body function is not executed, we clone and unsqueeze the input
# to avoid the aliasing, therefore loop_count is always >= 1
torch._check(loop_count >= 1)
# We snapshot the dispatch keys in forward for materializing the
# the bw_graph in backward.
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
assert len(fw_outputs) > 0, "fw_outputs shouldn't be empty"
# Only the last of the output fw_outputs need to be returned
return tuple(ckp[-1] for ckp in fw_outputs)
@staticmethod
def backward(ctx, *grads):
from torch._higher_order_ops.cond import create_bw_fn
from torch._higher_order_ops.scan import split_into_chunks
# set up single step bw fn
bw_body_fn = create_bw_fn(ctx.fw_body_fn, ctx.carries + ctx.additional_inputs)
# Note [Handle inputs that're not differentiable]
# When a forward input is non-differentiable e.g. a symint or an integer tensor, their gradients
# will be None. However, we don't want to return None in the subgraph because this complicates the
# inductor codegen, where we need to do a non-unform treatment for None and tensors.
# So we set up masks and filter the None gradients so that only tensors are returned from each step.
carries_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
for t in ctx.carries
]
additional_inputs_tensor_masks = [
True if isinstance(t, torch.Tensor) and t.dtype.is_floating_point else False
for t in ctx.additional_inputs
]
init_idx = torch.zeros((), dtype=torch.int64)
init_grad_carries = filter_with_masks(grads, carries_tensor_masks) # type: ignore[arg-type]
init_grad_additional_inputs = tuple(
torch.zeros_like(t)
for need_keep, t in zip(
additional_inputs_tensor_masks, ctx.additional_inputs
)
if need_keep
)
# We need to the forward inputs to each iteration to compute the backward
# which is the concatenation of first iteraiton input i.e. ctx.carries and all iterations's
# output except the last iteration.
fw_carries = [
torch.cat([carry.unsqueeze(0), carries[:-1]])
for carry, carries in zip(ctx.carries, ctx.fw_outputs)
]
for fw_carry, carry in zip(fw_carries, ctx.carries):
fw_carry.requires_grad_(carry.requires_grad)
_, spec = pytree.tree_flatten(
(
init_idx,
init_grad_carries,
init_grad_additional_inputs,
ctx.fw_outputs,
ctx.additional_inputs,
)
)
def cond_fn(*flat_args):
(
idx,
grad_carries,
grad_additional_inputs,
fw_carries,
additional_inputs,
) = pytree.tree_unflatten(flat_args, spec)
assert isinstance(fw_carries[0], torch.Tensor), fw_carries[0]
# excluding the last iteration's output
return idx < fw_carries[0].size(0)
def body_fn(*flat_args):
(
idx,
grad_carries,
grad_additional_inputs,
fw_carries,
additional_inputs,
) = pytree.tree_unflatten(flat_args, spec)
reversed_idx = fw_carries[0].size(0) - idx - 1
selected_fw_carries = [
ckp.select(0, reversed_idx.item()) for ckp in fw_carries
]
cur_grad_carries, cur_grad_additional_inputs = split_into_chunks(
bw_body_fn(*selected_fw_carries, *additional_inputs, *grad_carries),
[len(ctx.carries), len(ctx.additional_inputs)],
)
assert all(isinstance(t, torch.Tensor) for t in cur_grad_carries)
cur_grad_carries_tensors = filter_with_masks(
cur_grad_carries, carries_tensor_masks
)
cur_grad_additional_inputs_tensors = filter_with_masks(
cur_grad_additional_inputs, additional_inputs_tensor_masks
)
return (
idx + 1,
*cur_grad_carries_tensors,
*(
cur_grad + grad
for cur_grad, grad in zip(
cur_grad_additional_inputs_tensors, grad_additional_inputs
)
),
)
args_single_step_bw = (
init_idx,
*init_grad_carries,
*init_grad_additional_inputs,
*fw_carries,
*ctx.additional_inputs,
)
cond_gm = materialize_as_graph(
cond_fn,
args_single_step_bw,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
body_gm = materialize_as_graph(
body_fn,
args_single_step_bw,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
_, final_grad_carries, final_grad_additional_inputs = split_into_chunks(
while_loop_op(
cond_gm,
body_gm,
(
init_idx,
*init_grad_carries,
*init_grad_additional_inputs,
),
(*fw_carries, *ctx.additional_inputs),
),
[1, len(init_grad_carries), len(init_grad_additional_inputs)],
)
return (
None,
None,
None,
None,
*fill_none_with_masks(final_grad_carries, carries_tensor_masks),
*fill_none_with_masks(
final_grad_additional_inputs, additional_inputs_tensor_masks
),
)
while_loop_stack_output_op = WhileLoopStackOutputOp()
while_loop_stack_output_op.py_impl(DispatchKey.CompositeExplicitAutograd)(

View File

@ -2345,6 +2345,7 @@ class _MakefxTracer:
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
t.recompile()
# TODO: kind of a bad way to do it, should maybe figure out a better way
if self.tracing_mode == "symbolic":
assert self.fake_tensor_mode is not None