mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] fixed dot batching rule and added vmap tests using opinfo
This commit is contained in:
@ -4,6 +4,7 @@ from . import _C
|
|||||||
from ._src.vmap import vmap
|
from ._src.vmap import vmap
|
||||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||||
|
from ._src.python_key import key_wrap, ModuleWrap
|
||||||
|
|
||||||
|
|
||||||
# Monkeypatching lol
|
# Monkeypatching lol
|
||||||
|
|||||||
40
functorch/functorch/_src/python_key.py
Normal file
40
functorch/functorch/_src/python_key.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import functools
|
||||||
|
import torch._C.key as key
|
||||||
|
from torch.fx import PythonTensor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ModuleWrap(torch.nn.Module):
|
||||||
|
def __init__(self, mod, inps):
|
||||||
|
super().__init__()
|
||||||
|
self.mod = mod
|
||||||
|
self.inps = inps
|
||||||
|
@functools.wraps(mod.forward)
|
||||||
|
def forward_wrapped(self, *args):
|
||||||
|
new_args = []
|
||||||
|
for inp, arg in zip(inps, args):
|
||||||
|
if isinstance(inp, torch.Tensor):
|
||||||
|
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||||
|
else:
|
||||||
|
new_arg = inp
|
||||||
|
new_args.append(new_arg)
|
||||||
|
out = self.mod(*new_args)
|
||||||
|
return key.removeKey(out).proxy
|
||||||
|
|
||||||
|
type(self).forward = forward_wrapped
|
||||||
|
|
||||||
|
def key_wrap(f, inps):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapped(*args):
|
||||||
|
new_args = []
|
||||||
|
for inp, arg in zip(inps, args):
|
||||||
|
if isinstance(inp, torch.Tensor):
|
||||||
|
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||||
|
else:
|
||||||
|
new_arg = inp
|
||||||
|
new_args.append(new_arg)
|
||||||
|
out = f(*new_args)
|
||||||
|
if key.hasKey(out):
|
||||||
|
return key.removeKey(out).proxy
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
return wrapped
|
||||||
@ -24,6 +24,9 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
|||||||
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
|
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
|
||||||
auto A_ = moveBatchDimToFront(A, A_bdim);
|
auto A_ = moveBatchDimToFront(A, A_bdim);
|
||||||
auto B_ = moveBatchDimToFront(B, B_bdim);
|
auto B_ = moveBatchDimToFront(B, B_bdim);
|
||||||
|
if (A_bdim && B_bdim) {
|
||||||
|
return {at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0};
|
||||||
|
}
|
||||||
return {at::matmul(A_, B_.t()), 0};
|
return {at::matmul(A_, B_.t()), 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import warnings
|
|||||||
import unittest
|
import unittest
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||||
skipCUDAIfNoMagma
|
skipCUDAIfNoMagma
|
||||||
|
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
|
||||||
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from functorch import vmap
|
from functorch import vmap
|
||||||
@ -1675,7 +1677,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
|||||||
# unsqueeze dim 0
|
# unsqueeze dim 0
|
||||||
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
||||||
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
||||||
|
|
||||||
# unsqueeze last dim (positive)
|
# unsqueeze last dim (positive)
|
||||||
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
||||||
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
||||||
@ -2587,6 +2589,58 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
|||||||
result = vmap(vjp)(gy)
|
result = vmap(vjp)(gy)
|
||||||
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
||||||
|
|
||||||
|
class TestVmapOperators(TestCase):
|
||||||
|
@onlyCPU
|
||||||
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||||
|
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
||||||
|
op_skip = {'__getitem__', 'broadcast_to', 'dsplit', 'hsplit', 'vsplit', 'moveaxis', 'positive', 'tensor_split', 'unfold'}
|
||||||
|
# Unsupported input types
|
||||||
|
if op.name in op_skip:
|
||||||
|
return
|
||||||
|
|
||||||
|
# entries in here need don't work and need to be fixed.
|
||||||
|
vmap_fail = {'repeat'}
|
||||||
|
if op.name in vmap_fail:
|
||||||
|
return
|
||||||
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
||||||
|
def add_batch_dim(arg, bdim, batch_size=3):
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
shape = [1] * len(arg.shape)
|
||||||
|
shape.insert(bdim, batch_size)
|
||||||
|
return (arg.repeat(shape), bdim)
|
||||||
|
else:
|
||||||
|
return (arg, None)
|
||||||
|
|
||||||
|
for sample_input in sample_inputs_itr:
|
||||||
|
arg_values = [sample_input.input] + list(sample_input.args)
|
||||||
|
kwarg_values = sample_input.kwargs
|
||||||
|
if len(kwarg_values) > 0:
|
||||||
|
continue
|
||||||
|
batch_size = 3
|
||||||
|
out_dim = 0
|
||||||
|
batch_choices = [(add_batch_dim(a,0, batch_size), (a, None)) if isinstance(a, torch.Tensor) else ((a, None),) for a in arg_values]
|
||||||
|
for batched_values in itertools.product(*batch_choices):
|
||||||
|
batched_args, in_dims = zip(*batched_values)
|
||||||
|
if all([i is None for i in in_dims]):
|
||||||
|
continue
|
||||||
|
outs = []
|
||||||
|
for idx in range(batch_size):
|
||||||
|
idx_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(batched_args, in_dims)]
|
||||||
|
out = op.op(*idx_args, **kwarg_values)
|
||||||
|
outs.append(out)
|
||||||
|
loop_out = []
|
||||||
|
if isinstance(outs[0], torch.Tensor):
|
||||||
|
loop_out = torch.stack(outs)
|
||||||
|
else:
|
||||||
|
for idx in range(len(outs[0])):
|
||||||
|
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
|
||||||
|
|
||||||
|
batched_out = vmap(op.op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
|
||||||
|
self.assertEqual(loop_out, batched_out)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_device_type_tests(TestVmapOperators, globals())
|
||||||
|
|
||||||
only_for = ("cpu", "cuda")
|
only_for = ("cpu", "cuda")
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestVmapBatchedGradient,
|
TestVmapBatchedGradient,
|
||||||
|
|||||||
Reference in New Issue
Block a user