mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] fixed dot batching rule and added vmap tests using opinfo
This commit is contained in:
@ -4,6 +4,7 @@ from . import _C
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
from ._src.python_key import key_wrap, ModuleWrap
|
||||
|
||||
|
||||
# Monkeypatching lol
|
||||
|
40
functorch/functorch/_src/python_key.py
Normal file
40
functorch/functorch/_src/python_key.py
Normal file
@ -0,0 +1,40 @@
|
||||
import functools
|
||||
import torch._C.key as key
|
||||
from torch.fx import PythonTensor
|
||||
import torch
|
||||
|
||||
class ModuleWrap(torch.nn.Module):
|
||||
def __init__(self, mod, inps):
|
||||
super().__init__()
|
||||
self.mod = mod
|
||||
self.inps = inps
|
||||
@functools.wraps(mod.forward)
|
||||
def forward_wrapped(self, *args):
|
||||
new_args = []
|
||||
for inp, arg in zip(inps, args):
|
||||
if isinstance(inp, torch.Tensor):
|
||||
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||
else:
|
||||
new_arg = inp
|
||||
new_args.append(new_arg)
|
||||
out = self.mod(*new_args)
|
||||
return key.removeKey(out).proxy
|
||||
|
||||
type(self).forward = forward_wrapped
|
||||
|
||||
def key_wrap(f, inps):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args):
|
||||
new_args = []
|
||||
for inp, arg in zip(inps, args):
|
||||
if isinstance(inp, torch.Tensor):
|
||||
new_arg = key.addKey(PythonTensor(inp.shape, arg))
|
||||
else:
|
||||
new_arg = inp
|
||||
new_args.append(new_arg)
|
||||
out = f(*new_args)
|
||||
if key.hasKey(out):
|
||||
return key.removeKey(out).proxy
|
||||
else:
|
||||
return out
|
||||
return wrapped
|
@ -24,6 +24,9 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
|
||||
auto A_ = moveBatchDimToFront(A, A_bdim);
|
||||
auto B_ = moveBatchDimToFront(B, B_bdim);
|
||||
if (A_bdim && B_bdim) {
|
||||
return {at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0};
|
||||
}
|
||||
return {at::matmul(A_, B_.t()), 0};
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,8 @@ import warnings
|
||||
import unittest
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
import types
|
||||
|
||||
from functorch import vmap
|
||||
@ -1675,7 +1677,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
# unsqueeze dim 0
|
||||
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
||||
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
||||
|
||||
|
||||
# unsqueeze last dim (positive)
|
||||
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
||||
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
||||
@ -2587,6 +2589,58 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
||||
result = vmap(vjp)(gy)
|
||||
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
||||
|
||||
class TestVmapOperators(TestCase):
|
||||
@onlyCPU
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
||||
op_skip = {'__getitem__', 'broadcast_to', 'dsplit', 'hsplit', 'vsplit', 'moveaxis', 'positive', 'tensor_split', 'unfold'}
|
||||
# Unsupported input types
|
||||
if op.name in op_skip:
|
||||
return
|
||||
|
||||
# entries in here need don't work and need to be fixed.
|
||||
vmap_fail = {'repeat'}
|
||||
if op.name in vmap_fail:
|
||||
return
|
||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
def add_batch_dim(arg, bdim, batch_size=3):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
shape = [1] * len(arg.shape)
|
||||
shape.insert(bdim, batch_size)
|
||||
return (arg.repeat(shape), bdim)
|
||||
else:
|
||||
return (arg, None)
|
||||
|
||||
for sample_input in sample_inputs_itr:
|
||||
arg_values = [sample_input.input] + list(sample_input.args)
|
||||
kwarg_values = sample_input.kwargs
|
||||
if len(kwarg_values) > 0:
|
||||
continue
|
||||
batch_size = 3
|
||||
out_dim = 0
|
||||
batch_choices = [(add_batch_dim(a,0, batch_size), (a, None)) if isinstance(a, torch.Tensor) else ((a, None),) for a in arg_values]
|
||||
for batched_values in itertools.product(*batch_choices):
|
||||
batched_args, in_dims = zip(*batched_values)
|
||||
if all([i is None for i in in_dims]):
|
||||
continue
|
||||
outs = []
|
||||
for idx in range(batch_size):
|
||||
idx_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(batched_args, in_dims)]
|
||||
out = op.op(*idx_args, **kwarg_values)
|
||||
outs.append(out)
|
||||
loop_out = []
|
||||
if isinstance(outs[0], torch.Tensor):
|
||||
loop_out = torch.stack(outs)
|
||||
else:
|
||||
for idx in range(len(outs[0])):
|
||||
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
|
||||
|
||||
batched_out = vmap(op.op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
|
||||
self.assertEqual(loop_out, batched_out)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestVmapOperators, globals())
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(
|
||||
TestVmapBatchedGradient,
|
||||
|
Reference in New Issue
Block a user