[expanded weights] add RNN support via decomp (#91807)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91807
Approved by: https://github.com/albanD
This commit is contained in:
lezcano
2023-02-08 09:36:19 +00:00
committed by PyTorch MergeBot
parent c2a92687e0
commit 20d01d2dc9
7 changed files with 165 additions and 16 deletions

View File

@ -11,12 +11,15 @@ from torch.nn import CrossEntropyLoss
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize
from torch.testing._internal.common_methods_invocations import SampleInput, op_db
from torch.nn.utils._expanded_weights import ExpandedWeight
from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \
unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs
from torch.utils._pytree import tree_map_only
class TestContext:
pass
@ -383,14 +386,22 @@ class TestExpandedWeightFunctional(TestCase):
F.group_norm(inp, 2) # 5 is not divisible by 2
class TestExpandedWeightModule(TestCase):
def _do_test(self, module, input):
batch_size = input.shape[0]
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
args = args or ()
kwargs = kwargs or {}
batch_dim = 0 if batch_first else 1
batch_size = input.shape[batch_dim]
diff_input = input.dtype == torch.float or input.dtype == torch.double
if diff_input:
input.requires_grad_()
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager
actual_res = call_for_per_sample_grads(module, loss_reduction="sum")(input).sum()
actual_res = call_for_per_sample_grads(module,
batch_size=batch_size,
loss_reduction="sum",
batch_first=batch_first)(input, *args, **kwargs).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -401,18 +412,24 @@ class TestExpandedWeightModule(TestCase):
input.grad = torch.zeros_like(input.grad)
# get per sample grads with a for loop
expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
expected_res = torch.tensor(0., device=input.device, dtype=actual_res.dtype)
expected_grads = []
for i in range(batch_size):
input_slice = input[i]
input_slice = input.narrow(batch_dim, i, 1)
input_slice = input_slice.squeeze(batch_dim)
# h's batch dim is always the first dim. Must be contiguous for CUDA
sliced_args = tree_map_only(torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args)
diff_params = module.parameters()
if diff_input:
diff_params = chain(diff_params, (input_slice,))
res = module(input_slice.unsqueeze(0)).sum()
res = module(input_slice.unsqueeze(batch_dim).contiguous(), *sliced_args, **kwargs).sum()
out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
expected_grads.append(out_grads)
expected_res += res
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
if not batch_first:
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
@ -457,6 +474,52 @@ class TestExpandedWeightModule(TestCase):
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls == torch.nn.RNN, module_db))
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
def __init__(self, m_cons, args, kwargs):
super().__init__()
self.m = m_cons(*args, **kwargs)
def forward(self, *inps):
ret = self.m(*inps)
assert isinstance(ret, tuple)
return ret[0]
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = RNNWrapper(module_cls, args, kwargs)
batch_first = m.m.batch_first
m.to(device).to(dtype)
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
# if the RNN tests use unbatched inputs--batch the inputs
input = args[0].detach()
if input.dim() == 2:
new_input_shape = [1] * (len(input.shape) + 1)
if batch_first:
new_input_shape[0] = 2
input = input.repeat(new_input_shape)
else:
new_input_shape[1] = 2
input = input.unsqueeze(1).repeat(new_input_shape)
h = args[1] if len(args) > 1 else None
if h is not None:
new_h_shape = [1] * (len(h.shape) + 1)
new_h_shape[1] = 2
h = h.unsqueeze(1).repeat(new_h_shape)
args = list(args)
args[1] = h
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
def test_per_sample_api_failing(self):
module = nn.Linear(10, 10)
input = torch.randn(64, 10)
@ -665,5 +728,6 @@ def clone_if_tensor(t):
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
instantiate_device_type_tests(TestExpandedWeightModule, globals())
if __name__ == '__main__':
run_tests()

View File

