[scan] materialize combine_fn in forward add more autograd tests (#161732)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161732
Approved by: https://github.com/zou3519
ghstack dependencies: #161557, #161664, #161808, #162025
This commit is contained in:
Yidi Wu
2025-09-26 16:47:26 -07:00
committed by PyTorch MergeBot
parent b85bee3bbb
commit 3413490f53
5 changed files with 136 additions and 72 deletions

View File

@ -1661,7 +1661,7 @@ class ScanModels:
super().__init__()
self.reverse = reverse
self.dim = dim
self.linear = torch.nn.Linear(4, 4)
self.linear = torch.nn.Linear(4, 4, dtype=torch.float64)
def forward(self, scan_op, init, xs):
def combine_fn(carry, x):
@ -1893,7 +1893,10 @@ class ScanTests(TestCase):
):
import copy
inputs = [inp.requires_grad_(autograd) for inp in inputs]
inputs = [
inp.requires_grad_(autograd) if inp.dtype.is_floating_point else inp
for inp in inputs
]
inputs = [inp.to(device=device) for inp in inputs]
model = model.to(device=device)
for p in model.parameters():
@ -1903,7 +1906,8 @@ class ScanTests(TestCase):
model2 = copy.deepcopy(model)
model3 = copy.deepcopy(model)
model4 = copy.deepcopy(model)
torch.compile(fullgraph=True, dynamic=dynamic)(model)
model3.compile(fullgraph=True, dynamic=dynamic)
model4.compile(fullgraph=True, dynamic=dynamic)
def _run_model(model, inputs):
cloned_inputs = [
@ -1928,11 +1932,9 @@ class ScanTests(TestCase):
result_exp = _run_model(model1, [_fake_scan] + inputs)
result_eager = _run_model(model2, [scan] + inputs)
result_compiled = _run_model(
torch.compile(fullgraph=True, dynamic=dynamic)(model3), [scan] + inputs
)
result_compiled = _run_model(model3, [scan] + inputs)
result_compiled_exp = _run_model(
torch.compile(fullgraph=True, dynamic=dynamic)(model4),
model4,
[_fake_scan] + inputs,
)
@ -1958,8 +1960,9 @@ class ScanTests(TestCase):
@parametrize("dynamic", [True, False])
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 2])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim):
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim, autograd):
self._run_test(
model=ScanModels.SimpleWithPytreeInOuts(reverse=reverse, dim=dim),
inputs=(
@ -1969,6 +1972,7 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -1977,10 +1981,13 @@ class ScanTests(TestCase):
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_nn_modules(self, device, dynamic, reverse, dim, scan_length):
init = torch.randn(20, 16, 4, 4)
xs = torch.randn(scan_length, 20, 16, 4, 4)
def test_scan_nn_modules(
self, device, dynamic, reverse, dim, scan_length, autograd
):
init = torch.randn(20, 16, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 20, 16, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.ScanLinearWithView(reverse=reverse, dim=dim),
@ -1990,6 +1997,7 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -1998,8 +2006,9 @@ class ScanTests(TestCase):
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length):
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length, autograd):
init = torch.randn(2, 4, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 2, 4, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
@ -2011,6 +2020,7 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -2020,10 +2030,13 @@ class ScanTests(TestCase):
@parametrize("dim", [0, 1, 3])
@parametrize("pred", [True, False])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_in_cond(self, device, dynamic, reverse, dim, pred, scan_length):
init = torch.randn(4, 4, 4)
xs = torch.randn(scan_length, 4, 4, 4)
def test_scan_in_cond(
self, device, dynamic, reverse, dim, pred, scan_length, autograd
):
init = torch.randn(4, 4, 4, dtype=torch.float64)
xs = torch.randn(scan_length, 4, 4, 4, dtype=torch.float64)
xs = xs.movedim(0, dim)
self._run_test(
model=ScanModels.ScanInCond(reverse=reverse, dim=dim),
@ -2034,6 +2047,7 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -2042,8 +2056,9 @@ class ScanTests(TestCase):
@parametrize("reverse", [True, False])
@parametrize("dim", [0, 1, 3])
@parametrize("scan_length", [1, 5])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length):
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length, autograd):
init = torch.randn(2, 4, 4, 4)
xs = torch.randn(scan_length, 4, 4, 4)
xs = xs.movedim(0, dim)
@ -2055,13 +2070,15 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_chunked_ce(self, device, dynamic):
def test_scan_chunked_ce(self, device, dynamic, autograd):
self._run_test(
model=ScanModels.ChunkedCE(10),
inputs=(
@ -2072,6 +2089,7 @@ class ScanTests(TestCase):
),
device=device,
dynamic=dynamic,
autograd=autograd,
)
@requires_gpu
@ -2095,8 +2113,9 @@ class ScanTests(TestCase):
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@parametrize("autograd", [True, False])
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_scan_with_clamp(self, device, dynamic):
def test_scan_with_clamp(self, device, dynamic, autograd):
B = 4
T = 8
H = 16
@ -2104,10 +2123,11 @@ class ScanTests(TestCase):
model=ScanModels.ScanWithClamp(),
inputs=(
torch.randn((B, H)),
torch.randn((T, B, H), requires_grad=True),
torch.randn((T, B, H)),
),
device=device,
dynamic=dynamic,
autograd=autograd,
)

