Revert "[associative_scan] Autograd separated (#139939)"

This reverts commit 103f725afa8dbf0204a1be6a042ab93aa16d85d8.

Reverted https://github.com/pytorch/pytorch/pull/139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](https://github.com/pytorch/pytorch/pull/139939#issuecomment-3267945657))
This commit is contained in:
PyTorch MergeBot
2025-09-08 20:42:47 +00:00
parent 015423bef8
commit 5d819f3faf
4 changed files with 110 additions and 587 deletions

View File

@ -144,7 +144,7 @@ def get_scan_combine_fn(name, associative=True, parameters=None):
}
def non_pointwise(x: torch.Tensor, y: torch.Tensor):
W = torch.arange(4, dtype=torch.float, device=x.device).view(2, 2)
W = torch.diag(torch.ones(2, device=x.device))
return x @ W + y @ W
def RNN(x: torch.Tensor, y: torch.Tensor):
@ -3717,38 +3717,11 @@ class AssociativeScanTests(TestCase):
torch._dynamo.reset()
super().setUp()
def _check_autograd(self, result, result_exp, autograd_param):
grad_param = [p for p in autograd_param if p.requires_grad]
result_flatten, _ = pytree.tree_flatten(result)
result_exp_flatten, _ = pytree.tree_flatten(result_exp)
result_flatten = [r for r in result_flatten if r.requires_grad]
result_exp_flatten = [r for r in result_exp_flatten if r.requires_grad]
# Check the result and parameter lists
assert len(result_flatten) == len(result_exp_flatten), (
"The number of elements requiring gradients is different for the results and the expected results"
)
grad_exp_init = [torch.ones_like(el) for el in result_exp_flatten]
expected_grads = torch.autograd.grad(
result_exp_flatten, grad_param, grad_exp_init
)
grad_init = [torch.ones_like(el) for el in result_flatten]
grads = torch.autograd.grad(result_flatten, grad_param, grad_init)
self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06)
def _run_test(self, model, model_fake, inputs, autograd_param=None):
def _run_test(self, model, model_fake, inputs):
result = model(inputs)
result_exp = model_fake(inputs)
self.assertEqual(result, result_exp)
if autograd_param is not None and any(
par.requires_grad for par in autograd_param
):
self._check_autograd(result, result_exp, autograd_param)
# Return the result of the functions under test for further investigations
return result
@ -3763,7 +3736,6 @@ class AssociativeScanTests(TestCase):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -3779,22 +3751,10 @@ class AssociativeScanTests(TestCase):
)
),
)
# # Skipping this combination as there is a CPP compilation failure that
# # may be unrelated to associative_scan itself. There is a dedicated tests for
# # this case below.
# @decorateIf(
# unittest.skip,
# lambda params: (
# params["compile_mode"] == "compile_dynamic_shape"
# and params["combine_mode"] == "generic"
# and params["device"] == torch.device("cpu")
# and params["autograd"]
# ),
# )
def test_associative_scan_compile(
self, combine_mode, reverse, compile_mode, device, autograd
self, combine_mode, reverse, compile_mode, device
):
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
x = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": 0,
"reverse": reverse,
@ -3806,7 +3766,6 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
if not reverse:
@ -3816,9 +3775,7 @@ class AssociativeScanTests(TestCase):
self.assertEqual(results, results_torch)
# Jax Examples
x = torch.arange(
0, 4, device=device, dtype=torch.float32, requires_grad=autograd
)
x = torch.arange(0, 4, device=device)
kwargs = {
"dim": 0,
"reverse": reverse,
@ -3831,13 +3788,12 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
if not reverse:
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.float32)
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
else:
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.float32)
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
self.assertEqual(result, results_torch)
@ -3847,7 +3803,6 @@ class AssociativeScanTests(TestCase):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -3863,9 +3818,7 @@ class AssociativeScanTests(TestCase):
)
),
)
def test_associative_scan_dim(
self, combine_mode, compile_mode, reverse, device, autograd
):
def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device):
import random
random.seed(1234)
@ -3876,7 +3829,7 @@ class AssociativeScanTests(TestCase):
torch._dynamo.reset()
shapes = [random.randint(1, 9) for _ in range(num_dim)]
rnd_scan_dim = random.randint(0, num_dim - 1)
x = torch.randn(*shapes, device=device, requires_grad=autograd)
x = torch.randn(*shapes, device=device)
kwargs = {
"dim": rnd_scan_dim,
@ -3889,7 +3842,6 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.Simple(**kwargs),
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
if not reverse:
@ -3928,7 +3880,6 @@ class AssociativeScanTests(TestCase):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -3944,11 +3895,9 @@ class AssociativeScanTests(TestCase):
)
),
)
def test_associative_scan_tuple(
self, compile_mode, combine_mode, reverse, device, autograd
):
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, device):
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
inp = (x, y)
kwargs = {
@ -3963,19 +3912,18 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else inp,
)
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
def test_associative_scan_expand_in_combine_fn(
self, compile_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
x = torch.randn(3, 2, 2, device=device)
def combine_fn(x, y):
return x * torch.sum(y, -1).expand(x.shape)
@ -3992,7 +3940,6 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4000,15 +3947,10 @@ class AssociativeScanTests(TestCase):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
def test_associative_scan_non_contiguous_tensor(
self, compile_mode, reverse, device, autograd
self, compile_mode, reverse, device
):
x = (
torch.arange(30, device=device, dtype=torch.float32, requires_grad=autograd)
.view(10, 3)
.t()
)
x = torch.arange(30, device=device).view(10, 3).t()
assert not x.is_contiguous()
kwargs = {
@ -4023,7 +3965,6 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4032,7 +3973,6 @@ class AssociativeScanTests(TestCase):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -4049,11 +3989,11 @@ class AssociativeScanTests(TestCase):
),
)
def test_associative_scan_complex_pytree(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
x = torch.randn(3, 2, 2, device=device)
y = torch.randn(3, 2, 2, device=device)
z = torch.randn(3, 2, 2, device=device)
inp = {"i": x, "j": ([y], [{"o": z}])}
kwargs = {
@ -4068,7 +4008,6 @@ class AssociativeScanTests(TestCase):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (x, y, z),
)
@skipIfTorchDynamo("don't test compile on compile")
@ -4218,7 +4157,6 @@ class GraphModule(torch.nn.Module):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -4235,7 +4173,7 @@ class GraphModule(torch.nn.Module):
),
)
def test_associative_scan_downstream_scan_matmul(
self, combine_mode, compile_mode, reverse, device, autograd
self, combine_mode, compile_mode, reverse, device
):
def first_chain_fct(scan_fct, inp, **kwargs):
o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
@ -4245,7 +4183,7 @@ class GraphModule(torch.nn.Module):
W = torch.ones(2, 5, device=device)
return inp @ W
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
inp = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": 1,
"reverse": reverse,
@ -4258,7 +4196,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4267,7 +4204,6 @@ class GraphModule(torch.nn.Module):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -4284,7 +4220,7 @@ class GraphModule(torch.nn.Module):
),
)
def test_associative_scan_downstream_scan_scan(
self, combine_mode, compile_mode, reverse, device, autograd
self, combine_mode, compile_mode, reverse, device
):
def first_chain_fct(scan_fct, inp, **kwargs):
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
@ -4294,7 +4230,7 @@ class GraphModule(torch.nn.Module):
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
return o2
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
inp = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": 1,
@ -4308,7 +4244,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4318,7 +4253,6 @@ class GraphModule(torch.nn.Module):
@parametrize("reverse_first", [False, True])
@parametrize("same_direction", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -4334,20 +4268,8 @@ class GraphModule(torch.nn.Module):
)
),
)
# Skipping the autograd=True because
# associative_scan does currently not support gradients for lifted parameters
@decorateIf(
unittest.skip,
lambda params: (params["combine_mode"] == "pointwise" and params["autograd"]),
)
def test_associative_scan_downstream_scan_scan_different_dim(
self,
combine_mode,
compile_mode,
reverse_first,
same_direction,
device,
autograd,
self, combine_mode, compile_mode, reverse_first, same_direction, device
):
reverse_second = reverse_first if same_direction else not reverse_first
@ -4359,7 +4281,7 @@ class GraphModule(torch.nn.Module):
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
return o2
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
inp = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": [1, 0],
@ -4373,7 +4295,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.ChainFn(**kwargs),
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
# TODO: Does not work because of the usage of vmap within associative_scan
@ -4432,9 +4353,8 @@ class GraphModule(torch.nn.Module):
@parametrize("loop_type", ["for"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
def test_associative_scan_loop_in_combine_fn(
self, compile_mode, loop_type, reverse, device, autograd
self, compile_mode, loop_type, reverse, device
):
def combine_fn(x, y):
cnt = torch.zeros_like(y[0, :])
@ -4459,7 +4379,7 @@ class GraphModule(torch.nn.Module):
cnt += torch.abs(y[ind])
return x * cnt
inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd) * 2
inp = torch.randn(3, 10, 1, device=device) * 2
kwargs = {
"dim": 0,
@ -4473,7 +4393,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
# TODO: Does not work because of the usage of vmap within associative_scan
@ -4518,7 +4437,6 @@ class GraphModule(torch.nn.Module):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of compile_mode=compile_dynamic_shape
# as the current implementation does not support lifted arguments
@decorateIf(
@ -4529,14 +4447,12 @@ class GraphModule(torch.nn.Module):
or torch.version.hip
),
)
def test_associative_scan_cond_in_combine_fn(
self, compile_mode, reverse, device, autograd
):
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
def combine_fn(x, y):
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
return x * val
inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd)
inp = torch.randn(3, 10, 1, device=device)
kwargs = {
"dim": 0,
@ -4550,7 +4466,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
# TODO: Does not work because of the usage of vmap within associative_scan
@ -4592,10 +4507,7 @@ class GraphModule(torch.nn.Module):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
def test_associative_scan_vmap_in_combine_fn(
self, compile_mode, reverse, device, autograd
):
def test_associative_scan_vmap_in_combine_fn(self, compile_mode, reverse, device):
def combine_fn(x, y):
def body(x):
return x**2
@ -4604,7 +4516,7 @@ class GraphModule(torch.nn.Module):
y_new = mapped_body(y)
return x + y_new
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
inp = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": 0,
@ -4618,7 +4530,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4626,7 +4537,6 @@ class GraphModule(torch.nn.Module):
@parametrize("reverse", [False, True])
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of associative_scan and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
@ -4634,9 +4544,9 @@ class GraphModule(torch.nn.Module):
lambda params: (params["device"] == torch.device("cpu")),
)
def test_associative_scan_non_pointwise_generic(
self, reverse, compile_mode, device, autograd
self, reverse, compile_mode, device
):
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
x = torch.randn(3, 10, 2, device=device)
kwargs = {
"dim": 0,
@ -4650,7 +4560,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=x,
autograd_param=None if not autograd else (x,),
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@ -4660,7 +4569,6 @@ class GraphModule(torch.nn.Module):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
@ -4677,14 +4585,14 @@ class GraphModule(torch.nn.Module):
),
)
def test_associative_scan_binary_operator(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
state_dim = 20
timesteps = 10
projected_inputs = torch.randn(
timesteps, state_dim, device=device, requires_grad=autograd
timesteps, state_dim, requires_grad=True, device=device
)
A = torch.randn(state_dim, device=device, requires_grad=autograd)
A = torch.randn(state_dim, requires_grad=True, device=device)
elements = (A.repeat((timesteps, 1)), projected_inputs)
kwargs = {
@ -4699,7 +4607,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=elements,
autograd_param=None if not autograd else elements,
)
@skipIfRocm(msg="Unsupported on ROCM yet")
@ -4781,7 +4688,6 @@ class GraphModule(torch.nn.Module):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combine_mode=pointwise
# as the current implementation of associative_scan lowering
# does not support lifted arguments
@ -4790,9 +4696,9 @@ class GraphModule(torch.nn.Module):
lambda params: (params["combine_mode"] == "pointwise"),
)
def test_associative_scan_freevars_simple(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
H = torch.rand(2, device=device, requires_grad=autograd)
H = torch.rand(2, device=device)
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
return x * H + y * 2
@ -4800,13 +4706,13 @@ class GraphModule(torch.nn.Module):
def fct_freevars2(x: torch.Tensor, y: torch.Tensor):
return x * H + y * H
H1 = torch.rand(1, device=device, requires_grad=autograd)
H2 = torch.rand(1, device=device, requires_grad=autograd)
H1 = torch.rand(1, device=device)
H2 = torch.rand(1, device=device)
def fct_freevars3(x: torch.Tensor, y: torch.Tensor):
return x * H1 + y * H2
inp = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
inp = torch.randn(3, 2, 2, device=device)
for fct, param in [
(fct_freevars1, (H,)),
@ -4825,7 +4731,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp, *param),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4834,7 +4739,6 @@ class GraphModule(torch.nn.Module):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combine_mode=pointwise
# as the current implementation of associative_scan lowering
# does not support lifted arguments
@ -4843,10 +4747,10 @@ class GraphModule(torch.nn.Module):
lambda params: (params["combine_mode"] == "pointwise"),
)
def test_associative_scan_freevars_nested(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
H1 = torch.rand(4, 5, device=device, requires_grad=autograd)
H2 = torch.rand(4, 1, device=device, requires_grad=autograd)
H1 = torch.rand(4, 5, device=device)
H2 = torch.rand(4, 1, device=device)
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
def inner(xi):
@ -4862,10 +4766,13 @@ class GraphModule(torch.nn.Module):
ret = inner(y)
return x + ret * H1
H1_i = torch.rand(4, 5, device=device)
# TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
# RuntimeError: vmap: called random operation while in randomness error mode.
# Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
def fct_nested_inside(x: torch.Tensor, y: torch.Tensor):
# H2_i = torch.rand(4, 1, device=device)
H2_i = torch.ones(4, 1, device=device) * 42
def inner(xi):
@ -4875,6 +4782,7 @@ class GraphModule(torch.nn.Module):
return x + ret * H1
def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor):
# H2_i = torch.rand(4, 1, device=device)
H2_i = torch.ones(4, 1, device=device) * 42
def inner(xi):
@ -4883,11 +4791,11 @@ class GraphModule(torch.nn.Module):
ret = inner(y)
return x + ret * H1
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
inp = torch.randn(3, 4, 5, device=device)
for fct, fct_fake, param in [
(fct_nested_outside, fct_nested_outside_fake, (H1, H2)),
(fct_nested_inside, fct_nested_inside_fake, ()),
(fct_nested_inside, fct_nested_inside_fake, (H1_i,)),
]:
kwargs = {
"dim": 0,
@ -4902,7 +4810,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp, *param),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4911,7 +4818,6 @@ class GraphModule(torch.nn.Module):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combine_mode=pointwise
# as the current implementation of associative_scan lowering
# does not support lifted arguments
@ -4920,7 +4826,7 @@ class GraphModule(torch.nn.Module):
lambda params: (params["combine_mode"] == "pointwise"),
)
def test_associative_scan_freevars_fct(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
def additional_fct_no_add_inp(x, y):
return x * y
@ -4929,7 +4835,7 @@ class GraphModule(torch.nn.Module):
ret = additional_fct_no_add_inp(y, y)
return x + ret
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
inp = torch.randn(3, 4, 5, device=device)
kwargs = {
"dim": 0,
@ -4943,7 +4849,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4951,10 +4856,7 @@ class GraphModule(torch.nn.Module):
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
def test_associative_scan_freevars_fct_generic(
self, compile_mode, reverse, device, autograd
):
def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device):
def additional_fct_no_add_inp(x, y):
return x * y
@ -4968,7 +4870,7 @@ class GraphModule(torch.nn.Module):
ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1)
return x + ret
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
inp = torch.randn(3, 4, 5, device=device)
kwargs = {
"dim": 0,
@ -4983,7 +4885,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -4992,7 +4893,6 @@ class GraphModule(torch.nn.Module):
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("autograd", [False, True])
# Skipping the combine_mode=pointwise
# as the current implementation of associative_scan lowering
# does not support lifted arguments
@ -5001,7 +4901,7 @@ class GraphModule(torch.nn.Module):
lambda params: (params["combine_mode"] == "pointwise"),
)
def test_associative_scan_freevars_shape_check(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
H = torch.eye(2, device=device, requires_grad=True)
@ -5022,7 +4922,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (inp,),
)
@unittest.skipIf(not SM70OrLater, "triton")
@ -5031,7 +4930,6 @@ class GraphModule(torch.nn.Module):
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("autograd", [False, True])
# Skipping the combine_mode=pointwise
# as the current implementation of associative_scan lowering
# does not support lifted arguments
@ -5040,11 +4938,11 @@ class GraphModule(torch.nn.Module):
lambda params: (params["combine_mode"] == "pointwise"),
)
def test_associative_scan_freevars_pytree(
self, compile_mode, combine_mode, reverse, device, autograd
self, compile_mode, combine_mode, reverse, device
):
xf = torch.randn(2, 2, device=device, requires_grad=autograd)
yf = torch.randn(2, 2, device=device, requires_grad=autograd)
zf = torch.randn(2, 2, device=device, requires_grad=autograd)
xf = torch.randn(2, 2, device=device, requires_grad=True)
yf = torch.randn(2, 2, device=device, requires_grad=True)
zf = torch.randn(2, 2, device=device, requires_grad=True)
inpf = {"i": xf, "j": ([yf], [{"o": zf}])}
def fct_pointwise(x, y):
@ -5061,9 +4959,9 @@ class GraphModule(torch.nn.Module):
),
}
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
inp = {"i": x, "j": ([y], [{"o": z}])}
kwargs = {
@ -5078,7 +4976,6 @@ class GraphModule(torch.nn.Module):
model=AssociativeScanModels.CombineFn(**kwargs),
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
inputs=inp,
autograd_param=None if not autograd else (*pytree.tree_leaves(inp),),
)
@unittest.skipIf(not SM70OrLater, "triton")