@ -2157,15 +2157,15 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
precomputed_input = F.linear(inp, ih_weight, ih_bias)
precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
cur_hidden = hidden
cur_hidden = hidden.unsqueeze(0)
step_output = []
for inp in precomputed_input:
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
step_output.append(cur_hidden)
out = torch.stack(step_output, 0)
out = torch.cat(step_output, 0)
return out, cur_hidden
return out, cur_hidden.squeeze(0)
def _rnn_helper(

View File

@ -1,11 +1,35 @@
from contextlib import contextmanager
from torch._C import _TensorBase
import torch
import functools
from torch._decomp import decomposition_table
from typing import Callable, Dict, cast
from torch.utils._pytree import tree_map_only
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
# __torch_function__ runs before the pydispatcher so we need to use the same
# decompositions indexed by their torch equivalent
expanded_weights_rnn_decomps = {
# func: (input_decomp, data_decomp)
torch.rnn_relu: (decomposition_table[torch._ops.ops.aten.rnn_relu.input], None),
torch.rnn_tanh: (decomposition_table[torch._ops.ops.aten.rnn_tanh.input], None)
}
@contextmanager
def batch_second(args, kwargs):
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), args)
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), args)
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), kwargs)
def implements_per_sample_grads(torch_function):
@functools.wraps(torch_function)
def decorator(autograd_func):
@ -28,6 +52,7 @@ def implements_per_sample_grads(torch_function):
class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.batch_first = True
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
@ -45,6 +70,18 @@ class ExpandedWeight(torch.Tensor):
def __torch_function__(cls, func, _, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in expanded_weights_rnn_decomps:
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance(args[1], torch.Tensor) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
if decomp is not None:
with batch_second(args, kwargs):
return decomp(*args, **kwargs)
if func == torch._cudnn_rnn_flatten_weight:
# since we aren't using the fused cuda kernels for RNNs, don't do this
return
if func in cls.handled_functions:
return cls.handled_functions[func].apply(tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())))
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
@ -55,6 +92,27 @@ class ExpandedWeight(torch.Tensor):
def dtype(self):
return self.orig_weight.dtype
@property
def data(self):
return self.orig_weight.data
@property
def shape(self):
return self.orig_weight.shape
@property
def device(self):
return self.orig_weight.device
@property
def is_cuda(self):
return self.orig_weight.is_cuda
def data_ptr(self):
return self.orig_weight.data_ptr()
def get_device(self):
return self.orig_weight.get_device()
def set_batch_first(self, is_batch_first=True):
self.batch_first = is_batch_first

View File

@ -3,6 +3,18 @@ from typing import Optional
import torch
from .expanded_weights_impl import ExpandedWeight
def is_batch_first(expanded_args_and_kwargs):
batch_first = None
for arg in expanded_args_and_kwargs:
if not isinstance(arg, ExpandedWeight):
continue
if not batch_first:
batch_first = arg.batch_first
elif arg.batch_first != batch_first:
raise RuntimeError("Got conflicting batch_first arguments in the same layer")
return batch_first
def standard_kwargs(kwarg_names, expanded_args):
r'''Most `__torch_function__`s standardize the kwargs that they give, so this will separate
the args and kwargs they pass. Functions that don't are linear and convND
@ -46,9 +58,11 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
if input.shape[0] == 0:
raise RuntimeError("0 is not a valid batch size for Expanded Weights but got input tensor of "
f"{input} in function {func.__name__}")
batch_size = input.shape[0]
for arg in expanded_args + tuple(expanded_kwargs.values()):
if isinstance(arg, ExpandedWeight) and arg.batch_size != batch_size:
if not isinstance(arg, ExpandedWeight):
continue
batch_size = input.shape[0] if arg.batch_first else input.shape[1]
if arg.batch_size != batch_size:
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")

View File

@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import \
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor, is_batch_first
from typing import List, Optional
@implements_per_sample_grads(F.linear)
@ -14,6 +14,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}")
expanded_kwargs = {'bias': expanded_args_and_kwargs[2] if len(expanded_args_and_kwargs) == 3 else None}
expanded_args = expanded_args_and_kwargs[:2]
ctx.batch_first = is_batch_first(expanded_args_and_kwargs)
output = forward_helper(F.linear, expanded_args, expanded_kwargs)
ctx.args = expanded_args
ctx.kwargs = expanded_kwargs
@ -33,6 +34,10 @@ class LinearPerSampleGrad(torch.autograd.Function):
results.append(None)
results.extend([None] * 2) # weight and bias don't compute batched gradients
if not ctx.batch_first:
grad_output = grad_output.transpose(0, 1)
input = input.transpose(0, 1)
# weight and bias get their grad_sample fields set directly if they exist
set_grad_sample_if_exists(weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input))
set_grad_sample_if_exists(bias, lambda _: torch.einsum("n...k->nk", grad_output))

View File

@ -6,9 +6,11 @@ from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeigh
from torch.utils._pytree import tree_flatten
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
# dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True):
r"""
call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum")
call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum", batch_first=True)
``call_for_per_sample_grads`` returns a function that is invoked like the forward
function of ``module`` and will produce the same result. Then, when backward is invoked,
the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample
@ -24,6 +26,8 @@ def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
running mean across a batch. Must be "mean" or "sum". Default: "sum"
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
dimension. If False, it's the second dimension. Default: True.
Examples::
>>> # xdoctest: +SKIP
@ -64,7 +68,7 @@ def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
if not isinstance(arg, torch.Tensor):
continue
arg_batch_size = arg.shape[0] # we assume batch size is the first dim
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
if batch_size is not None and batch_size != arg_batch_size:
raise RuntimeError("When computing batch size, found at least one input with batch size "
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "

View File

@ -1059,6 +1059,10 @@ rnn_gru_lstm_module_info_decorators = (
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
),
DecorateInfo(
skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module",
device_type='cuda'
),
DecorateInfo(
skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module",
device_type='cuda'