mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[associative_scan] Autograd separated (#139939)"
This reverts commit 103f725afa8dbf0204a1be6a042ab93aa16d85d8. Reverted https://github.com/pytorch/pytorch/pull/139939 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I am seeing a weird failure after this lands in trunk ([comment](https://github.com/pytorch/pytorch/pull/139939#issuecomment-3267945657))
This commit is contained in:
@ -144,7 +144,7 @@ def get_scan_combine_fn(name, associative=True, parameters=None):
|
||||
}
|
||||
|
||||
def non_pointwise(x: torch.Tensor, y: torch.Tensor):
|
||||
W = torch.arange(4, dtype=torch.float, device=x.device).view(2, 2)
|
||||
W = torch.diag(torch.ones(2, device=x.device))
|
||||
return x @ W + y @ W
|
||||
|
||||
def RNN(x: torch.Tensor, y: torch.Tensor):
|
||||
@ -3717,38 +3717,11 @@ class AssociativeScanTests(TestCase):
|
||||
torch._dynamo.reset()
|
||||
super().setUp()
|
||||
|
||||
def _check_autograd(self, result, result_exp, autograd_param):
|
||||
grad_param = [p for p in autograd_param if p.requires_grad]
|
||||
|
||||
result_flatten, _ = pytree.tree_flatten(result)
|
||||
result_exp_flatten, _ = pytree.tree_flatten(result_exp)
|
||||
result_flatten = [r for r in result_flatten if r.requires_grad]
|
||||
result_exp_flatten = [r for r in result_exp_flatten if r.requires_grad]
|
||||
|
||||
# Check the result and parameter lists
|
||||
assert len(result_flatten) == len(result_exp_flatten), (
|
||||
"The number of elements requiring gradients is different for the results and the expected results"
|
||||
)
|
||||
|
||||
grad_exp_init = [torch.ones_like(el) for el in result_exp_flatten]
|
||||
expected_grads = torch.autograd.grad(
|
||||
result_exp_flatten, grad_param, grad_exp_init
|
||||
)
|
||||
grad_init = [torch.ones_like(el) for el in result_flatten]
|
||||
grads = torch.autograd.grad(result_flatten, grad_param, grad_init)
|
||||
|
||||
self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06)
|
||||
|
||||
def _run_test(self, model, model_fake, inputs, autograd_param=None):
|
||||
def _run_test(self, model, model_fake, inputs):
|
||||
result = model(inputs)
|
||||
result_exp = model_fake(inputs)
|
||||
self.assertEqual(result, result_exp)
|
||||
|
||||
if autograd_param is not None and any(
|
||||
par.requires_grad for par in autograd_param
|
||||
):
|
||||
self._check_autograd(result, result_exp, autograd_param)
|
||||
|
||||
# Return the result of the functions under test for further investigations
|
||||
return result
|
||||
|
||||
@ -3763,7 +3736,6 @@ class AssociativeScanTests(TestCase):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -3779,22 +3751,10 @@ class AssociativeScanTests(TestCase):
|
||||
)
|
||||
),
|
||||
)
|
||||
# # Skipping this combination as there is a CPP compilation failure that
|
||||
# # may be unrelated to associative_scan itself. There is a dedicated tests for
|
||||
# # this case below.
|
||||
# @decorateIf(
|
||||
# unittest.skip,
|
||||
# lambda params: (
|
||||
# params["compile_mode"] == "compile_dynamic_shape"
|
||||
# and params["combine_mode"] == "generic"
|
||||
# and params["device"] == torch.device("cpu")
|
||||
# and params["autograd"]
|
||||
# ),
|
||||
# )
|
||||
def test_associative_scan_compile(
|
||||
self, combine_mode, reverse, compile_mode, device, autograd
|
||||
self, combine_mode, reverse, compile_mode, device
|
||||
):
|
||||
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
x = torch.randn(3, 10, 2, device=device)
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
"reverse": reverse,
|
||||
@ -3806,7 +3766,6 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.Simple(**kwargs),
|
||||
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
if not reverse:
|
||||
@ -3816,9 +3775,7 @@ class AssociativeScanTests(TestCase):
|
||||
self.assertEqual(results, results_torch)
|
||||
|
||||
# Jax Examples
|
||||
x = torch.arange(
|
||||
0, 4, device=device, dtype=torch.float32, requires_grad=autograd
|
||||
)
|
||||
x = torch.arange(0, 4, device=device)
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
"reverse": reverse,
|
||||
@ -3831,13 +3788,12 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
if not reverse:
|
||||
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.float32)
|
||||
results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
|
||||
else:
|
||||
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.float32)
|
||||
results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
|
||||
|
||||
self.assertEqual(result, results_torch)
|
||||
|
||||
@ -3847,7 +3803,6 @@ class AssociativeScanTests(TestCase):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -3863,9 +3818,7 @@ class AssociativeScanTests(TestCase):
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_associative_scan_dim(
|
||||
self, combine_mode, compile_mode, reverse, device, autograd
|
||||
):
|
||||
def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device):
|
||||
import random
|
||||
|
||||
random.seed(1234)
|
||||
@ -3876,7 +3829,7 @@ class AssociativeScanTests(TestCase):
|
||||
torch._dynamo.reset()
|
||||
shapes = [random.randint(1, 9) for _ in range(num_dim)]
|
||||
rnd_scan_dim = random.randint(0, num_dim - 1)
|
||||
x = torch.randn(*shapes, device=device, requires_grad=autograd)
|
||||
x = torch.randn(*shapes, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": rnd_scan_dim,
|
||||
@ -3889,7 +3842,6 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.Simple(**kwargs),
|
||||
model_fake=AssociativeScanModels.Simple(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
if not reverse:
|
||||
@ -3928,7 +3880,6 @@ class AssociativeScanTests(TestCase):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -3944,11 +3895,9 @@ class AssociativeScanTests(TestCase):
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_associative_scan_tuple(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
):
|
||||
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, device):
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
inp = (x, y)
|
||||
|
||||
kwargs = {
|
||||
@ -3963,19 +3912,18 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else inp,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@requires_cuda
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_associative_scan_expand_in_combine_fn(
|
||||
self, compile_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
|
||||
def combine_fn(x, y):
|
||||
return x * torch.sum(y, -1).expand(x.shape)
|
||||
@ -3992,7 +3940,6 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4000,15 +3947,10 @@ class AssociativeScanTests(TestCase):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_associative_scan_non_contiguous_tensor(
|
||||
self, compile_mode, reverse, device, autograd
|
||||
self, compile_mode, reverse, device
|
||||
):
|
||||
x = (
|
||||
torch.arange(30, device=device, dtype=torch.float32, requires_grad=autograd)
|
||||
.view(10, 3)
|
||||
.t()
|
||||
)
|
||||
x = torch.arange(30, device=device).view(10, 3).t()
|
||||
assert not x.is_contiguous()
|
||||
|
||||
kwargs = {
|
||||
@ -4023,7 +3965,6 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4032,7 +3973,6 @@ class AssociativeScanTests(TestCase):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -4049,11 +3989,11 @@ class AssociativeScanTests(TestCase):
|
||||
),
|
||||
)
|
||||
def test_associative_scan_complex_pytree(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
x = torch.randn(3, 2, 2, device=device)
|
||||
y = torch.randn(3, 2, 2, device=device)
|
||||
z = torch.randn(3, 2, 2, device=device)
|
||||
inp = {"i": x, "j": ([y], [{"o": z}])}
|
||||
|
||||
kwargs = {
|
||||
@ -4068,7 +4008,6 @@ class AssociativeScanTests(TestCase):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (x, y, z),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("don't test compile on compile")
|
||||
@ -4218,7 +4157,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -4235,7 +4173,7 @@ class GraphModule(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
def test_associative_scan_downstream_scan_matmul(
|
||||
self, combine_mode, compile_mode, reverse, device, autograd
|
||||
self, combine_mode, compile_mode, reverse, device
|
||||
):
|
||||
def first_chain_fct(scan_fct, inp, **kwargs):
|
||||
o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
||||
@ -4245,7 +4183,7 @@ class GraphModule(torch.nn.Module):
|
||||
W = torch.ones(2, 5, device=device)
|
||||
return inp @ W
|
||||
|
||||
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 10, 2, device=device)
|
||||
kwargs = {
|
||||
"dim": 1,
|
||||
"reverse": reverse,
|
||||
@ -4258,7 +4196,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.ChainFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4267,7 +4204,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -4284,7 +4220,7 @@ class GraphModule(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
def test_associative_scan_downstream_scan_scan(
|
||||
self, combine_mode, compile_mode, reverse, device, autograd
|
||||
self, combine_mode, compile_mode, reverse, device
|
||||
):
|
||||
def first_chain_fct(scan_fct, inp, **kwargs):
|
||||
o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
||||
@ -4294,7 +4230,7 @@ class GraphModule(torch.nn.Module):
|
||||
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
||||
return o2
|
||||
|
||||
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 10, 2, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 1,
|
||||
@ -4308,7 +4244,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.ChainFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4318,7 +4253,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("reverse_first", [False, True])
|
||||
@parametrize("same_direction", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -4334,20 +4268,8 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
),
|
||||
)
|
||||
# Skipping the autograd=True because
|
||||
# associative_scan does currently not support gradients for lifted parameters
|
||||
@decorateIf(
|
||||
unittest.skip,
|
||||
lambda params: (params["combine_mode"] == "pointwise" and params["autograd"]),
|
||||
)
|
||||
def test_associative_scan_downstream_scan_scan_different_dim(
|
||||
self,
|
||||
combine_mode,
|
||||
compile_mode,
|
||||
reverse_first,
|
||||
same_direction,
|
||||
device,
|
||||
autograd,
|
||||
self, combine_mode, compile_mode, reverse_first, same_direction, device
|
||||
):
|
||||
reverse_second = reverse_first if same_direction else not reverse_first
|
||||
|
||||
@ -4359,7 +4281,7 @@ class GraphModule(torch.nn.Module):
|
||||
o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs)
|
||||
return o2
|
||||
|
||||
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 10, 2, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": [1, 0],
|
||||
@ -4373,7 +4295,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.ChainFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.ChainFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
# TODO: Does not work because of the usage of vmap within associative_scan
|
||||
@ -4432,9 +4353,8 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("loop_type", ["for"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_associative_scan_loop_in_combine_fn(
|
||||
self, compile_mode, loop_type, reverse, device, autograd
|
||||
self, compile_mode, loop_type, reverse, device
|
||||
):
|
||||
def combine_fn(x, y):
|
||||
cnt = torch.zeros_like(y[0, :])
|
||||
@ -4459,7 +4379,7 @@ class GraphModule(torch.nn.Module):
|
||||
cnt += torch.abs(y[ind])
|
||||
return x * cnt
|
||||
|
||||
inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd) * 2
|
||||
inp = torch.randn(3, 10, 1, device=device) * 2
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4473,7 +4393,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
# TODO: Does not work because of the usage of vmap within associative_scan
|
||||
@ -4518,7 +4437,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of compile_mode=compile_dynamic_shape
|
||||
# as the current implementation does not support lifted arguments
|
||||
@decorateIf(
|
||||
@ -4529,14 +4447,12 @@ class GraphModule(torch.nn.Module):
|
||||
or torch.version.hip
|
||||
),
|
||||
)
|
||||
def test_associative_scan_cond_in_combine_fn(
|
||||
self, compile_mode, reverse, device, autograd
|
||||
):
|
||||
def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device):
|
||||
def combine_fn(x, y):
|
||||
val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,))
|
||||
return x * val
|
||||
|
||||
inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 10, 1, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4550,7 +4466,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
# TODO: Does not work because of the usage of vmap within associative_scan
|
||||
@ -4592,10 +4507,7 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_associative_scan_vmap_in_combine_fn(
|
||||
self, compile_mode, reverse, device, autograd
|
||||
):
|
||||
def test_associative_scan_vmap_in_combine_fn(self, compile_mode, reverse, device):
|
||||
def combine_fn(x, y):
|
||||
def body(x):
|
||||
return x**2
|
||||
@ -4604,7 +4516,7 @@ class GraphModule(torch.nn.Module):
|
||||
y_new = mapped_body(y)
|
||||
return x + y_new
|
||||
|
||||
inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 10, 2, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4618,7 +4530,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4626,7 +4537,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of associative_scan and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
@decorateIf(
|
||||
@ -4634,9 +4544,9 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["device"] == torch.device("cpu")),
|
||||
)
|
||||
def test_associative_scan_non_pointwise_generic(
|
||||
self, reverse, compile_mode, device, autograd
|
||||
self, reverse, compile_mode, device
|
||||
):
|
||||
x = torch.randn(3, 10, 2, device=device, requires_grad=autograd)
|
||||
x = torch.randn(3, 10, 2, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4650,7 +4560,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=x,
|
||||
autograd_param=None if not autograd else (x,),
|
||||
)
|
||||
|
||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
||||
@ -4660,7 +4569,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combination of combine_mode=pointwise and device=cpu
|
||||
# as the current implementation of pointwise does only support CUDA device
|
||||
# Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape
|
||||
@ -4677,14 +4585,14 @@ class GraphModule(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
def test_associative_scan_binary_operator(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
state_dim = 20
|
||||
timesteps = 10
|
||||
projected_inputs = torch.randn(
|
||||
timesteps, state_dim, device=device, requires_grad=autograd
|
||||
timesteps, state_dim, requires_grad=True, device=device
|
||||
)
|
||||
A = torch.randn(state_dim, device=device, requires_grad=autograd)
|
||||
A = torch.randn(state_dim, requires_grad=True, device=device)
|
||||
elements = (A.repeat((timesteps, 1)), projected_inputs)
|
||||
|
||||
kwargs = {
|
||||
@ -4699,7 +4607,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=elements,
|
||||
autograd_param=None if not autograd else elements,
|
||||
)
|
||||
|
||||
@skipIfRocm(msg="Unsupported on ROCM yet")
|
||||
@ -4781,7 +4688,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combine_mode=pointwise
|
||||
# as the current implementation of associative_scan lowering
|
||||
# does not support lifted arguments
|
||||
@ -4790,9 +4696,9 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["combine_mode"] == "pointwise"),
|
||||
)
|
||||
def test_associative_scan_freevars_simple(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
H = torch.rand(2, device=device, requires_grad=autograd)
|
||||
H = torch.rand(2, device=device)
|
||||
|
||||
def fct_freevars1(x: torch.Tensor, y: torch.Tensor):
|
||||
return x * H + y * 2
|
||||
@ -4800,13 +4706,13 @@ class GraphModule(torch.nn.Module):
|
||||
def fct_freevars2(x: torch.Tensor, y: torch.Tensor):
|
||||
return x * H + y * H
|
||||
|
||||
H1 = torch.rand(1, device=device, requires_grad=autograd)
|
||||
H2 = torch.rand(1, device=device, requires_grad=autograd)
|
||||
H1 = torch.rand(1, device=device)
|
||||
H2 = torch.rand(1, device=device)
|
||||
|
||||
def fct_freevars3(x: torch.Tensor, y: torch.Tensor):
|
||||
return x * H1 + y * H2
|
||||
|
||||
inp = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 2, 2, device=device)
|
||||
|
||||
for fct, param in [
|
||||
(fct_freevars1, (H,)),
|
||||
@ -4825,7 +4731,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp, *param),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4834,7 +4739,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combine_mode=pointwise
|
||||
# as the current implementation of associative_scan lowering
|
||||
# does not support lifted arguments
|
||||
@ -4843,10 +4747,10 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["combine_mode"] == "pointwise"),
|
||||
)
|
||||
def test_associative_scan_freevars_nested(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
H1 = torch.rand(4, 5, device=device, requires_grad=autograd)
|
||||
H2 = torch.rand(4, 1, device=device, requires_grad=autograd)
|
||||
H1 = torch.rand(4, 5, device=device)
|
||||
H2 = torch.rand(4, 1, device=device)
|
||||
|
||||
def fct_nested_outside(x: torch.Tensor, y: torch.Tensor):
|
||||
def inner(xi):
|
||||
@ -4862,10 +4766,13 @@ class GraphModule(torch.nn.Module):
|
||||
ret = inner(y)
|
||||
return x + ret * H1
|
||||
|
||||
H1_i = torch.rand(4, 5, device=device)
|
||||
|
||||
# TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error:
|
||||
# RuntimeError: vmap: called random operation while in randomness error mode.
|
||||
# Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap
|
||||
def fct_nested_inside(x: torch.Tensor, y: torch.Tensor):
|
||||
# H2_i = torch.rand(4, 1, device=device)
|
||||
H2_i = torch.ones(4, 1, device=device) * 42
|
||||
|
||||
def inner(xi):
|
||||
@ -4875,6 +4782,7 @@ class GraphModule(torch.nn.Module):
|
||||
return x + ret * H1
|
||||
|
||||
def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor):
|
||||
# H2_i = torch.rand(4, 1, device=device)
|
||||
H2_i = torch.ones(4, 1, device=device) * 42
|
||||
|
||||
def inner(xi):
|
||||
@ -4883,11 +4791,11 @@ class GraphModule(torch.nn.Module):
|
||||
ret = inner(y)
|
||||
return x + ret * H1
|
||||
|
||||
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 4, 5, device=device)
|
||||
|
||||
for fct, fct_fake, param in [
|
||||
(fct_nested_outside, fct_nested_outside_fake, (H1, H2)),
|
||||
(fct_nested_inside, fct_nested_inside_fake, ()),
|
||||
(fct_nested_inside, fct_nested_inside_fake, (H1_i,)),
|
||||
]:
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4902,7 +4810,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp, *param),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4911,7 +4818,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combine_mode=pointwise
|
||||
# as the current implementation of associative_scan lowering
|
||||
# does not support lifted arguments
|
||||
@ -4920,7 +4826,7 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["combine_mode"] == "pointwise"),
|
||||
)
|
||||
def test_associative_scan_freevars_fct(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
def additional_fct_no_add_inp(x, y):
|
||||
return x * y
|
||||
@ -4929,7 +4835,7 @@ class GraphModule(torch.nn.Module):
|
||||
ret = additional_fct_no_add_inp(y, y)
|
||||
return x + ret
|
||||
|
||||
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 4, 5, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4943,7 +4849,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4951,10 +4856,7 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
def test_associative_scan_freevars_fct_generic(
|
||||
self, compile_mode, reverse, device, autograd
|
||||
):
|
||||
def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device):
|
||||
def additional_fct_no_add_inp(x, y):
|
||||
return x * y
|
||||
|
||||
@ -4968,7 +4870,7 @@ class GraphModule(torch.nn.Module):
|
||||
ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1)
|
||||
return x + ret
|
||||
|
||||
inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd)
|
||||
inp = torch.randn(3, 4, 5, device=device)
|
||||
|
||||
kwargs = {
|
||||
"dim": 0,
|
||||
@ -4983,7 +4885,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -4992,7 +4893,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combine_mode=pointwise
|
||||
# as the current implementation of associative_scan lowering
|
||||
# does not support lifted arguments
|
||||
@ -5001,7 +4901,7 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["combine_mode"] == "pointwise"),
|
||||
)
|
||||
def test_associative_scan_freevars_shape_check(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
H = torch.eye(2, device=device, requires_grad=True)
|
||||
|
||||
@ -5022,7 +4922,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (inp,),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@ -5031,7 +4930,6 @@ class GraphModule(torch.nn.Module):
|
||||
@parametrize("reverse", [False, True])
|
||||
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
|
||||
@parametrize("combine_mode", ["pointwise", "generic"])
|
||||
@parametrize("autograd", [False, True])
|
||||
# Skipping the combine_mode=pointwise
|
||||
# as the current implementation of associative_scan lowering
|
||||
# does not support lifted arguments
|
||||
@ -5040,11 +4938,11 @@ class GraphModule(torch.nn.Module):
|
||||
lambda params: (params["combine_mode"] == "pointwise"),
|
||||
)
|
||||
def test_associative_scan_freevars_pytree(
|
||||
self, compile_mode, combine_mode, reverse, device, autograd
|
||||
self, compile_mode, combine_mode, reverse, device
|
||||
):
|
||||
xf = torch.randn(2, 2, device=device, requires_grad=autograd)
|
||||
yf = torch.randn(2, 2, device=device, requires_grad=autograd)
|
||||
zf = torch.randn(2, 2, device=device, requires_grad=autograd)
|
||||
xf = torch.randn(2, 2, device=device, requires_grad=True)
|
||||
yf = torch.randn(2, 2, device=device, requires_grad=True)
|
||||
zf = torch.randn(2, 2, device=device, requires_grad=True)
|
||||
inpf = {"i": xf, "j": ([yf], [{"o": zf}])}
|
||||
|
||||
def fct_pointwise(x, y):
|
||||
@ -5061,9 +4959,9 @@ class GraphModule(torch.nn.Module):
|
||||
),
|
||||
}
|
||||
|
||||
x = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
y = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
z = torch.randn(3, 2, 2, device=device, requires_grad=autograd)
|
||||
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
||||
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
||||
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
|
||||
inp = {"i": x, "j": ([y], [{"o": z}])}
|
||||
|
||||
kwargs = {
|
||||
@ -5078,7 +4976,6 @@ class GraphModule(torch.nn.Module):
|
||||
model=AssociativeScanModels.CombineFn(**kwargs),
|
||||
model_fake=AssociativeScanModels.CombineFn(**kwargs_fake),
|
||||
inputs=inp,
|
||||
autograd_param=None if not autograd else (*pytree.tree_leaves(inp),),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
|
@ -5,21 +5,17 @@ from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
import torch._subclasses.functional_tensor
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_maybe_compile_and_run_fn,
|
||||
_maybe_run_with_interpreter,
|
||||
autograd_not_implemented,
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
check_meta_consistency,
|
||||
create_bw_fn,
|
||||
first_slice_copy,
|
||||
first_slice_copy_with_grad,
|
||||
materialize_as_graph,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
split_into_chunks,
|
||||
unique_graph_id,
|
||||
validate_subgraph_args_types,
|
||||
)
|
||||
@ -195,9 +191,6 @@ def associative_scan(
|
||||
cumsum = associative_scan(add, x, dim)
|
||||
|
||||
"""
|
||||
# TODO: Support lifted arguments in inductor for associative_scan
|
||||
# TODO: Support autograd for cases with lifted arguments for combine_mode=pointwise
|
||||
|
||||
# The reason we flatten xs before calling into dynamo is that
|
||||
# we want to create a consistent input ordering for combine_fn
|
||||
# and we also want to the input ordering matches the output ordering.
|
||||
@ -249,6 +242,9 @@ def associative_scan(
|
||||
if reverse:
|
||||
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
|
||||
|
||||
# TODO: Support Autograd
|
||||
# TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
|
||||
|
||||
if combine_mode == "generic":
|
||||
# The generic_associative_scan implementation calls the combine_fn with a `batch` along the scan dimension
|
||||
# For example, consider:
|
||||
@ -472,378 +468,9 @@ def associative_scan_op_dense(combine_fn, xs, additional_inputs):
|
||||
return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs)
|
||||
|
||||
|
||||
class AssociativeScanAutogradOp(torch.autograd.Function):
|
||||
r""" associative_scan
|
||||
Example::
|
||||
xs = torch.arange(1, 5) = [1, 2, 3, 4]
|
||||
|
||||
def combine_fn(a: torch.Tensor, b: torch.Tensor):
|
||||
return a * b
|
||||
|
||||
ys = associative_scan(comine_fn, xs),
|
||||
which can be unpacked as:
|
||||
ys0 = xs0 = 1
|
||||
ys1 = combine_fn(ys0, xs1) = combine_fn(1, 2) = 2
|
||||
...
|
||||
ysT = combine_fn(ys(T-1), xsT) = combine_fn(6, 4) = 24
|
||||
ys = [1, 2, 6, 24]
|
||||
|
||||
This creates a recursive data dependency structure where each output yst
|
||||
depends on all prior inputs xs0 through xst. The dependency can be visualized as:
|
||||
|
||||
Level 0 (Input): xs0 xs1 xs2 xs3 xs4
|
||||
\ / | | |
|
||||
\ / | | |
|
||||
Level 1: ys1 ───────┘ | |
|
||||
\ / |
|
||||
\ / |
|
||||
Level 2: ys2 ────────┘ |
|
||||
\ /
|
||||
\ /
|
||||
Level 3: ys3 ────────────┘
|
||||
\
|
||||
\
|
||||
Level 4: ys4
|
||||
|
||||
|
||||
We could get the following backward gradient graph:
|
||||
|
||||
|
||||
Level 0 (output): g_xs0 g_xs1 g_xs2 g_xs3 g_xs4
|
||||
\ / | | |
|
||||
\ / | | |
|
||||
Level 1: gl_ys1 ─> g_ys1 ──────┘ | |
|
||||
\ / |
|
||||
\ / |
|
||||
Level 2: gl_ys2 ─> g_ys2 ────────┘ |
|
||||
\ /
|
||||
\ /
|
||||
Level 3: gl_ys3 ─> g_ys3 ───────────┘
|
||||
\
|
||||
\
|
||||
Level 4: gl_ys4 ─> g_ys4,
|
||||
|
||||
where gl_y1 is the gradient of the loss with respect to ys1 and the input of backward.
|
||||
|
||||
To calculate the gradients of the inputs, the chain rule suggests:
|
||||
|
||||
g_xs0 = g_ys1
|
||||
g_xs1 = g_ys1 * bw(ys0, xs1) = g_ys1 * bwxs01
|
||||
g_xs2 = g_ys2 * bw(ys1, xs2) = g_ys2 * bwxs12
|
||||
g_xs3 = g_ys3 * bw(ys2, xs3) = g_ys3 * bwxs23
|
||||
g_xs4 = g_ys4 * bw(ys3, xs4) = g_ys4 * bwxs34
|
||||
|
||||
Notice the bw(...) is just the single step bw (instantaneous gradients), whose formula can be computed from combine_fn.
|
||||
For example bw(ys3, xs4) (also abbreviated with bwxs34) computes the gradients ∂/∂xs4 combine_fn(ys3, xs4).
|
||||
Similarly, bw(ys4, ys3) (also abbreviated with bwys43) computes the gradients ∂/∂ys3 combine_fn(ys3, xs4).
|
||||
|
||||
Let's break down how to calculate g_ys by recursively substituting the unknowns:
|
||||
|
||||
g_ys1 = gl_ys1 + g_ys2 * bw(ys2, ys1)
|
||||
= gl_ys1 + (gl_ys2 + g_ys3 * bw(ys3, ys2)) * bw(ys2, ys1)
|
||||
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + g_ys3 * bw(ys3, ys2) * bw(y2, y1)
|
||||
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + gl_ys3 * bw(ys3, ys2) * bw(y2, y1) \
|
||||
+ g_ys4 * bw(ys4, ys3) * bw(ys3, ys2) * bw(ys2, ys1)
|
||||
= gl_ys1 + gl_ys2 * bw(ys2, ys1) + gl_ys3 * bw(ys3, ys2) * bw(y2, y1) \
|
||||
+ gl_ys4 * bw(ys4, ys3) * bw(ys3, ys2) * bw(ys2, ys1)
|
||||
|
||||
Let's do the same for all the g_ys:
|
||||
g_ys2 = gl_ys2 + gl_ys3 * bw(ys3, ys2) + gl_y4 * bw(ys4, ys3) * bw(ys3, ys2)
|
||||
g_ys3 = gl_ys3 + gl_ys4 * bw(ys4, ys3)
|
||||
g_ys4 = gl_ys4
|
||||
|
||||
Notice that the above can be re-written as columnwise multiplication of y_mat and gl_ys:
|
||||
|
||||
g_ys1 1, bwys21, bwys321, bwys4321 gl_ys1
|
||||
g_ys2 = 0, 1 , bwys321, bwys4321 . gl_ys2
|
||||
g_ys3 0, 0 , 1 , bwys4321 gl_ys3
|
||||
g_ys4 0, 0 , 0 , 1 gl_ys4,
|
||||
|
||||
where bwys21 is an abbreviation for bw(ys2, ys1),
|
||||
bwys321 is an abbreviation for bw(ys3, ys2) * bw(ys2, ys1) so on and so forth.
|
||||
|
||||
We could effectively compute the upper triangular matrix y_mat with:
|
||||
cumprod([1, bwys21, bwys32, bwys43]) then masking out the values as needed.
|
||||
Thus, only [1, bwys21, bwys32, bwys43] are required to compute the y_mat.
|
||||
|
||||
|
||||
References: https://justintchiu.com/blog/pscan_diff/
|
||||
|
||||
NOTE: [associative_scan autograd implementation]
|
||||
|
||||
The forward of associative_scan can be computed with the following steps:
|
||||
|
||||
1.) Compute the forward output of the associative_scan
|
||||
ys = associative_scan(combine_fn, xs, additional_inputs)
|
||||
|
||||
The backward of associative_scan can be computed with the following steps:
|
||||
|
||||
2.) Prepare the backward graph
|
||||
We prepare the backward graph to be used in the backward function.
|
||||
We utilize ``create_bw_fn`` to generate the joint function:
|
||||
combine_fn_bw = create_bw_fn(combine_fn, operands)
|
||||
where operands = [ys{t-1}, xst, additional_inputs]
|
||||
|
||||
3.) Materialize the ``combine_fn_bw``
|
||||
This is required because torch.compile and torch.autograd.grad
|
||||
cannot trace through the joint backward function dynamically.
|
||||
|
||||
4.) Compute the single step bw (instantaneous gradients) at every step t
|
||||
bwys{t-1}, bwxst = combine_fn_bw(ys{t-1}, xst, 1.)
|
||||
Here we pass 1 as the upstream gradient to obtain the local partial derivatives.
|
||||
|
||||
This gives:
|
||||
bwys = [bw(ys1, ys0), bw(ys2, ys1), ..., bw(ysT, ys{T-1})]
|
||||
bwxs = [bw(ys1, xs0), bw(ys2, xs1), ..., bw(ys{T-1}, xsT)]
|
||||
|
||||
5.) Compute the gradient transition matrix y_mat
|
||||
|
||||
As shown in the example above, each input xst affects all later outputs ysi for i ≥ t.
|
||||
According to the chain rule, each such path contributes a product of local gradients g_ysk.
|
||||
|
||||
For example:
|
||||
∂ysT/∂xst = ∂ysT/∂ys{T-1} * ∂ys{T-1}/∂ys{T-2} * ... * ∂ys{t+1}/∂yst * ∂yst/∂xst
|
||||
= bw(ysT, ys{T-1}) * bw(ys{T-1}, ys{T-2}) * ... * bw(ys{t+1}, yst) * bw(ys{t-1}, xst)
|
||||
|
||||
This motivates the use of a cumulative product over bwys to compute all such paths efficiently.
|
||||
|
||||
We now construct the matrix of gradient transition paths:
|
||||
|
||||
5.1 Repeat g_y values to form the base matrix
|
||||
y_mat = [[1, bwys21, bwys32, bwys43],
|
||||
[1, bwys21, bwys32, bwys43],
|
||||
[1, bwys21, bwys32, bwys43],
|
||||
[1, bwys21, bwys32, bwys43]]
|
||||
|
||||
5.2 Mask the lower triangle (inclusive) with 1s
|
||||
y_mat = [[1, bwys21, bwys32, bwys43],
|
||||
[1, 1 , bwys32, bwys43],
|
||||
[1, 1 , 1 , bwys43],
|
||||
[1, 1 , 1 , 1 ]]
|
||||
|
||||
5.3 Apply cumulative product row-wise
|
||||
y_mat = cumprod(y_mat, dim=1)
|
||||
Resulting in:
|
||||
y_mat = [[1, bwys21, bwys32 * bwys21, bwys43 * bwys32 * bwys21],
|
||||
[1, 1 , bwys32 , bwys43 * bwys32 ],
|
||||
[1, 1 , 1 , bwys43 ],
|
||||
[1, 1 , 1 , 1 ]]
|
||||
|
||||
5.4 Zero out the lower triangle (exclusive)
|
||||
Final y_mat:
|
||||
y_mat = [[1, bwys21, bwys32 * bwys21, bwys43 * bwys32 * bwys21],
|
||||
[0, 1 , bwys32 , bwys43 * bwys32 ],
|
||||
[0, 0 , 1 , bwys43 ],
|
||||
[0, 0 , 0 , 1 ]]
|
||||
|
||||
6.) Scale the y_mat with the upstream gradients gl_ys
|
||||
scaled_y_mat = y_mat * gl_ys
|
||||
Each entry now holds the full contribution of ∂L/∂ysj to ∂L/∂xsi via the path through ysj.
|
||||
|
||||
7.) Reduce the scaled_y_mat with a row-wise sum
|
||||
summed_y_mat = scaled_y_mat.sum(dim=1)
|
||||
This accumulates all downstream contributions for each xst.
|
||||
|
||||
8.) Scale with the instantaneous input gradients bwxs
|
||||
g_xs = summed_y_mat * bwxs
|
||||
|
||||
This gives the final input gradients:
|
||||
g_xs = [∂L/∂xs0, ∂L/∂xs1, ..., ∂L/∂xsT]
|
||||
|
||||
NOTE: [scan partial grad handling]
|
||||
If any element of xs or of the outputs does not require gradients
|
||||
(i.e., requires_grad=False), then the corresponding gradients will be returned
|
||||
as tensors of zeros with the same shape as the element.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
combine_fn,
|
||||
num_xs,
|
||||
num_additional_inputs,
|
||||
*operands,
|
||||
):
|
||||
ctx._num_xs = num_xs
|
||||
ctx._num_additional_inputs = num_additional_inputs
|
||||
ctx._combine_fn = combine_fn
|
||||
xs, additional_inputs = split_into_chunks(
|
||||
operands, [num_xs, num_additional_inputs]
|
||||
)
|
||||
|
||||
scan_length = xs[0].shape[0]
|
||||
ctx._scan_length = scan_length
|
||||
|
||||
# We snapshot the dispatch keys in forward for materializing the
|
||||
# the bw_graph in backward.
|
||||
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
||||
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
# 1.) Compute the forward output of the associative_scan
|
||||
ys = associative_scan_op(combine_fn, xs, additional_inputs)
|
||||
save_tensors_and_symints_for_backward(ctx, list(operands) + list(ys))
|
||||
|
||||
return (*ys,)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *gl_ys):
|
||||
r"""
|
||||
This function computes the gradients of the scan operation.
|
||||
For a detailed description see the document above.
|
||||
|
||||
Args:
|
||||
flat_grads (torch.Tensor): The tensor of upstream gradients, or a nested pytree of tensors.
|
||||
E.g.: Gradient of the loss with respect to the forward output ys
|
||||
"""
|
||||
|
||||
# The backward of associative_scan is always performed on the first dimension
|
||||
dim = 0
|
||||
scan_length = ctx._scan_length
|
||||
num_xs = ctx._num_xs
|
||||
num_additional_inputs = ctx._num_additional_inputs
|
||||
|
||||
# Extract the inputs to the forward path and outputs from the forward path
|
||||
flat_args = saved_tensors_and_symints(ctx)
|
||||
xs, additional_inputs, outs = split_into_chunks(
|
||||
flat_args, [num_xs, num_additional_inputs, num_xs]
|
||||
)
|
||||
ndim = outs[0].ndim
|
||||
|
||||
# First_slice_copy does not keep the original requires_grad flag,
|
||||
# but we need it here in order to compute the correcte gradients
|
||||
xs_slices = first_slice_copy_with_grad(itertools.chain(xs, xs))
|
||||
|
||||
# Construct the operands from the forward, fw_operands
|
||||
# and the operands for a single event t of the forward, fw_operands_slice
|
||||
fw_operands = (*xs, *additional_inputs)
|
||||
fw_operands_slice = (*xs_slices, *additional_inputs)
|
||||
|
||||
# 2.) Prepare the backward graph
|
||||
combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands_slice)
|
||||
|
||||
# 3.) Materialize the ``combine_fn_bw``
|
||||
# TODO: we need to materialize the bw graphs because dynamo is unable to
|
||||
# trace through the joint function when torch.compile torch.autograd.grad.
|
||||
combine_fn_bw_gm = materialize_as_graph(
|
||||
combine_fn_bw,
|
||||
(
|
||||
*fw_operands_slice,
|
||||
*[first_slice_copy(o) for o in outs],
|
||||
),
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
# vmap joint graph over scan dimension to compute the individual
|
||||
# gradients for each time slice ``t`` in parallel.
|
||||
# This computation can be parallelized, as these are just the instantaneous gradients and not the full chain-rule
|
||||
mapped_combine_fn_bw_gm = torch.vmap(combine_fn_bw_gm, 0, 0)
|
||||
|
||||
# 4.) Compute the single step bw (instantaneous gradients) at every step ``t``
|
||||
# Use a ones_like tensor in order not to scale the bwyst and bwxst,
|
||||
# with the upstream gradients yet.
|
||||
# Note: All bwyst and bwxst are computed in parallel, thus the tensors bwys and bwxs are the result.
|
||||
dummy_upstream_grad = (torch.ones_like(x) for x in gl_ys)
|
||||
grads = mapped_combine_fn_bw_gm(
|
||||
*(o.roll(1, dim) for o in outs), *fw_operands, *dummy_upstream_grad
|
||||
)
|
||||
bwys, bwxs = split_into_chunks(grads, [num_xs, num_xs])
|
||||
|
||||
def compute_y_mat(bwys: torch.Tensor) -> torch.Tensor:
|
||||
# Prepare a ones and a zeros helper mask in order to easily compute the y_mat
|
||||
def compute_helper_tril_mask(diagonal):
|
||||
def expand_masks(mask):
|
||||
for _ in range(ndim - 1):
|
||||
mask = mask.unsqueeze(-1)
|
||||
return mask
|
||||
|
||||
tril_mask = torch.tril(
|
||||
torch.ones(
|
||||
scan_length, scan_length, device=bwys.device, dtype=torch.bool
|
||||
),
|
||||
diagonal=diagonal,
|
||||
)
|
||||
tril_mask = expand_masks(tril_mask)
|
||||
tril_mask = tril_mask.expand(-1, -1, *bwys.shape[1:])
|
||||
return tril_mask
|
||||
|
||||
# The ones mask is used to fill the main diagonal and all elements below it with 1s
|
||||
ones_mask = compute_helper_tril_mask(0)
|
||||
|
||||
# The zero mask is used to set all elements below the main diagonal to 0
|
||||
zeros_mask = compute_helper_tril_mask(-1)
|
||||
|
||||
# 5.1) Repeat the elements of bwys to form the square matrix
|
||||
y_mat = bwys.unsqueeze(dim).repeat_interleave(scan_length, dim)
|
||||
|
||||
# 5.2) Fill the lower triangular part, including the diagonal,
|
||||
# of the h_mat with 1s. I.e., use the ones_mask to fill with 1s.
|
||||
y_mat.masked_fill_(ones_mask, 1.0)
|
||||
|
||||
# 5.3) Compute the cumulative products across dim + 1
|
||||
y_mat = y_mat.cumprod(dim=dim + 1)
|
||||
|
||||
# 5.4) Replace the elements we filled with 1s before with 0s
|
||||
y_mat.masked_fill_(zeros_mask, 0.0)
|
||||
|
||||
return y_mat
|
||||
|
||||
def compute_grad(bwxs, bwys, gl_ys):
|
||||
# Set the first gradient component of bwxs to 1.0, per definition.
|
||||
torch.select(bwxs, dim, 0).fill_(1.0)
|
||||
|
||||
# 5.) Compute the gradient transition matrix
|
||||
y_mat = compute_y_mat(bwys)
|
||||
|
||||
# 6.) scale the y_mat with the upstream gradients gl_ys
|
||||
scaled_y_mat = y_mat * gl_ys
|
||||
|
||||
# 7.) Reduce the y_mat with sum along the columns to get the total contributions for xs_t
|
||||
summed_y_mat = scaled_y_mat.sum(dim + 1)
|
||||
|
||||
# 8.) Scale with the bwxs to obtain the final gradients g_xs
|
||||
g_xs = summed_y_mat * bwxs
|
||||
|
||||
return g_xs
|
||||
|
||||
# Stack all leaves of the gradients along the first dimension.
|
||||
# This is useful as later the gradients of those leaves can be computed in parallel.
|
||||
bwxs_stacked_leaves = torch.stack(bwxs)
|
||||
bwys_stacked_leaves = torch.stack(bwys)
|
||||
gl_ys_stacked_leaves = torch.stack(gl_ys)
|
||||
|
||||
# The compute_grad function is parallelized across all individual leaves of xs
|
||||
# as these gradients can be computed independently from each other
|
||||
# TODO: torch.vmap may create composability issues
|
||||
compute_grad_mapped = torch.vmap(compute_grad, 0, 0)
|
||||
|
||||
g_xs = compute_grad_mapped(
|
||||
bwxs_stacked_leaves, bwys_stacked_leaves, gl_ys_stacked_leaves
|
||||
)
|
||||
|
||||
# TODO: Currently the gradients for the additional_inputs are not computed properly
|
||||
return *[None] * 3, *g_xs, *[None] * num_additional_inputs
|
||||
|
||||
|
||||
@associative_scan_op.py_autograd_impl
|
||||
def associative_scan_autograd(combine_fn, xs, additional_inputs):
|
||||
num_xs = len(xs)
|
||||
num_additional_inputs = len(additional_inputs)
|
||||
|
||||
if num_additional_inputs > 0:
|
||||
raise RuntimeError(
|
||||
"Associative_scan does currently not support gradients for lifted parameters!"
|
||||
)
|
||||
|
||||
flat_out = AssociativeScanAutogradOp.apply(
|
||||
combine_fn,
|
||||
num_xs,
|
||||
num_additional_inputs,
|
||||
*(tuple(xs) + tuple(additional_inputs)),
|
||||
)
|
||||
return (*flat_out,)
|
||||
associative_scan_op.py_autograd_impl(
|
||||
autograd_not_implemented(associative_scan_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@associative_scan_op.py_impl(ProxyTorchDispatchMode)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -13,9 +13,6 @@ from torch._higher_order_ops.utils import (
|
||||
check_meta_consistency,
|
||||
create_bw_fn,
|
||||
first_slice_copy,
|
||||
first_slice_copy_with_grad,
|
||||
get_tensor_mask,
|
||||
mask_list,
|
||||
materialize_as_graph,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
@ -63,6 +60,42 @@ def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
# NOTE: These functions can be reused in associative_scan and eventually moved to
|
||||
# torch._higher_order_ops.utils
|
||||
def get_tensor_mask(tensor_list: list[Any]) -> list[bool]:
|
||||
# Returns a mask whether a list element is a tensor or not
|
||||
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
|
||||
|
||||
|
||||
def mask_list(
|
||||
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
|
||||
) -> list[Any]:
|
||||
# Masks elements on an `inp` list.
|
||||
# If other is None, then the elements of the `inp` list where the mask is False are removed
|
||||
# If other is not None, then the elements of the `inp` list where the mask is False are
|
||||
# replaced with the elements of the `other` list
|
||||
assert len(mask) == len(inp), (
|
||||
"The length of the mask needs to be identical to the length of the input"
|
||||
)
|
||||
if other is not None:
|
||||
assert len(inp) == len(other), (
|
||||
"If an input and an other list is provided, they need to have the same length"
|
||||
)
|
||||
return [i if m else o for m, i, o in zip(mask, inp, other)]
|
||||
else:
|
||||
return [i for m, i in zip(mask, inp) if m]
|
||||
|
||||
|
||||
def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
|
||||
# First_slice_copy does not keep the original requires_grad flag,
|
||||
# but we need it for materialize_as_graph
|
||||
# in order to compute the correct gradients
|
||||
# The reason why first_slice_copy doesn't keep requires_grad flag is
|
||||
# because it's called in torch.autograd.Function.backward/forward.
|
||||
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
|
||||
return slc
|
||||
|
||||
|
||||
def call_operator(operator, *args):
|
||||
return pytree.tree_leaves(operator(*args))
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
@ -804,40 +804,6 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
return torch.select_copy(t, dim, 0)
|
||||
|
||||
|
||||
# Returns a mask whether a list element is a tensor or not
|
||||
def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]:
|
||||
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
|
||||
|
||||
|
||||
def mask_list(
|
||||
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
|
||||
) -> list[Any]:
|
||||
# Masks elements on an `inp` list.
|
||||
# If other is None, then the elements of the `inp` list where the mask is False are removed
|
||||
# If other is not None, then the elements of the `inp` list where the mask is False are
|
||||
# replaced with the elements of the `other` list
|
||||
assert len(mask) == len(inp), (
|
||||
"The length of the mask needs to be identical to the length of the input"
|
||||
)
|
||||
if other is not None:
|
||||
assert len(inp) == len(other), (
|
||||
"If an input and an other list is provided, they need to have the same length"
|
||||
)
|
||||
return [i if m else o for m, i, o in zip(mask, inp, other)]
|
||||
else:
|
||||
return [i for m, i in zip(mask, inp) if m]
|
||||
|
||||
|
||||
def first_slice_copy_with_grad(li: Iterable[Any]) -> list[Any]:
|
||||
# First_slice_copy does not keep the original requires_grad flag,
|
||||
# but we need it for materialize_as_graph
|
||||
# in order to compute the correct gradients
|
||||
# The reason why first_slice_copy doesn't keep requires_grad flag is
|
||||
# because it's called in torch.autograd.Function.backward/forward.
|
||||
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
|
||||
return slc
|
||||
|
||||
|
||||
# Reports the difference between meta of two tensors in a string
|
||||
def diff_tensor_meta(
|
||||
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
||||
|
Reference in New Issue
Block a user