mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[expanded weights] add RNN support via decomp (#91807)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91807 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2a92687e0
commit
20d01d2dc9
@ -11,12 +11,15 @@ from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize
|
||||
from torch.testing._internal.common_methods_invocations import SampleInput, op_db
|
||||
from torch.nn.utils._expanded_weights import ExpandedWeight
|
||||
from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \
|
||||
unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
class TestContext:
|
||||
pass
|
||||
@ -383,14 +386,22 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
batch_size = input.shape[0]
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
batch_dim = 0 if batch_first else 1
|
||||
batch_size = input.shape[batch_dim]
|
||||
diff_input = input.dtype == torch.float or input.dtype == torch.double
|
||||
if diff_input:
|
||||
input.requires_grad_()
|
||||
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager
|
||||
actual_res = call_for_per_sample_grads(module, loss_reduction="sum")(input).sum()
|
||||
actual_res = call_for_per_sample_grads(module,
|
||||
batch_size=batch_size,
|
||||
loss_reduction="sum",
|
||||
batch_first=batch_first)(input, *args, **kwargs).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -401,18 +412,24 @@ class TestExpandedWeightModule(TestCase):
|
||||
input.grad = torch.zeros_like(input.grad)
|
||||
|
||||
# get per sample grads with a for loop
|
||||
expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
|
||||
expected_res = torch.tensor(0., device=input.device, dtype=actual_res.dtype)
|
||||
expected_grads = []
|
||||
for i in range(batch_size):
|
||||
input_slice = input[i]
|
||||
input_slice = input.narrow(batch_dim, i, 1)
|
||||
input_slice = input_slice.squeeze(batch_dim)
|
||||
|
||||
# h's batch dim is always the first dim. Must be contiguous for CUDA
|
||||
sliced_args = tree_map_only(torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args)
|
||||
diff_params = module.parameters()
|
||||
if diff_input:
|
||||
diff_params = chain(diff_params, (input_slice,))
|
||||
res = module(input_slice.unsqueeze(0)).sum()
|
||||
res = module(input_slice.unsqueeze(batch_dim).contiguous(), *sliced_args, **kwargs).sum()
|
||||
out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
|
||||
expected_grads.append(out_grads)
|
||||
expected_res += res
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
||||
if not batch_first:
|
||||
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@ -457,6 +474,52 @@ class TestExpandedWeightModule(TestCase):
|
||||
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
|
||||
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@modules(filter(lambda m_info: m_info.module_cls == torch.nn.RNN, module_db))
|
||||
def test_module(self, device, dtype, module_info, training):
|
||||
class RNNWrapper(torch.nn.Module):
|
||||
def __init__(self, m_cons, args, kwargs):
|
||||
super().__init__()
|
||||
self.m = m_cons(*args, **kwargs)
|
||||
|
||||
def forward(self, *inps):
|
||||
ret = self.m(*inps)
|
||||
assert isinstance(ret, tuple)
|
||||
return ret[0]
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True, training=training)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = RNNWrapper(module_cls, args, kwargs)
|
||||
batch_first = m.m.batch_first
|
||||
m.to(device).to(dtype)
|
||||
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
# if the RNN tests use unbatched inputs--batch the inputs
|
||||
input = args[0].detach()
|
||||
if input.dim() == 2:
|
||||
new_input_shape = [1] * (len(input.shape) + 1)
|
||||
if batch_first:
|
||||
new_input_shape[0] = 2
|
||||
input = input.repeat(new_input_shape)
|
||||
else:
|
||||
new_input_shape[1] = 2
|
||||
input = input.unsqueeze(1).repeat(new_input_shape)
|
||||
|
||||
h = args[1] if len(args) > 1 else None
|
||||
if h is not None:
|
||||
new_h_shape = [1] * (len(h.shape) + 1)
|
||||
new_h_shape[1] = 2
|
||||
h = h.unsqueeze(1).repeat(new_h_shape)
|
||||
args = list(args)
|
||||
args[1] = h
|
||||
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
input = torch.randn(64, 10)
|
||||
@ -665,5 +728,6 @@ def clone_if_tensor(t):
|
||||
|
||||
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
|
||||
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
|
||||
instantiate_device_type_tests(TestExpandedWeightModule, globals())
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -2157,15 +2157,15 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
|
||||
|
||||
precomputed_input = F.linear(inp, ih_weight, ih_bias)
|
||||
precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
|
||||
cur_hidden = hidden
|
||||
cur_hidden = hidden.unsqueeze(0)
|
||||
step_output = []
|
||||
for inp in precomputed_input:
|
||||
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
|
||||
step_output.append(cur_hidden)
|
||||
|
||||
out = torch.stack(step_output, 0)
|
||||
out = torch.cat(step_output, 0)
|
||||
|
||||
return out, cur_hidden
|
||||
return out, cur_hidden.squeeze(0)
|
||||
|
||||
|
||||
def _rnn_helper(
|
||||
|
||||
@ -1,11 +1,35 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch._C import _TensorBase
|
||||
import torch
|
||||
import functools
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
from typing import Callable, Dict, cast
|
||||
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
|
||||
|
||||
# __torch_function__ runs before the pydispatcher so we need to use the same
|
||||
# decompositions indexed by their torch equivalent
|
||||
expanded_weights_rnn_decomps = {
|
||||
# func: (input_decomp, data_decomp)
|
||||
torch.rnn_relu: (decomposition_table[torch._ops.ops.aten.rnn_relu.input], None),
|
||||
torch.rnn_tanh: (decomposition_table[torch._ops.ops.aten.rnn_tanh.input], None)
|
||||
}
|
||||
|
||||
@contextmanager
|
||||
def batch_second(args, kwargs):
|
||||
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), args)
|
||||
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), kwargs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), args)
|
||||
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), kwargs)
|
||||
|
||||
|
||||
def implements_per_sample_grads(torch_function):
|
||||
@functools.wraps(torch_function)
|
||||
def decorator(autograd_func):
|
||||
@ -28,6 +52,7 @@ def implements_per_sample_grads(torch_function):
|
||||
class ExpandedWeight(torch.Tensor):
|
||||
def __init__(self, orig_weight, batch_size, loss_reduction):
|
||||
self.batch_size = batch_size
|
||||
self.batch_first = True
|
||||
self.orig_weight = orig_weight
|
||||
self.loss_reduction = loss_reduction
|
||||
|
||||
@ -45,6 +70,18 @@ class ExpandedWeight(torch.Tensor):
|
||||
def __torch_function__(cls, func, _, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func in expanded_weights_rnn_decomps:
|
||||
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
|
||||
decomp_opts = expanded_weights_rnn_decomps[func]
|
||||
use_input_variant = isinstance(args[1], torch.Tensor) # data variant uses a list here
|
||||
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
|
||||
|
||||
if decomp is not None:
|
||||
with batch_second(args, kwargs):
|
||||
return decomp(*args, **kwargs)
|
||||
if func == torch._cudnn_rnn_flatten_weight:
|
||||
# since we aren't using the fused cuda kernels for RNNs, don't do this
|
||||
return
|
||||
if func in cls.handled_functions:
|
||||
return cls.handled_functions[func].apply(tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())))
|
||||
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
|
||||
@ -55,6 +92,27 @@ class ExpandedWeight(torch.Tensor):
|
||||
def dtype(self):
|
||||
return self.orig_weight.dtype
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.orig_weight.data
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.orig_weight.shape
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.orig_weight.device
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.orig_weight.is_cuda
|
||||
|
||||
def data_ptr(self):
|
||||
return self.orig_weight.data_ptr()
|
||||
|
||||
def get_device(self):
|
||||
return self.orig_weight.get_device()
|
||||
|
||||
def set_batch_first(self, is_batch_first=True):
|
||||
self.batch_first = is_batch_first
|
||||
|
||||
@ -3,6 +3,18 @@ from typing import Optional
|
||||
import torch
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
def is_batch_first(expanded_args_and_kwargs):
|
||||
batch_first = None
|
||||
for arg in expanded_args_and_kwargs:
|
||||
if not isinstance(arg, ExpandedWeight):
|
||||
continue
|
||||
|
||||
if not batch_first:
|
||||
batch_first = arg.batch_first
|
||||
elif arg.batch_first != batch_first:
|
||||
raise RuntimeError("Got conflicting batch_first arguments in the same layer")
|
||||
return batch_first
|
||||
|
||||
def standard_kwargs(kwarg_names, expanded_args):
|
||||
r'''Most `__torch_function__`s standardize the kwargs that they give, so this will separate
|
||||
the args and kwargs they pass. Functions that don't are linear and convND
|
||||
@ -46,9 +58,11 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
|
||||
if input.shape[0] == 0:
|
||||
raise RuntimeError("0 is not a valid batch size for Expanded Weights but got input tensor of "
|
||||
f"{input} in function {func.__name__}")
|
||||
batch_size = input.shape[0]
|
||||
for arg in expanded_args + tuple(expanded_kwargs.values()):
|
||||
if isinstance(arg, ExpandedWeight) and arg.batch_size != batch_size:
|
||||
if not isinstance(arg, ExpandedWeight):
|
||||
continue
|
||||
batch_size = input.shape[0] if arg.batch_first else input.shape[1]
|
||||
if arg.batch_size != batch_size:
|
||||
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
|
||||
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from .expanded_weights_impl import implements_per_sample_grads
|
||||
from .expanded_weights_utils import \
|
||||
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
|
||||
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor, is_batch_first
|
||||
from typing import List, Optional
|
||||
|
||||
@implements_per_sample_grads(F.linear)
|
||||
@ -14,6 +14,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
||||
f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}")
|
||||
expanded_kwargs = {'bias': expanded_args_and_kwargs[2] if len(expanded_args_and_kwargs) == 3 else None}
|
||||
expanded_args = expanded_args_and_kwargs[:2]
|
||||
ctx.batch_first = is_batch_first(expanded_args_and_kwargs)
|
||||
output = forward_helper(F.linear, expanded_args, expanded_kwargs)
|
||||
ctx.args = expanded_args
|
||||
ctx.kwargs = expanded_kwargs
|
||||
@ -33,6 +34,10 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
||||
results.append(None)
|
||||
results.extend([None] * 2) # weight and bias don't compute batched gradients
|
||||
|
||||
if not ctx.batch_first:
|
||||
grad_output = grad_output.transpose(0, 1)
|
||||
input = input.transpose(0, 1)
|
||||
|
||||
# weight and bias get their grad_sample fields set directly if they exist
|
||||
set_grad_sample_if_exists(weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input))
|
||||
set_grad_sample_if_exists(bias, lambda _: torch.einsum("n...k->nk", grad_output))
|
||||
|
||||
@ -6,9 +6,11 @@ from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeigh
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
|
||||
# dependency on `functional_call` means that this can't be exposed in utils
|
||||
# without creating circular dependency
|
||||
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True):
|
||||
r"""
|
||||
call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum")
|
||||
call_for_per_sample_grads(module, batch_size=None, loss_reduction="sum", batch_first=True)
|
||||
``call_for_per_sample_grads`` returns a function that is invoked like the forward
|
||||
function of ``module`` and will produce the same result. Then, when backward is invoked,
|
||||
the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample
|
||||
@ -24,6 +26,8 @@ def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
|
||||
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
|
||||
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
|
||||
running mean across a batch. Must be "mean" or "sum". Default: "sum"
|
||||
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
|
||||
dimension. If False, it's the second dimension. Default: True.
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +SKIP
|
||||
@ -64,7 +68,7 @@ def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum"):
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
continue
|
||||
|
||||
arg_batch_size = arg.shape[0] # we assume batch size is the first dim
|
||||
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
|
||||
if batch_size is not None and batch_size != arg_batch_size:
|
||||
raise RuntimeError("When computing batch size, found at least one input with batch size "
|
||||
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "
|
||||
|
||||
@ -1059,6 +1059,10 @@ rnn_gru_lstm_module_info_decorators = (
|
||||
unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
|
||||
active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
|
||||
),
|
||||
DecorateInfo(
|
||||
skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module",
|
||||
device_type='cuda'
|
||||
),
|
||||
DecorateInfo(
|
||||
skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module",
|
||||
device_type='cuda'
|
||||
|
||||
Reference in New Issue
Block a user