diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 1f345706c27c..e2c14302034a 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -737,8 +737,7 @@ def forward(self, pred_1, x_1): getitem_1 = cond_1[0]; getitem_1 = None getitem_2 = cond_1[1] getitem_3 = cond_1[2]; getitem_3 = None - getitem_4 = cond_1[3]; getitem_4 = None - getitem_5 = cond_1[4]; cond_1 = getitem_5 = None + getitem_4 = cond_1[3]; cond_1 = getitem_4 = None return (getitem_2,)""", # noqa: B950 ) @@ -854,10 +853,7 @@ def forward(self, pred_1, a_1, b_1, c_1): cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None getitem_1 = cond_1[0] getitem_2 = cond_1[1] - getitem_3 = cond_1[2]; getitem_3 = None - getitem_4 = cond_1[3]; getitem_4 = None - getitem_5 = cond_1[4]; getitem_5 = None - getitem_6 = cond_1[5]; cond_1 = getitem_6 = None + getitem_3 = cond_1[2]; cond_1 = getitem_3 = None return (getitem_1, getitem_2)""", # noqa: B950 ) # Forward @@ -877,7 +873,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1): clone = torch.ops.aten.clone.default(arg6_1) clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None - return [clone, clone_1, None, None, zeros_like, None]""", + return [clone, clone_1, zeros_like]""", ) def test_cond_autograd_pytree_input(self): @@ -1302,15 +1298,11 @@ def forward(self, pred_1, x_1): return cond_outputs, cond_inputs - # TODO: The compile_mode = `compile_dynamic_shape` raises the Error - # torch._inductor.exc.LoweringException: NotImplementedError: get_size() is not - # implemented by ! @skipIfTorchDynamo("don't test compile on compile") @unittest.skipIf(not SM70OrLater, "triton") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @parametrize("compile_mode", ["compile_dynamic_shape"]) @parametrize("scalar", [False]) - @unittest.expectedFailure def test_cond_autograd_zeros_unused_branch_complex_compile_fail( self, compile_mode, scalar ): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 511b9cea5e14..a2bdfc9c4ea9 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -5,6 +5,7 @@ import unittest import torch import torch._dynamo.testing +import torch.utils._pytree as pytree from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.map import _fake_map from torch._higher_order_ops.scan import _fake_scan, scan @@ -37,6 +38,24 @@ def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)): return _prepend_product_of_values(inputs, counter_values, num_counters) +# a testing loss_fn +def loss_fn(result) -> torch.Tensor: + flat_results, _ = pytree.tree_flatten(result) + total_loss = torch.tensor( + 0.0, device=flat_results[0].device if flat_results else torch.device("cpu") + ) + + for res in flat_results: + # Convert to float if integer tensor to avoid numerical issues + if not res.dtype.is_floating_point: + res = res.float() + + # Simple robust loss: abs values + small constant to avoid inf/nan + total_loss = total_loss + (torch.abs(res) / (1.0 + torch.abs(res))).sum() + + return total_loss + + class CondModels: class Simple(torch.nn.Module): def forward(self, p, a, b): @@ -1036,8 +1055,6 @@ class WhileLoopTests(TestCase): dynamic=False, num_counters=1, ): - import torch.utils._pytree as pytree - cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") compiled_model = torch.compile(backend=cnt, fullgraph=True)(model) @@ -1566,8 +1583,6 @@ class ScanModels: def forward(self, scan_op, _input, weight, bias): def combine_fn(carry, x): - from torch.utils import _pytree as pytree - new_carry = { "param": carry["param"] @ x + carry["bias"], "bias": carry["bias"].sin(), @@ -1977,51 +1992,88 @@ class MapTests(TestCase): inputs, device, dynamic=False, + autograd=False, ): - cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") - compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)( - model - ) + import copy inputs = [inp.to(device=device) for inp in inputs] model = model.to(device=device) + model_eager = copy.deepcopy(model) + model_compiled = copy.deepcopy(model) + cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") + compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)( + model_compiled + ) + + if autograd: + pytree.tree_map_only(torch.Tensor, lambda t: t.requires_grad_(True), inputs) + cloned_inputs = [inp.clone() for inp in inputs] result = model(torch._higher_order_ops.map, *cloned_inputs) - result_exp = model(_fake_map, *cloned_inputs) + result_exp = model_eager(_fake_map, *cloned_inputs) result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs) self.assertEqual(result, result_exp) self.assertEqual(result, result_compiled) + if autograd: + + def loss_fn(result) -> torch.Tensor: + flat_results, _ = pytree.tree_flatten(result) + return sum( + [ + torch.sqrt(torch.pow(res.sum() / res.max(), 2)).sum() + for res in flat_results + ] + ) + + loss_fn(result).backward() + loss_fn(result_exp).backward() + loss_fn(result_compiled).backward() + + model_params = dict(model.named_parameters()) + model_eager_params = dict(model_eager.named_parameters()) + model_compiled_params = dict(model_compiled.named_parameters()) + for name, param in model_eager_params.items(): + self.assertEqual(param, model_params[name]) + self.assertEqual(param, model_compiled_params[name]) + self.assertEqual(param.grad, model_params[name].grad) + self.assertEqual(param.grad, model_compiled_params[name].grad) + @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) + @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_map_simple(self, device, dynamic): + def test_map_simple(self, device, dynamic, autograd): self._run_test( model=MapModels.Simple(), inputs=(torch.randn(3, 4),), device=device, dynamic=dynamic, + autograd=autograd, ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) + @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_map_simple_linear_with_view(self, device, dynamic): + def test_map_simple_linear_with_view(self, device, dynamic, autograd): self._run_test( model=MapModels.SimpleWithLinearWithView(), inputs=(torch.randn(3, 4),), device=device, dynamic=dynamic, + autograd=autograd, ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) + @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_map_pytree_in_out(self, device, dynamic): + def test_map_pytree_in_out(self, device, dynamic, autograd): self._run_test( model=MapModels.PytreeInOut(), inputs=( @@ -2031,13 +2083,15 @@ class MapTests(TestCase): ), device=device, dynamic=dynamic, + autograd=autograd, ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) + @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_map_nested_with_cond(self, device, dynamic): + def test_map_nested_with_cond(self, device, dynamic, autograd): self._run_test( model=MapModels.NestedWithCond(), inputs=( @@ -2047,6 +2101,7 @@ class MapTests(TestCase): ), device=device, dynamic=dynamic, + autograd=autograd, ) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index a0175371cc9d..7c13b9a0fd14 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import contextlib +import functools import logging import warnings from typing import Any, Callable, Optional, Union @@ -20,6 +21,8 @@ from torch._higher_order_ops.utils import ( _set_compilation_env, check_input_alias_and_mutation_return_outputs, create_bw_fn, + fill_none_with_masks, + filter_with_masks, materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, @@ -342,15 +345,32 @@ class CondAutogradOp(torch.autograd.Function): args = operands + flat_grads # TODO: we need to materialize the bw graphs because dynamo is unable to # trace through the joint function when torch.compile torch.autograd.grad. + + grads_tensor_masks = [] + + def create_fn_remove_none(fn): + @functools.wraps(fn) + def wrapped(*args): + nonlocal grads_tensor_masks + + true_outputs = fn(*args) + grads_tensor_masks = [ + True if isinstance(out, torch.Tensor) else False + for out in true_outputs + ] + return filter_with_masks(true_outputs, grads_tensor_masks) + + return wrapped + true_bw_gm = materialize_as_graph( - ctx._true_bw_fn, + create_fn_remove_none(ctx._true_bw_fn), args, ctx._fw_include_key_set, ctx._fw_exclude_key_set, force_enable_grad=True, ) false_bw_gm = materialize_as_graph( - ctx._false_bw_fn, + create_fn_remove_none(ctx._false_bw_fn), args, ctx._fw_include_key_set, ctx._fw_exclude_key_set, @@ -362,7 +382,7 @@ class CondAutogradOp(torch.autograd.Function): false_bw_gm, args, ) - return None, None, None, *grads + return None, None, None, *fill_none_with_masks(grads, grads_tensor_masks) # Note: diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 332bde7e464f..57d2cd3cb900 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -22,6 +22,8 @@ from .utils import ( _stack_pytree, _unstack_pytree, create_bw_fn, + fill_none_with_masks, + filter_with_masks, materialize_as_graph, save_tensors_and_symints_for_backward, saved_tensors_and_symints, @@ -154,8 +156,12 @@ class MapAutogradOp(torch.autograd.Function): bw_f = create_bw_fn(ctx._f, fw_args) + grads_tensor_masks = [] + # Create a wrapper around thefor the bw_f def bw_f_wrapper(*args): + nonlocal grads_tensor_masks + # Dissect args and re-order them for the ``ctx._bw_f`` # args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args] # The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads @@ -165,7 +171,11 @@ class MapAutogradOp(torch.autograd.Function): args, [num_mapped_args, num_grads, num_pos_args] ) bw_f_primals = *fw_m_args, *pos_args - return bw_f(*bw_f_primals, *bw_f_tangents) + gradients = bw_f(*bw_f_primals, *bw_f_tangents) + grads_tensor_masks = [ + True if isinstance(out, torch.Tensor) else out for out in gradients + ] + return filter_with_masks(gradients, grads_tensor_masks) def construct_args_single_step_bw(): unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args) @@ -194,7 +204,7 @@ class MapAutogradOp(torch.autograd.Function): grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args) - return None, None, *grads + return None, None, *fill_none_with_masks(grads, grads_tensor_masks) def trace_map(proxy_mode, func_overload, f, xs, pos_args): diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index b41e19b7177b..4b1a8a272cd8 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1217,3 +1217,13 @@ def _has_gen_schema(op: HigherOrderOperator): return hasattr(type(op), method) and getattr(type(op), method) is not getattr( HigherOrderOperator, method ) + + +def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]): + assert len(data) == len(masks) + return [item for item, keep in zip(data, masks) if keep] + + +def fill_none_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]): + data_iter = iter(data) + return [next(data_iter) if kept else None for kept in masks]