[expanded weights] add RNN support via decomp (#91807)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91807
Approved by: https://github.com/albanD
This commit is contained in:
lezcano
2023-02-08 09:36:19 +00:00
committed by PyTorch MergeBot
parent c2a92687e0
commit 20d01d2dc9
7 changed files with 165 additions and 16 deletions

View File

@ -11,12 +11,15 @@ from torch.nn import CrossEntropyLoss
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize
from torch.testing._internal.common_methods_invocations import SampleInput, op_db
from torch.nn.utils._expanded_weights import ExpandedWeight
from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \
unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs
from torch.utils._pytree import tree_map_only
class TestContext:
pass
@ -383,14 +386,22 @@ class TestExpandedWeightFunctional(TestCase):
F.group_norm(inp, 2) # 5 is not divisible by 2
class TestExpandedWeightModule(TestCase):
def _do_test(self, module, input):
batch_size = input.shape[0]
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
args = args or ()
kwargs = kwargs or {}
batch_dim = 0 if batch_first else 1
batch_size = input.shape[batch_dim]
diff_input = input.dtype == torch.float or input.dtype == torch.double
if diff_input:
input.requires_grad_()
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager
actual_res = call_for_per_sample_grads(module, loss_reduction="sum")(input).sum()
actual_res = call_for_per_sample_grads(module,
batch_size=batch_size,
loss_reduction="sum",
batch_first=batch_first)(input, *args, **kwargs).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -401,18 +412,24 @@ class TestExpandedWeightModule(TestCase):
input.grad = torch.zeros_like(input.grad)
# get per sample grads with a for loop
expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
expected_res = torch.tensor(0., device=input.device, dtype=actual_res.dtype)
expected_grads = []
for i in range(batch_size):
input_slice = input[i]
input_slice = input.narrow(batch_dim, i, 1)
input_slice = input_slice.squeeze(batch_dim)
# h's batch dim is always the first dim. Must be contiguous for CUDA
sliced_args = tree_map_only(torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args)
diff_params = module.parameters()
if diff_input:
diff_params = chain(diff_params, (input_slice,))
res = module(input_slice.unsqueeze(0)).sum()
res = module(input_slice.unsqueeze(batch_dim).contiguous(), *sliced_args, **kwargs).sum()
out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
expected_grads.append(out_grads)
expected_res += res
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
if not batch_first:
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
@ -457,6 +474,52 @@ class TestExpandedWeightModule(TestCase):
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls == torch.nn.RNN, module_db))
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
def __init__(self, m_cons, args, kwargs):
super().__init__()
self.m = m_cons(*args, **kwargs)
def forward(self, *inps):
ret = self.m(*inps)
assert isinstance(ret, tuple)
return ret[0]
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = RNNWrapper(module_cls, args, kwargs)
batch_first = m.m.batch_first
m.to(device).to(dtype)
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
# if the RNN tests use unbatched inputs--batch the inputs
input = args[0].detach()
if input.dim() == 2:
new_input_shape = [1] * (len(input.shape) + 1)
if batch_first:
new_input_shape[0] = 2
input = input.repeat(new_input_shape)
else:
new_input_shape[1] = 2
input = input.unsqueeze(1).repeat(new_input_shape)
h = args[1] if len(args) > 1 else None
if h is not None:
new_h_shape = [1] * (len(h.shape) + 1)
new_h_shape[1] = 2
h = h.unsqueeze(1).repeat(new_h_shape)
args = list(args)
args[1] = h
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
def test_per_sample_api_failing(self):
module = nn.Linear(10, 10)
input = torch.randn(64, 10)
@ -665,5 +728,6 @@ def clone_if_tensor(t):
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
instantiate_device_type_tests(TestExpandedWeightModule, globals())
if __name__ == '__main__':
run_tests()