mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[scan] materialize combine_fn in forward add more autograd tests (#161732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161732 Approved by: https://github.com/zou3519 ghstack dependencies: #161557, #161664, #161808, #162025
This commit is contained in:
committed by
PyTorch MergeBot
parent
b85bee3bbb
commit
3413490f53
@ -1661,7 +1661,7 @@ class ScanModels:
|
||||
super().__init__()
|
||||
self.reverse = reverse
|
||||
self.dim = dim
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
self.linear = torch.nn.Linear(4, 4, dtype=torch.float64)
|
||||
|
||||
def forward(self, scan_op, init, xs):
|
||||
def combine_fn(carry, x):
|
||||
@ -1893,7 +1893,10 @@ class ScanTests(TestCase):
|
||||
):
|
||||
import copy
|
||||
|
||||
inputs = [inp.requires_grad_(autograd) for inp in inputs]
|
||||
inputs = [
|
||||
inp.requires_grad_(autograd) if inp.dtype.is_floating_point else inp
|
||||
for inp in inputs
|
||||
]
|
||||
inputs = [inp.to(device=device) for inp in inputs]
|
||||
model = model.to(device=device)
|
||||
for p in model.parameters():
|
||||
@ -1903,7 +1906,8 @@ class ScanTests(TestCase):
|
||||
model2 = copy.deepcopy(model)
|
||||
model3 = copy.deepcopy(model)
|
||||
model4 = copy.deepcopy(model)
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model)
|
||||
model3.compile(fullgraph=True, dynamic=dynamic)
|
||||
model4.compile(fullgraph=True, dynamic=dynamic)
|
||||
|
||||
def _run_model(model, inputs):
|
||||
cloned_inputs = [
|
||||
@ -1928,11 +1932,9 @@ class ScanTests(TestCase):
|
||||
|
||||
result_exp = _run_model(model1, [_fake_scan] + inputs)
|
||||
result_eager = _run_model(model2, [scan] + inputs)
|
||||
result_compiled = _run_model(
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model3), [scan] + inputs
|
||||
)
|
||||
result_compiled = _run_model(model3, [scan] + inputs)
|
||||
result_compiled_exp = _run_model(
|
||||
torch.compile(fullgraph=True, dynamic=dynamic)(model4),
|
||||
model4,
|
||||
[_fake_scan] + inputs,
|
||||
)
|
||||
|
||||
@ -1958,8 +1960,9 @@ class ScanTests(TestCase):
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("reverse", [True, False])
|
||||
@parametrize("dim", [0, 1, 2])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim):
|
||||
def test_scan_pytree_in_out(self, device, dynamic, reverse, dim, autograd):
|
||||
self._run_test(
|
||||
model=ScanModels.SimpleWithPytreeInOuts(reverse=reverse, dim=dim),
|
||||
inputs=(
|
||||
@ -1969,6 +1972,7 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -1977,10 +1981,13 @@ class ScanTests(TestCase):
|
||||
@parametrize("reverse", [True, False])
|
||||
@parametrize("dim", [0, 1, 3])
|
||||
@parametrize("scan_length", [1, 5])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_nn_modules(self, device, dynamic, reverse, dim, scan_length):
|
||||
init = torch.randn(20, 16, 4, 4)
|
||||
xs = torch.randn(scan_length, 20, 16, 4, 4)
|
||||
def test_scan_nn_modules(
|
||||
self, device, dynamic, reverse, dim, scan_length, autograd
|
||||
):
|
||||
init = torch.randn(20, 16, 4, 4, dtype=torch.float64)
|
||||
xs = torch.randn(scan_length, 20, 16, 4, 4, dtype=torch.float64)
|
||||
xs = xs.movedim(0, dim)
|
||||
self._run_test(
|
||||
model=ScanModels.ScanLinearWithView(reverse=reverse, dim=dim),
|
||||
@ -1990,6 +1997,7 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -1998,8 +2006,9 @@ class ScanTests(TestCase):
|
||||
@parametrize("reverse", [True, False])
|
||||
@parametrize("dim", [0, 1, 3])
|
||||
@parametrize("scan_length", [1, 5])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length):
|
||||
def test_scan_conv(self, device, dynamic, reverse, dim, scan_length, autograd):
|
||||
init = torch.randn(2, 4, 4, 4, dtype=torch.float64)
|
||||
xs = torch.randn(scan_length, 2, 4, 4, 4, dtype=torch.float64)
|
||||
xs = xs.movedim(0, dim)
|
||||
@ -2011,6 +2020,7 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -2020,10 +2030,13 @@ class ScanTests(TestCase):
|
||||
@parametrize("dim", [0, 1, 3])
|
||||
@parametrize("pred", [True, False])
|
||||
@parametrize("scan_length", [1, 5])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_in_cond(self, device, dynamic, reverse, dim, pred, scan_length):
|
||||
init = torch.randn(4, 4, 4)
|
||||
xs = torch.randn(scan_length, 4, 4, 4)
|
||||
def test_scan_in_cond(
|
||||
self, device, dynamic, reverse, dim, pred, scan_length, autograd
|
||||
):
|
||||
init = torch.randn(4, 4, 4, dtype=torch.float64)
|
||||
xs = torch.randn(scan_length, 4, 4, 4, dtype=torch.float64)
|
||||
xs = xs.movedim(0, dim)
|
||||
self._run_test(
|
||||
model=ScanModels.ScanInCond(reverse=reverse, dim=dim),
|
||||
@ -2034,6 +2047,7 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -2042,8 +2056,9 @@ class ScanTests(TestCase):
|
||||
@parametrize("reverse", [True, False])
|
||||
@parametrize("dim", [0, 1, 3])
|
||||
@parametrize("scan_length", [1, 5])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length):
|
||||
def test_cond_in_scan(self, device, dynamic, reverse, dim, scan_length, autograd):
|
||||
init = torch.randn(2, 4, 4, 4)
|
||||
xs = torch.randn(scan_length, 4, 4, 4)
|
||||
xs = xs.movedim(0, dim)
|
||||
@ -2055,13 +2070,15 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_chunked_ce(self, device, dynamic):
|
||||
def test_scan_chunked_ce(self, device, dynamic, autograd):
|
||||
self._run_test(
|
||||
model=ScanModels.ChunkedCE(10),
|
||||
inputs=(
|
||||
@ -2072,6 +2089,7 @@ class ScanTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -2095,8 +2113,9 @@ class ScanTests(TestCase):
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_scan_with_clamp(self, device, dynamic):
|
||||
def test_scan_with_clamp(self, device, dynamic, autograd):
|
||||
B = 4
|
||||
T = 8
|
||||
H = 16
|
||||
@ -2104,10 +2123,11 @@ class ScanTests(TestCase):
|
||||
model=ScanModels.ScanWithClamp(),
|
||||
inputs=(
|
||||
torch.randn((B, H)),
|
||||
torch.randn((T, B, H), requires_grad=True),
|
||||
torch.randn((T, B, H)),
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
check_meta_consistency,
|
||||
create_bw_fn,
|
||||
filter_with_masks,
|
||||
first_slice_copy,
|
||||
first_slice_copy_with_grad,
|
||||
get_tensor_mask,
|
||||
@ -591,11 +592,20 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
*y,
|
||||
]
|
||||
|
||||
# Materialize the ``combine_fn_with_carry_checkpoint`` with enable_grad
|
||||
# we need enable_grad to support torch.func.grad_and_value
|
||||
# in subgraph.
|
||||
gm = materialize_as_graph(
|
||||
combine_fn_with_carry_checkpoint,
|
||||
(*init, *[x[0] for x in xs], *additional_inputs),
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
# 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
|
||||
c_T, carries_ys = _extract_carry_and_out(
|
||||
scan_op(
|
||||
combine_fn_with_carry_checkpoint,
|
||||
gm,
|
||||
init,
|
||||
xs,
|
||||
additional_inputs,
|
||||
@ -621,6 +631,7 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
Args:
|
||||
flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
|
||||
"""
|
||||
from torch._higher_order_ops.utils import fill_none_with_masks
|
||||
|
||||
# Collect the saved items from the forward
|
||||
num_leaves_init = ctx._num_leaves_init
|
||||
@ -640,10 +651,11 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
def initialize_g_additional_inputs(
|
||||
additional_inputs,
|
||||
):
|
||||
# The initial gradients for the additional_inputs are all zeros
|
||||
g_additional_inputs = [
|
||||
torch.zeros_like(ai) if ai_tm else None
|
||||
for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs)
|
||||
torch.zeros_like(ai)
|
||||
for ai in filter_with_masks(
|
||||
additional_inputs, additional_inputs_tensor_mask
|
||||
)
|
||||
]
|
||||
return g_additional_inputs
|
||||
|
||||
@ -668,6 +680,11 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
)
|
||||
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
|
||||
|
||||
# Initialize the g_additional_inputs with zero-tensors.
|
||||
# This step is necessary because the gradients of the additional inputs are accumulated in the
|
||||
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
|
||||
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
|
||||
|
||||
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
|
||||
def combine_fn_bw_grad_accumulation(*args):
|
||||
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
|
||||
@ -680,7 +697,7 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
) = split_into_chunks(
|
||||
args,
|
||||
[
|
||||
num_additional_inputs,
|
||||
len(initial_g_additional_inputs),
|
||||
num_leaves_init + num_leaves_ys,
|
||||
num_leaves_init + num_leaves_xs + num_additional_inputs,
|
||||
],
|
||||
@ -694,13 +711,15 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
|
||||
new_g_additional_inputs = [
|
||||
# If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
|
||||
carr_g + curr_g if add_inp_tm else carr_g
|
||||
for add_inp_tm, carr_g, curr_g in zip(
|
||||
additional_inputs_tensor_mask,
|
||||
carr_g + curr_g
|
||||
for carr_g, curr_g in zip(
|
||||
carried_g_additional_input,
|
||||
g_additional_inputs_t,
|
||||
filter_with_masks(
|
||||
g_additional_inputs_t, additional_inputs_tensor_mask
|
||||
),
|
||||
)
|
||||
]
|
||||
assert all(isinstance(t, torch.Tensor) for t in new_g_additional_inputs)
|
||||
|
||||
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
|
||||
# The ``g_xs_t`` is encoded as the output of the backward scan operator
|
||||
@ -718,9 +737,9 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
# Because only tensor elements of additional inputs can have requires_grad=True,
|
||||
# the values for non-tensor elements of additional inputs are None
|
||||
masked_additional_inputs = [
|
||||
a.clone() if add_inp_tm else None
|
||||
for add_inp_tm, a in zip(
|
||||
additional_inputs_tensor_mask, additional_inputs
|
||||
a.clone()
|
||||
for a in filter_with_masks(
|
||||
additional_inputs, additional_inputs_tensor_mask
|
||||
)
|
||||
]
|
||||
|
||||
@ -771,11 +790,12 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
|
||||
# Decompose the flat_grads into g_c_T, g_ys
|
||||
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
|
||||
|
||||
# Initialize the g_additional_inputs with zero-tensors.
|
||||
# This step is necessary because the gradients of the additional inputs are accumulated in the
|
||||
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
|
||||
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
|
||||
assert all(
|
||||
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_c_T
|
||||
)
|
||||
assert all(
|
||||
isinstance(t, torch.Tensor) and t.dtype.is_floating_point for t in g_ys
|
||||
)
|
||||
|
||||
# Prepend the inits to the carries.
|
||||
# This is needed, because when computing the gradients, the last carry is not needed
|
||||
@ -804,13 +824,19 @@ class ScanAutogradOp(torch.autograd.Function):
|
||||
|
||||
# Unpack the computed gradients
|
||||
g_additional_inputs, g_init, g_xs = split_into_chunks(
|
||||
gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs]
|
||||
gradients,
|
||||
[len(initial_g_additional_inputs), num_leaves_init, num_leaves_xs],
|
||||
)
|
||||
|
||||
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
|
||||
g_xs = [torch.flip(elem, [0]) for elem in g_xs]
|
||||
|
||||
return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
|
||||
return (
|
||||
*[None] * 4,
|
||||
*g_init,
|
||||
*g_xs,
|
||||
*fill_none_with_masks(g_additional_inputs, additional_inputs_tensor_mask),
|
||||
)
|
||||
|
||||
|
||||
@scan_op.py_autograd_impl
|
||||
|
@ -734,6 +734,31 @@ def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[A
|
||||
return elements
|
||||
|
||||
|
||||
def _clone_aliasing_output(inputs: Sequence[Any], outputs: Sequence[Any]):
|
||||
# For tensors whose grad is None, create zero tensors as gradients
|
||||
# This invariant is useful for cudagraph.
|
||||
|
||||
# Elimitate input-output, output-output aliasing
|
||||
seen_input_storages = {
|
||||
StorageWeakRef(t._typed_storage())
|
||||
for t in inputs
|
||||
if isinstance(t, torch.Tensor)
|
||||
}
|
||||
seen_output_storages = set()
|
||||
final_outputs = []
|
||||
for out in outputs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
out_storage = StorageWeakRef(out._typed_storage())
|
||||
if (
|
||||
out_storage in seen_input_storages
|
||||
or out_storage in seen_output_storages
|
||||
):
|
||||
out = out.clone()
|
||||
seen_output_storages.add(StorageWeakRef(out._typed_storage()))
|
||||
final_outputs.append(out)
|
||||
return final_outputs
|
||||
|
||||
|
||||
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
||||
"""
|
||||
For a fn that accepts flat inputs and returns flat outputs:
|
||||
@ -773,34 +798,14 @@ def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
||||
|
||||
# For tensors whose grad is None, create zero tensors as gradients
|
||||
# This invariant is useful for cudagraph.
|
||||
grads = [
|
||||
grad_args = [
|
||||
torch.zeros_like(arg)
|
||||
if isinstance(arg, torch.Tensor) and grad is None
|
||||
else grad
|
||||
for grad, arg in zip(grad_args, primals)
|
||||
]
|
||||
|
||||
# Elimitate input-output, output-output aliasing
|
||||
seen_input_storages = {
|
||||
StorageWeakRef(t._typed_storage())
|
||||
for t in args_and_grad_outs
|
||||
if isinstance(t, torch.Tensor)
|
||||
}
|
||||
seen_output_storages = set()
|
||||
final_grads = []
|
||||
for grad, arg in zip(grads, primals):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert isinstance(grad, torch.Tensor)
|
||||
seen_input_storages.add(StorageWeakRef(arg._typed_storage()))
|
||||
grad_storage = StorageWeakRef(grad._typed_storage())
|
||||
if (
|
||||
grad_storage in seen_input_storages
|
||||
or grad_storage in seen_output_storages
|
||||
):
|
||||
grad = grad.clone()
|
||||
seen_output_storages.add(StorageWeakRef(grad._typed_storage()))
|
||||
final_grads.append(grad)
|
||||
|
||||
final_grads = _clone_aliasing_output(args_and_grad_outs, grad_args)
|
||||
return final_grads
|
||||
|
||||
return flat_fn
|
||||
|
@ -8664,6 +8664,17 @@ class WhileLoop(ExternKernel):
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _maybe_wrap_as_tensor_box(out: IRNode) -> IRNode:
|
||||
if isinstance(out, TensorBox):
|
||||
return out
|
||||
elif isinstance(out, (StorageBox, ReinterpretView)):
|
||||
return TensorBox(out)
|
||||
elif isinstance(out, MultiOutput):
|
||||
return TensorBox.create(out)
|
||||
else:
|
||||
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@ -8686,9 +8697,21 @@ class WhileLoop(ExternKernel):
|
||||
ret = []
|
||||
for tb, fk in zip(tensor_boxes, fake_tensors):
|
||||
if isinstance(fk, torch.Tensor):
|
||||
# Subgraph lowering always return StorageBox as graph_outputs because
|
||||
# it realizes the outputs.
|
||||
#
|
||||
# However, require_exact_strides is expecting TensorBox
|
||||
# e.g. in require_exact_strides when an expand happens,
|
||||
# the fake tensor's stride is (0, 0, 0) but the storage
|
||||
# box might have a different stride so lowering.slice_
|
||||
# is used to make the stride consistent and it expects input to
|
||||
# be TensorBox.
|
||||
#
|
||||
# So we wrap the inputs as tensor boxes if they're not yet.
|
||||
new_tb = WhileLoop._maybe_wrap_as_tensor_box(tb)
|
||||
ret.append(
|
||||
ExternKernel.require_exact_strides(
|
||||
tb, fk.stride(), allow_padding=False
|
||||
new_tb, fk.stride(), allow_padding=False
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -7095,21 +7095,11 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
|
||||
msg = f"{msg} Found from : \n {stack_trace}"
|
||||
V.graph.disable_cudagraphs_reason = msg
|
||||
|
||||
def _map_output(out: Any):
|
||||
if isinstance(out, TensorBox):
|
||||
return out
|
||||
elif isinstance(out, ir.StorageBox):
|
||||
return TensorBox(out)
|
||||
elif isinstance(out, ir.MultiOutput):
|
||||
return TensorBox.create(out)
|
||||
else:
|
||||
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
|
||||
|
||||
result = ir.WhileLoop.create(
|
||||
cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
|
||||
)
|
||||
assert isinstance(result, Sequence)
|
||||
return list(map(_map_output, result))
|
||||
return list(map(ir.WhileLoop._maybe_wrap_as_tensor_box, result))
|
||||
|
||||
|
||||
register_lowering(
|
||||
|
Reference in New Issue
Block a user