mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom ops] Add register_vmap for custom ops (#130589)
Fixes #130284 Fixes #130653 - Add `torch.library.register_vmap` to custom ops - Add `register_vmap` for operators in ops in custom_op_db. - Make `torch.autograd.Function` support kwarg-only kwargs for vmap - test operators in op_db with `tests/test_vmap`. - change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
404d640c39
commit
68c725a094
@ -42,6 +42,7 @@ via PyTorch's C++ operator registration APIs).
|
||||
.. autofunction:: register_kernel
|
||||
.. autofunction:: register_autograd
|
||||
.. autofunction:: register_fake
|
||||
.. autofunction:: register_vmap
|
||||
.. autofunction:: impl_abstract
|
||||
.. autofunction:: get_ctx
|
||||
.. autofunction:: register_torch_dispatch
|
||||
|
@ -18,6 +18,7 @@ from torch.testing._internal.autograd_function_db import autograd_function_db
|
||||
from torch.testing._internal.common_device_type import toleranceOverride
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
|
||||
from torch.testing._internal.common_modules import module_db
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
|
||||
|
||||
IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1"
|
||||
@ -38,8 +39,26 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
|
||||
flat_out, out_spec = pytree.tree_flatten(out)
|
||||
outs.append(flat_out)
|
||||
|
||||
# use the same out_dim for all outputs
|
||||
if isinstance(out_dim, int):
|
||||
flat_out_dim = [out_dim for _ in flat_out]
|
||||
else:
|
||||
flat_out_dim, _ = pytree.tree_flatten(out_dim)
|
||||
|
||||
outs = zip(*outs)
|
||||
result = [torch.stack(out_lst) for out_lst in outs]
|
||||
|
||||
result = []
|
||||
for i, out_lst in enumerate(outs):
|
||||
if flat_out_dim[i] is not None:
|
||||
if not all(isinstance(x, torch.Tensor) for x in out_lst):
|
||||
raise ValueError(
|
||||
f"vmap `{op}` must only return "
|
||||
"Tensors. Did you mean to set out_dims= to None for output?"
|
||||
)
|
||||
result.append(torch.stack(out_lst))
|
||||
else:
|
||||
# not batched over, result should be the same for all batches
|
||||
result.append(out_lst[0])
|
||||
return pytree.tree_unflatten(result, out_spec)
|
||||
|
||||
|
||||
@ -317,9 +336,9 @@ def _compute_quantities_for_vmap_test(
|
||||
inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
|
||||
outer_in_dims = (0,) + in_dims
|
||||
batched_args, kwarg_values = maybe_clone_inputs()
|
||||
vmapvmap_output = vmap(vmap(f, inner_in_dims), outer_in_dims)(
|
||||
dummy, *batched_args, **kwarg_values
|
||||
)
|
||||
vmapvmap_output = vmap(
|
||||
vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim
|
||||
)(dummy, *batched_args, **kwarg_values)
|
||||
|
||||
yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected)
|
||||
|
||||
@ -440,7 +459,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
|
||||
|
||||
|
||||
def skipOps(test_case_name, base_test_name, to_skip):
|
||||
all_opinfos = op_db + additional_op_db + autograd_function_db
|
||||
all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db
|
||||
for decorate_meta in to_skip:
|
||||
matching_opinfos = [
|
||||
o
|
||||
|
@ -1617,6 +1617,28 @@ class TestAutogradFunctionVmapAPI(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
|
||||
result = vmap(Zeros.apply)(x)
|
||||
|
||||
def test_kwarg_only_tensors(self, device):
|
||||
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
|
||||
|
||||
class MyClass(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(x, *, y):
|
||||
return x + y
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def vmap(info, in_dims, x, *, y):
|
||||
assert in_dims == (0,)
|
||||
return x + y, 0
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
|
||||
vmap(MyClass.apply)(x, y=y)
|
||||
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestVmapOfGrad(TestCase):
|
||||
|
@ -68,6 +68,7 @@ from torch.testing._internal.common_utils import (
|
||||
unMarkDynamoStrictTest,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@ -3937,10 +3938,17 @@ def discover_variants(opinfo):
|
||||
@unMarkDynamoStrictTest
|
||||
class TestVmapOperatorsOpInfo(TestCase):
|
||||
def vmap_outplace_test(
|
||||
self, func, args, kwargs, in_dims, check_shape_only=False, postprocess_fn=None
|
||||
self,
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
in_dims,
|
||||
check_shape_only=False,
|
||||
postprocess_fn=None,
|
||||
out_dim=0,
|
||||
):
|
||||
for vmap_out, loop_out in compute_quantities_for_vmap_test(
|
||||
func, args, kwargs, in_dims
|
||||
func, args, kwargs, in_dims, out_dim=out_dim
|
||||
):
|
||||
if postprocess_fn is not None:
|
||||
loop_out = postprocess_fn(loop_out)
|
||||
@ -3950,7 +3958,9 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
continue
|
||||
self.assertEqual(vmap_out, loop_out)
|
||||
|
||||
def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
|
||||
def vmap_inplace_test(
|
||||
self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0
|
||||
):
|
||||
# NB: This test assumes that the first argument is being modified.
|
||||
# This is OK because it's what every other OpInfo-based test assumes,
|
||||
# but it is going to need a more robust solution eventually.
|
||||
@ -3963,13 +3973,19 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
args,
|
||||
kwargs,
|
||||
in_dims,
|
||||
out_dim=out_dim,
|
||||
compute_loop_out=False,
|
||||
clone_inputs=True,
|
||||
):
|
||||
pass
|
||||
return
|
||||
for vmap_out, loop_out in compute_quantities_for_vmap_test(
|
||||
func, args, kwargs, in_dims, clone_inputs=True
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
in_dims,
|
||||
clone_inputs=True,
|
||||
out_dim=out_dim,
|
||||
):
|
||||
if postprocess_fn is not None:
|
||||
loop_out = postprocess_fn(loop_out)
|
||||
@ -4027,6 +4043,13 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
continue
|
||||
kwargs = sample_input.kwargs
|
||||
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
|
||||
out_dim = 0
|
||||
if op.name == "NumpySplitCopyWithIntCustomOp":
|
||||
# special case for this custom op
|
||||
def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim):
|
||||
return [0 for _ in range(len(splits) + 1)], None
|
||||
|
||||
out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args)
|
||||
for batched_args, in_dims, _ in generate_vmap_inputs(
|
||||
args, {}, is_batch_norm_and_training=is_batch_norm_and_training
|
||||
):
|
||||
@ -4038,6 +4061,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
in_dims,
|
||||
check_shape_only,
|
||||
postprocess_fn,
|
||||
out_dim=out_dim,
|
||||
)
|
||||
if op.name in skip_inplace:
|
||||
continue
|
||||
@ -4109,6 +4133,9 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
"linalg.eigh", ""
|
||||
), # not always return the same result for the same input, see test_linalg_eigh for manual test
|
||||
skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
# UnimplementedError: data-dependent operators cannot be vmapped
|
||||
xfail("NumpyNonzeroCustomOp"),
|
||||
xfail("NumpyNMSCustomOp"),
|
||||
# ----------------------------------------------------------------------
|
||||
# ---------------------------- BUGS ------------------------------------
|
||||
# entries in here don't work and need to be fixed.
|
||||
@ -4187,7 +4214,10 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
}
|
||||
|
||||
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
||||
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
|
||||
@ops(
|
||||
op_db + additional_op_db + autograd_function_db + custom_op_db,
|
||||
dtypes=OpDTypes.any_one,
|
||||
)
|
||||
@opsToleranceOverride(
|
||||
"TestVmapOperatorsOpInfo",
|
||||
"test_vmap_exhaustive",
|
||||
@ -4248,7 +4278,10 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
)
|
||||
|
||||
@with_tf32_off
|
||||
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
|
||||
@ops(
|
||||
op_db + additional_op_db + autograd_function_db + custom_op_db,
|
||||
dtypes=OpDTypes.any_one,
|
||||
)
|
||||
@opsToleranceOverride(
|
||||
"TestVmapOperatorsOpInfo",
|
||||
"test_op_has_batch_rule",
|
||||
|
@ -2327,6 +2327,12 @@ class TestCustomOpAPI(TestCase):
|
||||
setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
|
||||
torch.library.register_vmap(
|
||||
"_torch_testing::foo",
|
||||
lambda info, in_dims, x, *, y: (x, 0),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_register_autograd_kwargonly_low_level(self):
|
||||
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
|
||||
@ -3382,6 +3388,246 @@ Please use `add.register_fake` to add an fake impl.""",
|
||||
with f.set_kernel_enabled("cpu", enabled=False):
|
||||
self.assertEqual(f(x), x + 1)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_register_vmap_kwargonly_low_level(self):
|
||||
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
|
||||
lib.define("foo(Tensor x, *, float y) -> Tensor")
|
||||
called = False
|
||||
|
||||
def foo_impl(x, *, y):
|
||||
return x * y
|
||||
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
|
||||
def vmap(info, in_dims, x, *, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
return x * y, 0
|
||||
|
||||
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
|
||||
|
||||
x = torch.ones(3)
|
||||
result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_register_vmap_defaults(self):
|
||||
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
|
||||
lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
|
||||
|
||||
def foo_impl(w, x=2, *, y=3, z):
|
||||
return w * x * y * z
|
||||
|
||||
lib.impl("foo", foo_impl, "CPU")
|
||||
|
||||
called = False
|
||||
|
||||
def vmap(info, in_dims, w, x=2, *, y=3, z):
|
||||
nonlocal called
|
||||
called = True
|
||||
return w * x * y * z, 0
|
||||
|
||||
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
|
||||
|
||||
w = torch.ones(3)
|
||||
result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, w * 2 * 3 * 42)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap(self):
|
||||
for mode in ["function", "qualname", "opoverload", "c_opdef"]:
|
||||
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
def f(x: Tensor, y: Tensor) -> Tensor:
|
||||
return x * y
|
||||
|
||||
called = False
|
||||
|
||||
def fvmap(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
if mode == "function":
|
||||
torch.library.register_vmap(
|
||||
f,
|
||||
fvmap,
|
||||
)
|
||||
elif mode == "qualname":
|
||||
torch.library.register_vmap(
|
||||
"mylib::f",
|
||||
fvmap,
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_vmap(
|
||||
torch.ops.mylib.f.default,
|
||||
fvmap,
|
||||
)
|
||||
elif mode == "c_opdef":
|
||||
f.register_vmap(
|
||||
fvmap,
|
||||
)
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
|
||||
called = False
|
||||
result = torch.vmap(f, out_dims=1)(x, y)
|
||||
self.assertEqual(result, (x * y).T)
|
||||
self.assertTrue(called)
|
||||
|
||||
called = False
|
||||
result = torch.vmap(f, in_dims=1)(x, y)
|
||||
self.assertEqual(result, (x * y).T)
|
||||
self.assertTrue(called)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_library_decorator(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
def f(x: Tensor, y: Tensor) -> Tensor:
|
||||
return x * y
|
||||
|
||||
called = False
|
||||
|
||||
@torch.library.register_vmap("mylib::f")
|
||||
def fvmap(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_op_decorator(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
def f(x: Tensor, y: Tensor) -> Tensor:
|
||||
return x * y
|
||||
|
||||
called = False
|
||||
|
||||
@f.register_vmap
|
||||
def fvmap(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_register_multiple_times(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
def f(x: Tensor, y: Tensor) -> Tensor:
|
||||
return x * y
|
||||
|
||||
called = False
|
||||
|
||||
@f.register_vmap
|
||||
def fvmap(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
called = False
|
||||
|
||||
@f.register_vmap
|
||||
def fvmap2(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x + y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x + y)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_vmap_register_multiple_times_2(self):
|
||||
@torch.library.custom_op("mylib::f", mutates_args=())
|
||||
def f(x: Tensor, y: Tensor) -> Tensor:
|
||||
return x * y
|
||||
|
||||
called = False
|
||||
|
||||
@torch.library.register_vmap("mylib::f")
|
||||
def fvmap(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
y = torch.randn(2, 2)
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x * y)
|
||||
called = False
|
||||
|
||||
@torch.library.register_vmap("mylib::f")
|
||||
def fvmap2(info, in_dims, x, y):
|
||||
nonlocal called
|
||||
called = True
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x + y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
result = torch.vmap(f)(x, y)
|
||||
self.assertTrue(called)
|
||||
self.assertEqual(result, x + y)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
@ -274,7 +274,17 @@ def validate_vmap_returns_tuple_of_two_elements(result):
|
||||
|
||||
|
||||
@custom_function_call.py_impl(TransformType.Vmap)
|
||||
def custom_function_call_vmap(interpreter, autograd_function, *operands):
|
||||
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
|
||||
if any(
|
||||
isinstance(val, torch.Tensor)
|
||||
for val in torch.utils._pytree.tree_flatten(kwargs)[0]
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"Run vmap on autograd.Function with kwarg-only Tensor args. "
|
||||
f"Please do not pass kwarg-only Tensors to autograd.Function. "
|
||||
f"Got: {kwargs}"
|
||||
)
|
||||
|
||||
if autograd_function.generate_vmap_rule:
|
||||
if has_overriden_vmap_rule(autograd_function):
|
||||
# TODO: Update link to stable once that's out
|
||||
@ -302,22 +312,32 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands):
|
||||
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
||||
)
|
||||
|
||||
return custom_function_call_vmap_helper(
|
||||
interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def custom_function_call_vmap_helper(
|
||||
interpreter, vmap_function, op, *operands, **kwargs
|
||||
):
|
||||
current_level = interpreter.level()
|
||||
info = VmapInfo(
|
||||
batch_size=interpreter.batch_size(),
|
||||
randomness=interpreter.randomness(),
|
||||
)
|
||||
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
|
||||
|
||||
# If none of the tensors are batched at the current level, then we skip the
|
||||
# current level. This saves the user from needing to handle this case in
|
||||
# their vmap staticmethod (and is consistent with our C++ batching rule API)
|
||||
if pytree.tree_all(lambda dim: dim is None, in_dims):
|
||||
with interpreter.lower():
|
||||
return custom_function_call(autograd_function, *operands)
|
||||
if isinstance(op, torch.autograd.function.FunctionMeta):
|
||||
return custom_function_call(op, *operands)
|
||||
else:
|
||||
return op(*operands, **kwargs)
|
||||
|
||||
with interpreter.lower():
|
||||
result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
|
||||
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
|
||||
validate_vmap_returns_tuple_of_two_elements(result)
|
||||
unwrapped_output, out_dims = result
|
||||
|
||||
|
@ -180,6 +180,7 @@ class CustomOpDef:
|
||||
self._setup_context_fn: Optional[Callable] = None
|
||||
self._backward_fn: Optional[Callable] = None
|
||||
self._torch_dispatch_fns: Dict[type, Callable] = {}
|
||||
self._vmap_fn: Optional[Callable] = None
|
||||
|
||||
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
||||
self._register_to_dispatcher()
|
||||
@ -662,6 +663,103 @@ class CustomOpDef:
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._opoverload(*args, **kwargs)
|
||||
|
||||
def register_vmap(
|
||||
self,
|
||||
func: Optional[Callable] = None,
|
||||
):
|
||||
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
|
||||
|
||||
This API may be used as a decorator.
|
||||
|
||||
In order for an operator to work with :func:`torch.vmap`, you may need to register a
|
||||
vmap implementation in the following signature:
|
||||
|
||||
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
|
||||
|
||||
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
|
||||
|
||||
It specifies how do we compute the batched version of ``op`` given inputs with an additional
|
||||
dimension (specified by ``in_dims``).
|
||||
|
||||
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
|
||||
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
|
||||
specifying what dimension of the Tensor is being vmapped over.
|
||||
|
||||
``info`` is a collection of additional metadata that may be helpful:
|
||||
``info.batch_size`` specifies the size of the dimension being vmapped over, while
|
||||
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
|
||||
|
||||
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
|
||||
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
|
||||
per output that specifies if the output has the vmapped dimension and what index it is in.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from torch import Tensor
|
||||
>>> from typing import Tuple
|
||||
>>>
|
||||
>>> def to_numpy(tensor):
|
||||
>>> return tensor.cpu().numpy()
|
||||
>>>
|
||||
>>> lib = torch.library.Library("mylib", "FRAGMENT")
|
||||
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
|
||||
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
>>> x_np = to_numpy(x)
|
||||
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
|
||||
>>> return torch.tensor(x_np ** 3, device=x.device), dx
|
||||
>>>
|
||||
>>> def numpy_cube_vmap(info, in_dims, x):
|
||||
>>> result = numpy_cube(x)
|
||||
>>> return result, (in_dims[0], in_dims[0])
|
||||
>>>
|
||||
>>> numpy_cube.register_vmap(numpy_cube_vmap)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> torch.vmap(numpy_cube)(x)
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
|
||||
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
|
||||
>>>
|
||||
>>> @numpy_mul.register_vmap
|
||||
>>> def numpy_mul_vmap(info, in_dims, x, y):
|
||||
>>> x_bdim, y_bdim = in_dims
|
||||
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
>>> result = x * y
|
||||
>>> result = result.movedim(-1, 0)
|
||||
>>> return result, 0
|
||||
>>>
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = torch.randn(3)
|
||||
>>> torch.vmap(numpy_mul)(x, y)
|
||||
"""
|
||||
from torch._functorch.autograd_function import custom_function_call_vmap_helper
|
||||
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
|
||||
|
||||
def register(func):
|
||||
need_register = self._vmap_fn is None
|
||||
self._vmap_fn = func
|
||||
|
||||
if need_register:
|
||||
|
||||
def wrapped_func(keyset, *args, **kwargs):
|
||||
interpreter = retrieve_current_functorch_interpreter()
|
||||
return custom_function_call_vmap_helper(
|
||||
interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
|
||||
)
|
||||
|
||||
self._lib.impl(
|
||||
self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
else:
|
||||
return register(func)
|
||||
|
||||
|
||||
# NOTE: [Supporting decorator and non-decorator usage]
|
||||
#
|
||||
|
130
torch/library.py
130
torch/library.py
@ -954,6 +954,136 @@ def register_torch_dispatch(
|
||||
return register(func)
|
||||
|
||||
|
||||
def register_vmap(
|
||||
op: _op_identifier,
|
||||
func: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
lib=None,
|
||||
):
|
||||
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
|
||||
|
||||
This API may be used as a decorator (see examples).
|
||||
|
||||
In order for an operator to work with :func:`torch.vmap`, you may need to register a
|
||||
vmap implementation in the following signature:
|
||||
|
||||
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
|
||||
|
||||
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
|
||||
We do not support kwarg-only Tensor args.
|
||||
|
||||
It specifies how do we compute the batched version of ``op`` given inputs with an additional
|
||||
dimension (specified by ``in_dims``).
|
||||
|
||||
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
|
||||
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
|
||||
specifying what dimension of the Tensor is being vmapped over.
|
||||
|
||||
``info`` is a collection of additional metadata that may be helpful:
|
||||
``info.batch_size`` specifies the size of the dimension being vmapped over, while
|
||||
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
|
||||
|
||||
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
|
||||
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
|
||||
per output that specifies if the output has the vmapped dimension and what index it is in.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from torch import Tensor
|
||||
>>> from typing import Tuple
|
||||
>>>
|
||||
>>> def to_numpy(tensor):
|
||||
>>> return tensor.cpu().numpy()
|
||||
>>>
|
||||
>>> lib = torch.library.Library("mylib", "FRAGMENT")
|
||||
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
|
||||
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
>>> x_np = to_numpy(x)
|
||||
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
|
||||
>>> return torch.tensor(x_np ** 3, device=x.device), dx
|
||||
>>>
|
||||
>>> def numpy_cube_vmap(info, in_dims, x):
|
||||
>>> result = numpy_cube(x)
|
||||
>>> return result, (in_dims[0], in_dims[0])
|
||||
>>>
|
||||
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> torch.vmap(numpy_cube)(x)
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
|
||||
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
|
||||
>>>
|
||||
>>> @torch.library.register_vmap("mylib::numpy_mul")
|
||||
>>> def numpy_mul_vmap(info, in_dims, x, y):
|
||||
>>> x_bdim, y_bdim = in_dims
|
||||
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
>>> result = x * y
|
||||
>>> result = result.movedim(-1, 0)
|
||||
>>> return result, 0
|
||||
>>>
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = torch.randn(3)
|
||||
>>> torch.vmap(numpy_mul)(x, y)
|
||||
|
||||
.. note::
|
||||
The vmap function should aim to preserve the semantics of the entire custom operator.
|
||||
That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
|
||||
|
||||
If your custom operator has any custom behavior in the backward pass, please
|
||||
keep this in mind.
|
||||
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}")
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
return opdef.register_vmap(func)
|
||||
assert isinstance(op, str)
|
||||
qualname = op
|
||||
op = torch._library.utils.lookup_op(qualname)
|
||||
schema = op._schema
|
||||
if _library.utils.has_kwarg_only_tensors(schema):
|
||||
raise NotImplementedError(
|
||||
f"register_vmap with kwarg-only Tensor args. In the original "
|
||||
f"definition of the op, please make your tensors not kwarg-only. "
|
||||
f"Got: {schema}"
|
||||
)
|
||||
|
||||
def register(func):
|
||||
nonlocal op, lib
|
||||
|
||||
namespace, opname = torch._library.utils.parse_namespace(qualname)
|
||||
if lib is None:
|
||||
lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(lib)
|
||||
|
||||
from torch._functorch.autograd_function import custom_function_call_vmap_helper
|
||||
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
|
||||
|
||||
def wrapped_func(keyset, *args, **kwargs):
|
||||
interpreter = retrieve_current_functorch_interpreter()
|
||||
return custom_function_call_vmap_helper(
|
||||
interpreter, func, op, *args, **kwargs
|
||||
)
|
||||
|
||||
lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
else:
|
||||
return register(func)
|
||||
|
||||
|
||||
# If the op was defined in C++, then we want to make sure there was an
|
||||
# m.set_python_module(module, ...) call and that the module is the
|
||||
# same as the module that called torch.library.register_fake.
|
||||
|
@ -50,6 +50,12 @@ def numpy_cube_backward(ctx, grad_out, grad_dx):
|
||||
|
||||
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
|
||||
|
||||
def numpy_cube_vmap(info, in_dims, x):
|
||||
result = numpy_cube(x)
|
||||
return result, (in_dims[0], in_dims[0])
|
||||
|
||||
numpy_cube.register_vmap(numpy_cube_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
|
||||
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
|
||||
@ -70,6 +76,16 @@ def numpy_mul_backward(ctx, grad_out):
|
||||
|
||||
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
|
||||
|
||||
def numpy_mul_vmap(info, in_dims, x, y):
|
||||
x_bdim, y_bdim = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
result = x * y
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
numpy_mul.register_vmap(numpy_mul_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
|
||||
def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
|
||||
return torch.tensor(to_numpy(x) * scalar, device=x.device)
|
||||
@ -87,6 +103,15 @@ def numpy_mul_scalar_backward(ctx, grad_out):
|
||||
|
||||
numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
|
||||
|
||||
def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
|
||||
x_bdim, = in_dims
|
||||
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
result = x * scalar
|
||||
result = result.movedim(-1, 0)
|
||||
return result, 0
|
||||
|
||||
numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
|
||||
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
device = x.device
|
||||
@ -116,6 +141,14 @@ def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
|
||||
|
||||
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
|
||||
|
||||
def numpy_sort_vmap(info, in_dims, x, dim):
|
||||
x_bdim, _ = in_dims
|
||||
x = x.movedim(x_bdim, 0)
|
||||
dim = dim if dim >= 0 else dim + x.dim() - 1
|
||||
result = numpy_sort(x, dim + 1)
|
||||
return result, (0, 0, 0)
|
||||
|
||||
numpy_sort.register_vmap(numpy_sort_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
|
||||
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
|
||||
@ -144,6 +177,26 @@ def numpy_take_backward(ctx, grad_out):
|
||||
|
||||
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
|
||||
|
||||
def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
|
||||
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
|
||||
|
||||
# wrap dim
|
||||
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
|
||||
dim = dim if dim >= 0 else dim + logical_dim
|
||||
|
||||
def expand_bdim(x, x_bdim):
|
||||
if x_bdim is None:
|
||||
return x.expand(info.batch_size, *x.shape)
|
||||
return x.movedim(x_bdim, 0)
|
||||
|
||||
x = expand_bdim(x, x_bdim)
|
||||
ind = expand_bdim(ind, ind_bdim)
|
||||
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
|
||||
|
||||
return numpy_take(x, ind, ind_inv, dim + 1), 0
|
||||
|
||||
numpy_take.register_vmap(numpy_take_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
|
||||
def numpy_nonzero(x: Tensor) -> Tensor:
|
||||
x_np = to_numpy(x)
|
||||
@ -170,6 +223,11 @@ def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
yield SampleInput(result, args=())
|
||||
|
||||
def numpy_nonzero_vmap(info, in_dims, x):
|
||||
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
|
||||
|
||||
numpy_nonzero.register_vmap(numpy_nonzero_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
|
||||
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
|
||||
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
|
||||
@ -186,6 +244,16 @@ def numpy_view_copy_backward(ctx, grad_out):
|
||||
|
||||
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
|
||||
|
||||
def numpy_view_copy_vmap(info, in_dims, x, shape):
|
||||
x_bdim, _ = in_dims
|
||||
x = x.movedim(x_bdim, 0)
|
||||
x_shape = x.shape[0]
|
||||
batch_shape = (x_shape, *shape)
|
||||
result = numpy_view_copy(x, batch_shape)
|
||||
return result, 0
|
||||
|
||||
numpy_view_copy.register_vmap(numpy_view_copy_vmap)
|
||||
|
||||
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
result = make_arg(2, 3, 4, low=0.9, high=2)
|
||||
@ -222,6 +290,13 @@ def numpy_cat_backward(ctx, grad_out):
|
||||
|
||||
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
|
||||
|
||||
def numpy_cat_vmap(info, in_dims, x, dim):
|
||||
x_bdim, = in_dims
|
||||
result = numpy_cat(x, dim)
|
||||
return result, x_bdim
|
||||
|
||||
numpy_cat.register_vmap(numpy_cat_vmap)
|
||||
|
||||
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
r0 = make_arg(2, 3, 4, low=0.9, high=2)
|
||||
@ -249,6 +324,14 @@ def numpy_split_copy_backward(ctx, grad_out):
|
||||
|
||||
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
|
||||
|
||||
def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
|
||||
x_bdim, _ , _ = in_dims
|
||||
x = x.movedim(x_bdim, 0)
|
||||
result = numpy_split_copy(x, splits, dim + 1)
|
||||
return result, 0
|
||||
|
||||
numpy_split_copy.register_vmap(numpy_split_copy_vmap)
|
||||
|
||||
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
x = make_arg(2, 9, low=0.9, high=2)
|
||||
@ -275,6 +358,14 @@ numpy_split_copy_with_int.register_autograd(
|
||||
numpy_split_copy_with_int_backward,
|
||||
setup_context=numpy_split_copy_with_int_setup_context)
|
||||
|
||||
def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
|
||||
x_bdim, _ , _ = in_dims
|
||||
x = x.movedim(x_bdim, 0)
|
||||
result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
|
||||
return (result, len_split), ([0 for _ in range(len(result))], None)
|
||||
|
||||
numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
|
||||
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
|
||||
# Adapted from Ross Girshick's fast-rcnn implementation at
|
||||
@ -331,6 +422,11 @@ def _(boxes, scores, iou_threshold):
|
||||
result = boxes.new_empty([i0], dtype=torch.int64)
|
||||
return result
|
||||
|
||||
def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
|
||||
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
|
||||
|
||||
numpy_nms.register_vmap(numpy_nms_vmap)
|
||||
|
||||
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
|
||||
N = 64
|
||||
|
Reference in New Issue
Block a user