View File

@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
check_input_alias_and_mutation_return_outputs,
check_meta_consistency,
create_bw_fn,
filter_with_masks,
first_slice_copy,
first_slice_copy_with_grad,
get_tensor_mask,
@ -591,11 +592,20 @@ class ScanAutogradOp(torch.autograd.Function):
*y,
]
# Materialize the ``combine_fn_with_carry_checkpoint`` with enable_grad
# we need enable_grad to support torch.func.grad_and_value
# in subgraph.
gm = materialize_as_graph(
combine_fn_with_carry_checkpoint,
(*init, *[x[0] for x in xs], *additional_inputs),
force_enable_grad=True,
)
with torch._C._AutoDispatchBelowAutograd():
# 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
c_T, carries_ys = _extract_carry_and_out(
scan_op(
combine_fn_with_carry_checkpoint,
gm,
init,
xs,
additional_inputs,
@ -621,6 +631,7 @@ class ScanAutogradOp(torch.autograd.Function):
Args:
flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
"""
from torch._higher_order_ops.utils import fill_none_with_masks
# Collect the saved items from the forward
num_leaves_init = ctx._num_leaves_init
@ -640,10 +651,11 @@ class ScanAutogradOp(torch.autograd.Function):
def initialize_g_additional_inputs(
additional_inputs,
):
# The initial gradients for the additional_inputs are all zeros
g_additional_inputs = [
torch.zeros_like(ai) if ai_tm else None
for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs)
torch.zeros_like(ai)
for ai in filter_with_masks(
additional_inputs, additional_inputs_tensor_mask
)
]
return g_additional_inputs
@ -668,6 +680,11 @@ class ScanAutogradOp(torch.autograd.Function):
)
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
# Initialize the g_additional_inputs with zero-tensors.
# This step is necessary because the gradients of the additional inputs are accumulated in the
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
def combine_fn_bw_grad_accumulation(*args):
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
@ -680,7 +697,7 @@ class ScanAutogradOp(torch.autograd.Function):
) = split_into_chunks(
args,
[
num_additional_inputs,
len(initial_g_additional_inputs),
num_leaves_init + num_leaves_ys,
num_leaves_init + num_leaves_xs + num_additional_inputs,
],
@ -694,13 +711,15 @@ class ScanAutogradOp(torch.autograd.Function):
new_g_additional_inputs = [
# If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
carr_g + curr_g if add_inp_tm else carr_g
for add_inp_tm, carr_g, curr_g in zip(
additional_inputs_tensor_mask,
carr_g + curr_g
for carr_g, curr_g in zip(
carried_g_additional_input,
g_additional_inputs_t,
filter_with_masks(
g_additional_inputs_t, additional_inputs_tensor_mask
),
)
]
assert all(isinstance(t, torch.Tensor) for t in new_g_additional_inputs)
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
# The ``g_xs_t`` is encoded as the output of the backward scan operator
@ -718,9 +737,9 @@ class ScanAutogradOp(torch.autograd.Function):
# Because only tensor elements of additional inputs can have requires_grad=True,
# the values for non-tensor elements of additional inputs are None
masked_additional_inputs = [
a.clone() if add_inp_tm else None
for add_inp_tm, a in zip(
additional_inputs_tensor_mask, additional_inputs
a.clone()
for a in filter_with_masks(
additional_inputs, additional_inputs_tensor_mask
)
]
@ -771,11 +790,12 @@ class ScanAutogradOp(torch.autograd.Function):
# Decompose the flat_grads into g_c_T, g_ys
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
# Initialize the g_additional_inputs with zero-tensors.
# This step is necessary because the gradients of the additional inputs are accumulated in the
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
assert all(
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_c_T
)
assert all(
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_ys
)
# Prepend the inits to the carries.
# This is needed, because when computing the gradients, the last carry is not needed
@ -804,13 +824,19 @@ class ScanAutogradOp(torch.autograd.Function):
# Unpack the computed gradients
g_additional_inputs, g_init, g_xs = split_into_chunks(
gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs]
gradients,
[len(initial_g_additional_inputs), num_leaves_init, num_leaves_xs],
)
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
g_xs = [torch.flip(elem, [0]) for elem in g_xs]
return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
return (
*[None] * 4,
*g_init,
*g_xs,
*fill_none_with_masks(g_additional_inputs, additional_inputs_tensor_mask),
)
@scan_op.py_autograd_impl

View File

@ -734,6 +734,31 @@ def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[A
return elements
def _clone_aliasing_output(inputs: Sequence[Any], outputs: Sequence[Any]):
# For tensors whose grad is None, create zero tensors as gradients
# This invariant is useful for cudagraph.
# Elimitate input-output, output-output aliasing
seen_input_storages = {
StorageWeakRef(t._typed_storage())
for t in inputs
if isinstance(t, torch.Tensor)
}
seen_output_storages = set()
final_outputs = []
for out in outputs:
if isinstance(out, torch.Tensor):
out_storage = StorageWeakRef(out._typed_storage())
if (
out_storage in seen_input_storages
or out_storage in seen_output_storages
):
out = out.clone()
seen_output_storages.add(StorageWeakRef(out._typed_storage()))
final_outputs.append(out)
return final_outputs
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
"""
For a fn that accepts flat inputs and returns flat outputs:
@ -773,34 +798,14 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
# For tensors whose grad is None, create zero tensors as gradients
# This invariant is useful for cudagraph.
grads = [
grad_args = [
torch.zeros_like(arg)
if isinstance(arg, torch.Tensor) and grad is None
else grad
for grad, arg in zip(grad_args, primals)
]
# Elimitate input-output, output-output aliasing
seen_input_storages = {
StorageWeakRef(t._typed_storage())
for t in args_and_grad_outs
if isinstance(t, torch.Tensor)
}
seen_output_storages = set()
final_grads = []
for grad, arg in zip(grads, primals):
if isinstance(arg, torch.Tensor):
assert isinstance(grad, torch.Tensor)
seen_input_storages.add(StorageWeakRef(arg._typed_storage()))
grad_storage = StorageWeakRef(grad._typed_storage())
if (
grad_storage in seen_input_storages
or grad_storage in seen_output_storages
):
grad = grad.clone()
seen_output_storages.add(StorageWeakRef(grad._typed_storage()))
final_grads.append(grad)
final_grads = _clone_aliasing_output(args_and_grad_outs, grad_args)
return final_grads
return flat_fn

View File

@ -8664,6 +8664,17 @@ class WhileLoop(ExternKernel):
return result
@staticmethod
def _maybe_wrap_as_tensor_box(out: IRNode) -> IRNode:
if isinstance(out, TensorBox):
return out
elif isinstance(out, (StorageBox, ReinterpretView)):
return TensorBox(out)
elif isinstance(out, MultiOutput):
return TensorBox.create(out)
else:
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
@classmethod
def create(
cls,
@ -8686,9 +8697,21 @@ class WhileLoop(ExternKernel):
ret = []
for tb, fk in zip(tensor_boxes, fake_tensors):
if isinstance(fk, torch.Tensor):
# Subgraph lowering always return StorageBox as graph_outputs because
# it realizes the outputs.
#
# However, require_exact_strides is expecting TensorBox
# e.g. in require_exact_strides when an expand happens,
# the fake tensor's stride is (0, 0, 0) but the storage
# box might have a different stride so lowering.slice_
# is used to make the stride consistent and it expects input to
# be TensorBox.
#
# So we wrap the inputs as tensor boxes if they're not yet.
new_tb = WhileLoop._maybe_wrap_as_tensor_box(tb)
ret.append(
ExternKernel.require_exact_strides(
tb, fk.stride(), allow_padding=False
new_tb, fk.stride(), allow_padding=False
)
)
else:

View File

@ -7095,21 +7095,11 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
msg = f"{msg} Found from : \n {stack_trace}"
V.graph.disable_cudagraphs_reason = msg
def _map_output(out: Any):
if isinstance(out, TensorBox):
return out
elif isinstance(out, ir.StorageBox):
return TensorBox(out)
elif isinstance(out, ir.MultiOutput):
return TensorBox.create(out)
else:
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
result = ir.WhileLoop.create(
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
)
assert isinstance(result, Sequence)
return list(map(_map_output, result))
return list(map(ir.WhileLoop._maybe_wrap_as_tensor_box, result))
register_lowering(