[functorch] fixed dot batching rule and added vmap tests using opinfo

This commit is contained in:
Horace He
2021-04-28 18:42:16 -07:00
committed by Jon Janzen
parent 918ede7a85
commit 0d94ae66a7
4 changed files with 99 additions and 1 deletions

View File

@ -4,6 +4,7 @@ from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
from ._src.python_key import key_wrap, ModuleWrap
# Monkeypatching lol

View File

@ -0,0 +1,40 @@
import functools
import torch._C.key as key
from torch.fx import PythonTensor
import torch
class ModuleWrap(torch.nn.Module):
def __init__(self, mod, inps):
super().__init__()
self.mod = mod
self.inps = inps
@functools.wraps(mod.forward)
def forward_wrapped(self, *args):
new_args = []
for inp, arg in zip(inps, args):
if isinstance(inp, torch.Tensor):
new_arg = key.addKey(PythonTensor(inp.shape, arg))
else:
new_arg = inp
new_args.append(new_arg)
out = self.mod(*new_args)
return key.removeKey(out).proxy
type(self).forward = forward_wrapped
def key_wrap(f, inps):
@functools.wraps(f)
def wrapped(*args):
new_args = []
for inp, arg in zip(inps, args):
if isinstance(inp, torch.Tensor):
new_arg = key.addKey(PythonTensor(inp.shape, arg))
else:
new_arg = inp
new_args.append(new_arg)
out = f(*new_args)
if key.hasKey(out):
return key.removeKey(out).proxy
else:
return out
return wrapped

View File

@ -24,6 +24,9 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
auto A_ = moveBatchDimToFront(A, A_bdim);
auto B_ = moveBatchDimToFront(B, B_bdim);
if (A_bdim && B_bdim) {
return {at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0};
}
return {at::matmul(A_, B_.t()), 0};
}

View File

@ -8,6 +8,8 @@ import warnings
import unittest
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db
import types
from functorch import vmap
@ -1675,7 +1677,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
# unsqueeze dim 0
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
# unsqueeze last dim (positive)
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
@ -2587,6 +2589,58 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
result = vmap(vjp)(gy)
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
class TestVmapOperators(TestCase):
@onlyCPU
@ops(op_db, allowed_dtypes=(torch.float,))
def test_normalize_operator_exhaustive(self, device, dtype, op):
op_skip = {'__getitem__', 'broadcast_to', 'dsplit', 'hsplit', 'vsplit', 'moveaxis', 'positive', 'tensor_split', 'unfold'}
# Unsupported input types
if op.name in op_skip:
return
# entries in here need don't work and need to be fixed.
vmap_fail = {'repeat'}
if op.name in vmap_fail:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
def add_batch_dim(arg, bdim, batch_size=3):
if isinstance(arg, torch.Tensor):
shape = [1] * len(arg.shape)
shape.insert(bdim, batch_size)
return (arg.repeat(shape), bdim)
else:
return (arg, None)
for sample_input in sample_inputs_itr:
arg_values = [sample_input.input] + list(sample_input.args)
kwarg_values = sample_input.kwargs
if len(kwarg_values) > 0:
continue
batch_size = 3
out_dim = 0
batch_choices = [(add_batch_dim(a,0, batch_size), (a, None)) if isinstance(a, torch.Tensor) else ((a, None),) for a in arg_values]
for batched_values in itertools.product(*batch_choices):
batched_args, in_dims = zip(*batched_values)
if all([i is None for i in in_dims]):
continue
outs = []
for idx in range(batch_size):
idx_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(batched_args, in_dims)]
out = op.op(*idx_args, **kwarg_values)
outs.append(out)
loop_out = []
if isinstance(outs[0], torch.Tensor):
loop_out = torch.stack(outs)
else:
for idx in range(len(outs[0])):
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
batched_out = vmap(op.op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
self.assertEqual(loop_out, batched_out)
instantiate_device_type_tests(TestVmapOperators, globals())
only_for = ("cpu", "cuda")
instantiate_device_type_tests(
TestVmapBatchedGradient,