[map] filter none gradients and add autograd inductor tests (#160548)

Will filter the none outputs in autograd backward for other hops as follow ups

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160548
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2025-08-14 16:49:22 -07:00
committed by PyTorch MergeBot
parent fa75ba9303
commit ff86509a06
5 changed files with 116 additions and 29 deletions

View File

@ -737,8 +737,7 @@ def forward(self, pred_1, x_1):
getitem_1 = cond_1[0]; getitem_1 = None
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; getitem_4 = None
getitem_5 = cond_1[4]; cond_1 = getitem_5 = None
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
return (getitem_2,)""", # noqa: B950
)
@ -854,10 +853,7 @@ def forward(self, pred_1, a_1, b_1, c_1):
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; getitem_4 = None
getitem_5 = cond_1[4]; getitem_5 = None
getitem_6 = cond_1[5]; cond_1 = getitem_6 = None
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
# Forward
@ -877,7 +873,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
clone = torch.ops.aten.clone.default(arg6_1)
clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
return [clone, clone_1, None, None, zeros_like, None]""",
return [clone, clone_1, zeros_like]""",
)
def test_cond_autograd_pytree_input(self):
@ -1302,15 +1298,11 @@ def forward(self, pred_1, x_1):
return cond_outputs, cond_inputs
# TODO: The compile_mode = `compile_dynamic_shape` raises the Error
# torch._inductor.exc.LoweringException: NotImplementedError: get_size() is not
# implemented by <class 'torch._inductor.ir.NoneAsConstantBuffer'>!
@skipIfTorchDynamo("don't test compile on compile")
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("compile_mode", ["compile_dynamic_shape"])
@parametrize("scalar", [False])
@unittest.expectedFailure
def test_cond_autograd_zeros_unused_branch_complex_compile_fail(
self, compile_mode, scalar
):

View File

@ -5,6 +5,7 @@ import unittest
import torch
import torch._dynamo.testing
import torch.utils._pytree as pytree
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.map import _fake_map
from torch._higher_order_ops.scan import _fake_scan, scan
@ -37,6 +38,24 @@ def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
return _prepend_product_of_values(inputs, counter_values, num_counters)
# a testing loss_fn
def loss_fn(result) -> torch.Tensor:
flat_results, _ = pytree.tree_flatten(result)
total_loss = torch.tensor(
0.0, device=flat_results[0].device if flat_results else torch.device("cpu")
)
for res in flat_results:
# Convert to float if integer tensor to avoid numerical issues
if not res.dtype.is_floating_point:
res = res.float()
# Simple robust loss: abs values + small constant to avoid inf/nan
total_loss = total_loss + (torch.abs(res) / (1.0 + torch.abs(res))).sum()
return total_loss
class CondModels:
class Simple(torch.nn.Module):
def forward(self, p, a, b):
@ -1036,8 +1055,6 @@ class WhileLoopTests(TestCase):
dynamic=False,
num_counters=1,
):
import torch.utils._pytree as pytree
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
@ -1566,8 +1583,6 @@ class ScanModels:
def forward(self, scan_op, _input, weight, bias):
def combine_fn(carry, x):
from torch.utils import _pytree as pytree
new_carry = {
"param": carry["param"] @ x + carry["bias"],
"bias": carry["bias"].sin(),
@ -1977,51 +1992,88 @@ class MapTests(TestCase):
inputs,
device,
dynamic=False,
autograd=False,
):
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
model
)
import copy
inputs = [inp.to(device=device) for inp in inputs]
model = model.to(device=device)
model_eager = copy.deepcopy(model)
model_compiled = copy.deepcopy(model)
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
model_compiled
)
if autograd:
pytree.tree_map_only(torch.Tensor, lambda t: t.requires_grad_(True), inputs)
cloned_inputs = [inp.clone() for inp in inputs]
result = model(torch._higher_order_ops.map, *cloned_inputs)
result_exp = model(_fake_map, *cloned_inputs)
result_exp = model_eager(_fake_map, *cloned_inputs)
result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs)
self.assertEqual(result, result_exp)
self.assertEqual(result, result_compiled)
if autograd:
def loss_fn(result) -> torch.Tensor:
flat_results, _ = pytree.tree_flatten(result)
return sum(
[
torch.sqrt(torch.pow(res.sum() / res.max(), 2)).sum()
for res in flat_results
]
)
loss_fn(result).backward()
loss_fn(result_exp).backward()
loss_fn(result_compiled).backward()
model_params = dict(model.named_parameters())
model_eager_params = dict(model_eager.named_parameters())
model_compiled_params = dict(model_compiled.named_parameters())
for name, param in model_eager_params.items():
self.assertEqual(param, model_params[name])
self.assertEqual(param, model_compiled_params[name])
self.assertEqual(param.grad, model_params[name].grad)
self.assertEqual(param.grad, model_compiled_params[name].grad)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_simple(self, device, dynamic):
def test_map_simple(self, device, dynamic, autograd):
self._run_test(
model=MapModels.Simple(),
inputs=(torch.randn(3, 4),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_simple_linear_with_view(self, device, dynamic):
def test_map_simple_linear_with_view(self, device, dynamic, autograd):
self._run_test(
model=MapModels.SimpleWithLinearWithView(),
inputs=(torch.randn(3, 4),),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_pytree_in_out(self, device, dynamic):
def test_map_pytree_in_out(self, device, dynamic, autograd):
self._run_test(
model=MapModels.PytreeInOut(),
inputs=(
@ -2031,13 +2083,15 @@ class MapTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_map_nested_with_cond(self, device, dynamic):
def test_map_nested_with_cond(self, device, dynamic, autograd):
self._run_test(
model=MapModels.NestedWithCond(),
inputs=(
@ -2047,6 +2101,7 @@ class MapTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import contextlib
import functools
import logging
import warnings
from typing import Any, Callable, Optional, Union
@ -20,6 +21,8 @@ from torch._higher_order_ops.utils import (
_set_compilation_env,
check_input_alias_and_mutation_return_outputs,
create_bw_fn,
fill_none_with_masks,
filter_with_masks,
materialize_as_graph,
reenter_make_fx,
save_tensors_and_symints_for_backward,
@ -342,15 +345,32 @@ class CondAutogradOp(torch.autograd.Function):
args = operands + flat_grads
# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint function when torch.compile torch.autograd.grad.
grads_tensor_masks = []
def create_fn_remove_none(fn):
@functools.wraps(fn)
def wrapped(*args):
nonlocal grads_tensor_masks
true_outputs = fn(*args)
grads_tensor_masks = [
True if isinstance(out, torch.Tensor) else False
for out in true_outputs
]
return filter_with_masks(true_outputs, grads_tensor_masks)
return wrapped
true_bw_gm = materialize_as_graph(
ctx._true_bw_fn,
create_fn_remove_none(ctx._true_bw_fn),
args,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
false_bw_gm = materialize_as_graph(
ctx._false_bw_fn,
create_fn_remove_none(ctx._false_bw_fn),
args,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
@ -362,7 +382,7 @@ class CondAutogradOp(torch.autograd.Function):
false_bw_gm,
args,
)
return None, None, None, *grads
return None, None, None, *fill_none_with_masks(grads, grads_tensor_masks)
# Note:

View File

@ -22,6 +22,8 @@ from .utils import (
_stack_pytree,
_unstack_pytree,
create_bw_fn,
fill_none_with_masks,
filter_with_masks,
materialize_as_graph,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
@ -154,8 +156,12 @@ class MapAutogradOp(torch.autograd.Function):
bw_f = create_bw_fn(ctx._f, fw_args)
grads_tensor_masks = []
# Create a wrapper around thefor the bw_f
def bw_f_wrapper(*args):
nonlocal grads_tensor_masks
# Dissect args and re-order them for the ``ctx._bw_f``
# args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args]
# The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads
@ -165,7 +171,11 @@ class MapAutogradOp(torch.autograd.Function):
args, [num_mapped_args, num_grads, num_pos_args]
)
bw_f_primals = *fw_m_args, *pos_args
return bw_f(*bw_f_primals, *bw_f_tangents)
gradients = bw_f(*bw_f_primals, *bw_f_tangents)
grads_tensor_masks = [
True if isinstance(out, torch.Tensor) else out for out in gradients
]
return filter_with_masks(gradients, grads_tensor_masks)
def construct_args_single_step_bw():
unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args)
@ -194,7 +204,7 @@ class MapAutogradOp(torch.autograd.Function):
grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args)
return None, None, *grads
return None, None, *fill_none_with_masks(grads, grads_tensor_masks)
def trace_map(proxy_mode, func_overload, f, xs, pos_args):

View File

@ -1217,3 +1217,13 @@ def _has_gen_schema(op: HigherOrderOperator):
return hasattr(type(op), method) and getattr(type(op), method) is not getattr(
HigherOrderOperator, method
)
def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
assert len(data) == len(masks)
return [item for item, keep in zip(data, masks) if keep]
def fill_none_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
data_iter = iter(data)
return [next(data_iter) if kept else None for kept in masks]