View File

@ -5,21 +5,17 @@ from typing import Any, Callable
import torch
import torch._prims_common as utils
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_maybe_compile_and_run_fn,
_maybe_run_with_interpreter,
autograd_not_implemented,
check_input_alias_and_mutation_return_outputs,
check_meta_consistency,
create_bw_fn,
first_slice_copy,
first_slice_copy_with_grad,
materialize_as_graph,
reenter_make_fx,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
split_into_chunks,
unique_graph_id,
validate_subgraph_args_types,
)
@ -195,9 +191,6 @@ def associative_scan(
cumsum = associative_scan(add, x, dim)
"""
# TODO: Support lifted arguments in inductor for associative_scan
# TODO: Support autograd for cases with lifted arguments for combine_mode=pointwise
# The reason we flatten xs before calling into dynamo is that
# we want to create a consistent input ordering for combine_fn
# and we also want to the input ordering matches the output ordering.
@ -249,6 +242,9 @@ def associative_scan(
if reverse:
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
# TODO: Support Autograd
# TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
if combine_mode == "generic":
# The generic_associative_scan implementation calls the combine_fn with a `batch` along the scan dimension
# For example, consider:
@ -472,378 +468,9 @@ def associative_scan_op_dense(combine_fn, xs, additional_inputs):
return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs)
class AssociativeScanAutogradOp(torch.autograd.Function):
r""" associative_scan
Example::
xs = torch.arange(1, 5) = [1, 2, 3, 4]
def combine_fn(a: torch.Tensor, b: torch.Tensor):
return a * b
ys = associative_scan(comine_fn, xs),
which can be unpacked as:
ys0 = xs0 = 1
ys1 = combine_fn(ys0, xs1) = combine_fn(1, 2) = 2
...
ysT = combine_fn(ys(T-1), xsT) = combine_fn(6, 4) = 24
ys = [1, 2, 6, 24]
This creates a recursive data dependency structure where each output yst
depends on all prior inputs xs0 through xst. The dependency can be visualized as:
Level 0 (Input): xs0 xs1 xs2 xs3 xs4
\ / | | |
\ / | | |
Level 1: ys1 ───────┘ | |
\ / |
\ / |
Level 2: ys2 ────────┘ |
\ /
\ /
Level 3: ys3 ────────────┘
\
\
Level 4: ys4
We could get the following backward gradient graph:
Level 0 (output): g_xs0 g_xs1 g_xs2 g_xs3 g_xs4
\ / | | |
\ / | | |
Level 1: gl_ys1 ─> g_ys1 ──────┘ | |
\ / |
\ / |
Level 2: gl_ys2 ─> g_ys2 ────────┘ |
\ /
\ /
Level 3: gl_ys3 ─> g_ys3 ───────────┘
\
\
Level 4: gl_ys4 ─> g_ys4,
where gl_y1 is the gradient of the loss with respect to ys1 and the input of backward.
To calculate the gradients of the inputs, the chain rule suggests:
g_xs0 = g_ys1
g_xs1 = g_ys1 * bw(ys0, xs1) = g_ys1 * bwxs01
g_xs2 = g_ys2 * bw(ys1, xs2) = g_ys2 * bwxs12
g_xs3 = g_ys3 * bw(ys2, xs3) = g_ys3 * bwxs23
g_xs4 = g_ys4 * bw(ys3, xs4) = g_ys4 * bwxs34
Notice the bw(...) is just the single step bw (instantaneous gradients), whose formula can be computed from combine_fn.
For example bw(ys3, xs4) (also abbreviated with bwxs34) computes the gradients ∂/∂xs4 combine_fn(ys3, xs4).
Similarly, bw(ys4, ys3) (also abbreviated with bwys43) computes the gradients ∂/∂ys3 combine_fn(ys3, xs4).
Let's break down how to calculate g_ys by recursively substituting the unknowns:
g_ys1 = gl_ys1 + g_ys2 * bw(ys2, ys1)
= gl_ys1 + (gl_ys2 + g_ys3 * bw(ys3, ys2)) * bw(ys2, ys1)
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + g_ys3 * bw(ys3, ys2) * bw(y2, y1)
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + gl_ys3 * bw(ys3, ys2) * bw(y2, y1) \
+ g_ys4 * bw(ys4, ys3) * bw(ys3, ys2) * bw(ys2, ys1)
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + gl_ys3 * bw(ys3, ys2) * bw(y2, y1) \
+ gl_ys4 * bw(ys4, ys3) * bw(ys3, ys2) * bw(ys2, ys1)
Let's do the same for all the g_ys:
g_ys2 = gl_ys2 + gl_ys3 * bw(ys3, ys2) + gl_y4 * bw(ys4, ys3) * bw(ys3, ys2)
g_ys3 = gl_ys3 + gl_ys4 * bw(ys4, ys3)
g_ys4 = gl_ys4
Notice that the above can be re-written as columnwise multiplication of y_mat and gl_ys:
g_ys1 1, bwys21, bwys321, bwys4321 gl_ys1
g_ys2 = 0, 1 , bwys321, bwys4321 . gl_ys2
g_ys3 0, 0 , 1 , bwys4321 gl_ys3
g_ys4 0, 0 , 0 , 1 gl_ys4,
where bwys21 is an abbreviation for bw(ys2, ys1),
bwys321 is an abbreviation for bw(ys3, ys2) * bw(ys2, ys1) so on and so forth.
We could effectively compute the upper triangular matrix y_mat with:
cumprod([1, bwys21, bwys32, bwys43]) then masking out the values as needed.
Thus, only [1, bwys21, bwys32, bwys43] are required to compute the y_mat.
References: https://justintchiu.com/blog/pscan_diff/
NOTE: [associative_scan autograd implementation]
The forward of associative_scan can be computed with the following steps:
1.) Compute the forward output of the associative_scan
ys = associative_scan(combine_fn, xs, additional_inputs)
The backward of associative_scan can be computed with the following steps:
2.) Prepare the backward graph
We prepare the backward graph to be used in the backward function.
We utilize ``create_bw_fn`` to generate the joint function:
combine_fn_bw = create_bw_fn(combine_fn, operands)
where operands = [ys{t-1}, xst, additional_inputs]
3.) Materialize the ``combine_fn_bw``
This is required because torch.compile and torch.autograd.grad
cannot trace through the joint backward function dynamically.
4.) Compute the single step bw (instantaneous gradients) at every step t
bwys{t-1}, bwxst = combine_fn_bw(ys{t-1}, xst, 1.)
Here we pass 1 as the upstream gradient to obtain the local partial derivatives.
This gives:
bwys = [bw(ys1, ys0), bw(ys2, ys1), ..., bw(ysT, ys{T-1})]
bwxs = [bw(ys1, xs0), bw(ys2, xs1), ..., bw(ys{T-1}, xsT)]
5.) Compute the gradient transition matrix y_mat
As shown in the example above, each input xst affects all later outputs ysi for i ≥ t.
According to the chain rule, each such path contributes a product of local gradients g_ysk.
For example:
∂ysT/∂xst = ∂ysT/∂ys{T-1} * ∂ys{T-1}/∂ys{T-2} * ... * ∂ys{t+1}/∂yst * ∂yst/∂xst
= bw(ysT, ys{T-1}) * bw(ys{T-1}, ys{T-2}) * ... * bw(ys{t+1}, yst) * bw(ys{t-1}, xst)
This motivates the use of a cumulative product over bwys to compute all such paths efficiently.
We now construct the matrix of gradient transition paths:
5.1 Repeat g_y values to form the base matrix
y_mat = [[1, bwys21, bwys32, bwys43],
[1, bwys21, bwys32, bwys43],
[1, bwys21, bwys32, bwys43],
[1, bwys21, bwys32, bwys43]]
5.2 Mask the lower triangle (inclusive) with 1s
y_mat = [[1, bwys21, bwys32, bwys43],
[1, 1 , bwys32, bwys43],
[1, 1 , 1 , bwys43],
[1, 1 , 1 , 1 ]]
5.3 Apply cumulative product row-wise
y_mat = cumprod(y_mat, dim=1)
Resulting in:
y_mat = [[1, bwys21, bwys32 * bwys21, bwys43 * bwys32 * bwys21],
[1, 1 , bwys32 , bwys43 * bwys32 ],
[1, 1 , 1 , bwys43 ],
[1, 1 , 1 , 1 ]]
5.4 Zero out the lower triangle (exclusive)
Final y_mat:
y_mat = [[1, bwys21, bwys32 * bwys21, bwys43 * bwys32 * bwys21],
[0, 1 , bwys32 , bwys43 * bwys32 ],
[0, 0 , 1 , bwys43 ],
[0, 0 , 0 , 1 ]]
6.) Scale the y_mat with the upstream gradients gl_ys
scaled_y_mat = y_mat * gl_ys
Each entry now holds the full contribution of ∂L/∂ysj to ∂L/∂xsi via the path through ysj.
7.) Reduce the scaled_y_mat with a row-wise sum
summed_y_mat = scaled_y_mat.sum(dim=1)
This accumulates all downstream contributions for each xst.
8.) Scale with the instantaneous input gradients bwxs
g_xs = summed_y_mat * bwxs
This gives the final input gradients:
g_xs = [∂L/∂xs0, ∂L/∂xs1, ..., ∂L/∂xsT]
NOTE: [scan partial grad handling]
If any element of xs or of the outputs does not require gradients
(i.e., requires_grad=False), then the corresponding gradients will be returned
as tensors of zeros with the same shape as the element.
"""
@staticmethod
def forward(
ctx,
combine_fn,
num_xs,
num_additional_inputs,
*operands,
):
ctx._num_xs = num_xs
ctx._num_additional_inputs = num_additional_inputs
ctx._combine_fn = combine_fn
xs, additional_inputs = split_into_chunks(
operands, [num_xs, num_additional_inputs]
)
scan_length = xs[0].shape[0]
ctx._scan_length = scan_length
# We snapshot the dispatch keys in forward for materializing the
# the bw_graph in backward.
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
with torch._C._AutoDispatchBelowAutograd():
# 1.) Compute the forward output of the associative_scan
ys = associative_scan_op(combine_fn, xs, additional_inputs)
save_tensors_and_symints_for_backward(ctx, list(operands) + list(ys))
return (*ys,)
@staticmethod
def backward(ctx, *gl_ys):
r"""
This function computes the gradients of the scan operation.
For a detailed description see the document above.
Args:
flat_grads (torch.Tensor): The tensor of upstream gradients, or a nested pytree of tensors.
E.g.: Gradient of the loss with respect to the forward output ys
"""
# The backward of associative_scan is always performed on the first dimension
dim = 0
scan_length = ctx._scan_length
num_xs = ctx._num_xs
num_additional_inputs = ctx._num_additional_inputs
# Extract the inputs to the forward path and outputs from the forward path
flat_args = saved_tensors_and_symints(ctx)
xs, additional_inputs, outs = split_into_chunks(
flat_args, [num_xs, num_additional_inputs, num_xs]
)
ndim = outs[0].ndim
# First_slice_copy does not keep the original requires_grad flag,
# but we need it here in order to compute the correcte gradients
xs_slices = first_slice_copy_with_grad(itertools.chain(xs, xs))
# Construct the operands from the forward, fw_operands
# and the operands for a single event t of the forward, fw_operands_slice
fw_operands = (*xs, *additional_inputs)
fw_operands_slice = (*xs_slices, *additional_inputs)
# 2.) Prepare the backward graph
combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands_slice)
# 3.) Materialize the ``combine_fn_bw``
# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint function when torch.compile torch.autograd.grad.
combine_fn_bw_gm = materialize_as_graph(
combine_fn_bw,
(
*fw_operands_slice,
*[first_slice_copy(o) for o in outs],
),
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
# vmap joint graph over scan dimension to compute the individual
# gradients for each time slice ``t`` in parallel.
# This computation can be parallelized, as these are just the instantaneous gradients and not the full chain-rule
mapped_combine_fn_bw_gm = torch.vmap(combine_fn_bw_gm, 0, 0)
# 4.) Compute the single step bw (instantaneous gradients) at every step ``t``
# Use a ones_like tensor in order not to scale the bwyst and bwxst,
# with the upstream gradients yet.
# Note: All bwyst and bwxst are computed in parallel, thus the tensors bwys and bwxs are the result.
dummy_upstream_grad = (torch.ones_like(x) for x in gl_ys)
grads = mapped_combine_fn_bw_gm(
*(o.roll(1, dim) for o in outs), *fw_operands, *dummy_upstream_grad
)
bwys, bwxs = split_into_chunks(grads, [num_xs, num_xs])
def compute_y_mat(bwys: torch.Tensor) -> torch.Tensor:
# Prepare a ones and a zeros helper mask in order to easily compute the y_mat
def compute_helper_tril_mask(diagonal):
def expand_masks(mask):
for _ in range(ndim - 1):
mask = mask.unsqueeze(-1)
return mask
tril_mask = torch.tril(
torch.ones(
scan_length, scan_length, device=bwys.device, dtype=torch.bool
),
diagonal=diagonal,
)
tril_mask = expand_masks(tril_mask)
tril_mask = tril_mask.expand(-1, -1, *bwys.shape[1:])
return tril_mask
# The ones mask is used to fill the main diagonal and all elements below it with 1s
ones_mask = compute_helper_tril_mask(0)
# The zero mask is used to set all elements below the main diagonal to 0
zeros_mask = compute_helper_tril_mask(-1)
# 5.1) Repeat the elements of bwys to form the square matrix
y_mat = bwys.unsqueeze(dim).repeat_interleave(scan_length, dim)
# 5.2) Fill the lower triangular part, including the diagonal,
# of the h_mat with 1s. I.e., use the ones_mask to fill with 1s.
y_mat.masked_fill_(ones_mask, 1.0)
# 5.3) Compute the cumulative products across dim + 1
y_mat = y_mat.cumprod(dim=dim + 1)
# 5.4) Replace the elements we filled with 1s before with 0s
y_mat.masked_fill_(zeros_mask, 0.0)
return y_mat
def compute_grad(bwxs, bwys, gl_ys):
# Set the first gradient component of bwxs to 1.0, per definition.
torch.select(bwxs, dim, 0).fill_(1.0)
# 5.) Compute the gradient transition matrix
y_mat = compute_y_mat(bwys)
# 6.) scale the y_mat with the upstream gradients gl_ys
scaled_y_mat = y_mat * gl_ys
# 7.) Reduce the y_mat with sum along the columns to get the total contributions for xs_t
summed_y_mat = scaled_y_mat.sum(dim + 1)
# 8.) Scale with the bwxs to obtain the final gradients g_xs
g_xs = summed_y_mat * bwxs
return g_xs
# Stack all leaves of the gradients along the first dimension.
# This is useful as later the gradients of those leaves can be computed in parallel.
bwxs_stacked_leaves = torch.stack(bwxs)
bwys_stacked_leaves = torch.stack(bwys)
gl_ys_stacked_leaves = torch.stack(gl_ys)
# The compute_grad function is parallelized across all individual leaves of xs
# as these gradients can be computed independently from each other
# TODO: torch.vmap may create composability issues
compute_grad_mapped = torch.vmap(compute_grad, 0, 0)
g_xs = compute_grad_mapped(
bwxs_stacked_leaves, bwys_stacked_leaves, gl_ys_stacked_leaves
)
# TODO: Currently the gradients for the additional_inputs are not computed properly
return *[None] * 3, *g_xs, *[None] * num_additional_inputs
@associative_scan_op.py_autograd_impl
def associative_scan_autograd(combine_fn, xs, additional_inputs):
num_xs = len(xs)
num_additional_inputs = len(additional_inputs)
if num_additional_inputs > 0:
raise RuntimeError(
"Associative_scan does currently not support gradients for lifted parameters!"
)
flat_out = AssociativeScanAutogradOp.apply(
combine_fn,
num_xs,
num_additional_inputs,
*(tuple(xs) + tuple(additional_inputs)),
)
return (*flat_out,)
associative_scan_op.py_autograd_impl(
autograd_not_implemented(associative_scan_op, deferred_error=True)
)
@associative_scan_op.py_impl(ProxyTorchDispatchMode)

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import itertools
from typing import Any, Callable
from typing import Any, Callable, Optional
import torch
import torch._prims_common as utils
@ -13,9 +13,6 @@ from torch._higher_order_ops.utils import (
check_meta_consistency,
create_bw_fn,
first_slice_copy,
first_slice_copy_with_grad,
get_tensor_mask,
mask_list,
materialize_as_graph,
reenter_make_fx,
save_tensors_and_symints_for_backward,
@ -63,6 +60,42 @@ def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor:
)
# NOTE: These functions can be reused in associative_scan and eventually moved to
# torch._higher_order_ops.utils
def get_tensor_mask(tensor_list: list[Any]) -> list[bool]:
# Returns a mask whether a list element is a tensor or not
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
def mask_list(
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
) -> list[Any]:
# Masks elements on an `inp` list.
# If other is None, then the elements of the `inp` list where the mask is False are removed
# If other is not None, then the elements of the `inp` list where the mask is False are
# replaced with the elements of the `other` list
assert len(mask) == len(inp), (
"The length of the mask needs to be identical to the length of the input"
)
if other is not None:
assert len(inp) == len(other), (
"If an input and an other list is provided, they need to have the same length"
)
return [i if m else o for m, i, o in zip(mask, inp, other)]
else:
return [i for m, i in zip(mask, inp) if m]
def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
# First_slice_copy does not keep the original requires_grad flag,
# but we need it for materialize_as_graph
# in order to compute the correct gradients
# The reason why first_slice_copy doesn't keep requires_grad flag is
# because it's called in torch.autograd.Function.backward/forward.
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
return slc
def call_operator(operator, *args):
return pytree.tree_leaves(operator(*args))

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, Callable, Optional, overload, TypeVar, Union
@ -804,40 +804,6 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
# Returns a mask whether a list element is a tensor or not
def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]:
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
def mask_list(
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
) -> list[Any]:
# Masks elements on an `inp` list.
# If other is None, then the elements of the `inp` list where the mask is False are removed
# If other is not None, then the elements of the `inp` list where the mask is False are
# replaced with the elements of the `other` list
assert len(mask) == len(inp), (
"The length of the mask needs to be identical to the length of the input"
)
if other is not None:
assert len(inp) == len(other), (
"If an input and an other list is provided, they need to have the same length"
)
return [i if m else o for m, i, o in zip(mask, inp, other)]
else:
return [i for m, i in zip(mask, inp) if m]
def first_slice_copy_with_grad(li: Iterable[Any]) -> list[Any]:
# First_slice_copy does not keep the original requires_grad flag,
# but we need it for materialize_as_graph
# in order to compute the correct gradients
# The reason why first_slice_copy doesn't keep requires_grad flag is
# because it's called in torch.autograd.Function.backward/forward.
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
return slc
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True