mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[expanded weights] add RNN support via decomp (#91807)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91807 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2a92687e0
commit
20d01d2dc9
@ -11,12 +11,15 @@ from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize
|
||||
from torch.testing._internal.common_methods_invocations import SampleInput, op_db
|
||||
from torch.nn.utils._expanded_weights import ExpandedWeight
|
||||
from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \
|
||||
unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
class TestContext:
|
||||
pass
|
||||
@ -383,14 +386,22 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
batch_size = input.shape[0]
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
batch_dim = 0 if batch_first else 1
|
||||
batch_size = input.shape[batch_dim]
|
||||
diff_input = input.dtype == torch.float or input.dtype == torch.double
|
||||
if diff_input:
|
||||
input.requires_grad_()
|
||||
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager
|
||||
actual_res = call_for_per_sample_grads(module, loss_reduction="sum")(input).sum()
|
||||
actual_res = call_for_per_sample_grads(module,
|
||||
batch_size=batch_size,
|
||||
loss_reduction="sum",
|
||||
batch_first=batch_first)(input, *args, **kwargs).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -401,18 +412,24 @@ class TestExpandedWeightModule(TestCase):
|
||||
input.grad = torch.zeros_like(input.grad)
|
||||
|
||||
# get per sample grads with a for loop
|
||||
expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
|
||||
expected_res = torch.tensor(0., device=input.device, dtype=actual_res.dtype)
|
||||
expected_grads = []
|
||||
for i in range(batch_size):
|
||||
input_slice = input[i]
|
||||
input_slice = input.narrow(batch_dim, i, 1)
|
||||
input_slice = input_slice.squeeze(batch_dim)
|
||||
|
||||
# h's batch dim is always the first dim. Must be contiguous for CUDA
|
||||
sliced_args = tree_map_only(torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args)
|
||||
diff_params = module.parameters()
|
||||
if diff_input:
|
||||
diff_params = chain(diff_params, (input_slice,))
|
||||
res = module(input_slice.unsqueeze(0)).sum()
|
||||
res = module(input_slice.unsqueeze(batch_dim).contiguous(), *sliced_args, **kwargs).sum()
|
||||
out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
|
||||
expected_grads.append(out_grads)
|
||||
expected_res += res
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
||||
if not batch_first:
|
||||
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@ -457,6 +474,52 @@ class TestExpandedWeightModule(TestCase):
|
||||
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
|
||||
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@modules(filter(lambda m_info: m_info.module_cls == torch.nn.RNN, module_db))
|
||||
def test_module(self, device, dtype, module_info, training):
|
||||
class RNNWrapper(torch.nn.Module):
|
||||
def __init__(self, m_cons, args, kwargs):
|
||||
super().__init__()
|
||||
self.m = m_cons(*args, **kwargs)
|
||||
|
||||
def forward(self, *inps):
|
||||
ret = self.m(*inps)
|
||||
assert isinstance(ret, tuple)
|
||||
return ret[0]
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = RNNWrapper(module_cls, args, kwargs)
|
||||
batch_first = m.m.batch_first
|
||||
m.to(device).to(dtype)
|
||||
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
# if the RNN tests use unbatched inputs--batch the inputs
|
||||
input = args[0].detach()
|
||||
if input.dim() == 2:
|
||||
new_input_shape = [1] * (len(input.shape) + 1)
|
||||
if batch_first:
|
||||
new_input_shape[0] = 2
|
||||
input = input.repeat(new_input_shape)
|
||||
else:
|
||||
new_input_shape[1] = 2
|
||||
input = input.unsqueeze(1).repeat(new_input_shape)
|
||||
|
||||
h = args[1] if len(args) > 1 else None
|
||||
if h is not None:
|
||||
new_h_shape = [1] * (len(h.shape) + 1)
|
||||
new_h_shape[1] = 2
|
||||
h = h.unsqueeze(1).repeat(new_h_shape)
|
||||
args = list(args)
|
||||
args[1] = h
|
||||
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
input = torch.randn(64, 10)
|
||||
@ -665,5 +728,6 @@ def clone_if_tensor(t):
|
||||
|
||||
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
|
||||
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
|
||||
instantiate_device_type_tests(TestExpandedWeightModule, globals())
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user