[custom ops] Add register_vmap for custom ops (#130589)

Fixes #130284
Fixes #130653

- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
This commit is contained in:
Shangdi Yu
2024-07-23 17:48:36 +00:00
committed by PyTorch MergeBot
parent 404d640c39
commit 68c725a094
9 changed files with 680 additions and 15 deletions

View File

@ -42,6 +42,7 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: register_kernel
.. autofunction:: register_autograd
.. autofunction:: register_fake
.. autofunction:: register_vmap
.. autofunction:: impl_abstract
.. autofunction:: get_ctx
.. autofunction:: register_torch_dispatch

View File

@ -18,6 +18,7 @@ from torch.testing._internal.autograd_function_db import autograd_function_db
from torch.testing._internal.common_device_type import toleranceOverride
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
from torch.testing._internal.common_modules import module_db
from torch.testing._internal.custom_op_db import custom_op_db
IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1"
@ -38,8 +39,26 @@ def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
flat_out, out_spec = pytree.tree_flatten(out)
outs.append(flat_out)
# use the same out_dim for all outputs
if isinstance(out_dim, int):
flat_out_dim = [out_dim for _ in flat_out]
else:
flat_out_dim, _ = pytree.tree_flatten(out_dim)
outs = zip(*outs)
result = [torch.stack(out_lst) for out_lst in outs]
result = []
for i, out_lst in enumerate(outs):
if flat_out_dim[i] is not None:
if not all(isinstance(x, torch.Tensor) for x in out_lst):
raise ValueError(
f"vmap `{op}` must only return "
"Tensors. Did you mean to set out_dims= to None for output?"
)
result.append(torch.stack(out_lst))
else:
# not batched over, result should be the same for all batches
result.append(out_lst[0])
return pytree.tree_unflatten(result, out_spec)
@ -317,9 +336,9 @@ def _compute_quantities_for_vmap_test(
inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
outer_in_dims = (0,) + in_dims
batched_args, kwarg_values = maybe_clone_inputs()
vmapvmap_output = vmap(vmap(f, inner_in_dims), outer_in_dims)(
dummy, *batched_args, **kwarg_values
)
vmapvmap_output = vmap(
vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim
)(dummy, *batched_args, **kwarg_values)
yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected)
@ -440,7 +459,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db + additional_op_db + autograd_function_db
all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db
for decorate_meta in to_skip:
matching_opinfos = [
o

View File

@ -1617,6 +1617,28 @@ class TestAutogradFunctionVmapAPI(TestCase):
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
result = vmap(Zeros.apply)(x)
def test_kwarg_only_tensors(self, device):
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
class MyClass(torch.autograd.Function):
@staticmethod
def forward(x, *, y):
return x + y
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, x, *, y):
assert in_dims == (0,)
return x + y, 0
x = torch.randn(3)
y = torch.randn(3)
vmap(MyClass.apply)(x, y=y)
@markDynamoStrictTest
class TestVmapOfGrad(TestCase):

View File

@ -68,6 +68,7 @@ from torch.testing._internal.common_utils import (
unMarkDynamoStrictTest,
xfailIfTorchDynamo,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.utils import _pytree as pytree
@ -3937,10 +3938,17 @@ def discover_variants(opinfo):
@unMarkDynamoStrictTest
class TestVmapOperatorsOpInfo(TestCase):
def vmap_outplace_test(
self, func, args, kwargs, in_dims, check_shape_only=False, postprocess_fn=None
self,
func,
args,
kwargs,
in_dims,
check_shape_only=False,
postprocess_fn=None,
out_dim=0,
):
for vmap_out, loop_out in compute_quantities_for_vmap_test(
func, args, kwargs, in_dims
func, args, kwargs, in_dims, out_dim=out_dim
):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
@ -3950,7 +3958,9 @@ class TestVmapOperatorsOpInfo(TestCase):
continue
self.assertEqual(vmap_out, loop_out)
def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
def vmap_inplace_test(
self, func, args, kwargs, in_dims, postprocess_fn=None, out_dim=0
):
# NB: This test assumes that the first argument is being modified.
# This is OK because it's what every other OpInfo-based test assumes,
# but it is going to need a more robust solution eventually.
@ -3963,13 +3973,19 @@ class TestVmapOperatorsOpInfo(TestCase):
args,
kwargs,
in_dims,
out_dim=out_dim,
compute_loop_out=False,
clone_inputs=True,
):
pass
return
for vmap_out, loop_out in compute_quantities_for_vmap_test(
func, args, kwargs, in_dims, clone_inputs=True
func,
args,
kwargs,
in_dims,
clone_inputs=True,
out_dim=out_dim,
):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
@ -4027,6 +4043,13 @@ class TestVmapOperatorsOpInfo(TestCase):
continue
kwargs = sample_input.kwargs
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
out_dim = 0
if op.name == "NumpySplitCopyWithIntCustomOp":
# special case for this custom op
def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim):
return [0 for _ in range(len(splits) + 1)], None
out_dim = sample_vmap_out_dim_numpy_split_copy_with_int(*args)
for batched_args, in_dims, _ in generate_vmap_inputs(
args, {}, is_batch_norm_and_training=is_batch_norm_and_training
):
@ -4038,6 +4061,7 @@ class TestVmapOperatorsOpInfo(TestCase):
in_dims,
check_shape_only,
postprocess_fn,
out_dim=out_dim,
)
if op.name in skip_inplace:
continue
@ -4109,6 +4133,9 @@ class TestVmapOperatorsOpInfo(TestCase):
"linalg.eigh", ""
), # not always return the same result for the same input, see test_linalg_eigh for manual test
skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format
# UnimplementedError: data-dependent operators cannot be vmapped
xfail("NumpyNonzeroCustomOp"),
xfail("NumpyNMSCustomOp"),
# ----------------------------------------------------------------------
# ---------------------------- BUGS ------------------------------------
# entries in here don't work and need to be fixed.
@ -4187,7 +4214,10 @@ class TestVmapOperatorsOpInfo(TestCase):
}
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@ops(
op_db + additional_op_db + autograd_function_db + custom_op_db,
dtypes=OpDTypes.any_one,
)
@opsToleranceOverride(
"TestVmapOperatorsOpInfo",
"test_vmap_exhaustive",
@ -4248,7 +4278,10 @@ class TestVmapOperatorsOpInfo(TestCase):
)
@with_tf32_off
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@ops(
op_db + additional_op_db + autograd_function_db + custom_op_db,
dtypes=OpDTypes.any_one,
)
@opsToleranceOverride(
"TestVmapOperatorsOpInfo",
"test_op_has_batch_rule",

View File

@ -2327,6 +2327,12 @@ class TestCustomOpAPI(TestCase):
setup_context=lambda ctx, inputs, keyword_only_inputs, output: None,
)
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
torch.library.register_vmap(
"_torch_testing::foo",
lambda info, in_dims, x, *, y: (x, 0),
)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_register_autograd_kwargonly_low_level(self):
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
@ -3382,6 +3388,246 @@ Please use `add.register_fake` to add an fake impl.""",
with f.set_kernel_enabled("cpu", enabled=False):
self.assertEqual(f(x), x + 1)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_register_vmap_kwargonly_low_level(self):
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
lib.define("foo(Tensor x, *, float y) -> Tensor")
called = False
def foo_impl(x, *, y):
return x * y
lib.impl("foo", foo_impl, "CPU")
def vmap(info, in_dims, x, *, y):
nonlocal called
called = True
return x * y, 0
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
x = torch.ones(3)
result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14)
self.assertTrue(called)
self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14]))
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_register_vmap_defaults(self):
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor")
def foo_impl(w, x=2, *, y=3, z):
return w * x * y * z
lib.impl("foo", foo_impl, "CPU")
called = False
def vmap(info, in_dims, w, x=2, *, y=3, z):
nonlocal called
called = True
return w * x * y * z, 0
torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib)
w = torch.ones(3)
result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42)
self.assertTrue(called)
self.assertEqual(result, w * 2 * 3 * 42)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap(self):
for mode in ["function", "qualname", "opoverload", "c_opdef"]:
@torch.library.custom_op("mylib::f", mutates_args=())
def f(x: Tensor, y: Tensor) -> Tensor:
return x * y
called = False
def fvmap(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
if mode == "function":
torch.library.register_vmap(
f,
fvmap,
)
elif mode == "qualname":
torch.library.register_vmap(
"mylib::f",
fvmap,
)
elif mode == "opoverload":
torch.library.register_vmap(
torch.ops.mylib.f.default,
fvmap,
)
elif mode == "c_opdef":
f.register_vmap(
fvmap,
)
x = torch.randn(2, 2)
y = torch.randn(2, 2)
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x * y)
called = False
result = torch.vmap(f, out_dims=1)(x, y)
self.assertEqual(result, (x * y).T)
self.assertTrue(called)
called = False
result = torch.vmap(f, in_dims=1)(x, y)
self.assertEqual(result, (x * y).T)
self.assertTrue(called)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_library_decorator(self):
@torch.library.custom_op("mylib::f", mutates_args=())
def f(x: Tensor, y: Tensor) -> Tensor:
return x * y
called = False
@torch.library.register_vmap("mylib::f")
def fvmap(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
x = torch.randn(2, 2)
y = torch.randn(2, 2)
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x * y)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_op_decorator(self):
@torch.library.custom_op("mylib::f", mutates_args=())
def f(x: Tensor, y: Tensor) -> Tensor:
return x * y
called = False
@f.register_vmap
def fvmap(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
x = torch.randn(2, 2)
y = torch.randn(2, 2)
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x * y)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_register_multiple_times(self):
@torch.library.custom_op("mylib::f", mutates_args=())
def f(x: Tensor, y: Tensor) -> Tensor:
return x * y
called = False
@f.register_vmap
def fvmap(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
x = torch.randn(2, 2)
y = torch.randn(2, 2)
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x * y)
called = False
@f.register_vmap
def fvmap2(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x + y
result = result.movedim(-1, 0)
return result, 0
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x + y)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap_register_multiple_times_2(self):
@torch.library.custom_op("mylib::f", mutates_args=())
def f(x: Tensor, y: Tensor) -> Tensor:
return x * y
called = False
@torch.library.register_vmap("mylib::f")
def fvmap(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
x = torch.randn(2, 2)
y = torch.randn(2, 2)
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x * y)
called = False
@torch.library.register_vmap("mylib::f")
def fvmap2(info, in_dims, x, y):
nonlocal called
called = True
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x + y
result = result.movedim(-1, 0)
return result, 0
result = torch.vmap(f)(x, y)
self.assertTrue(called)
self.assertEqual(result, x + y)
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"

View File

@ -274,7 +274,17 @@ def validate_vmap_returns_tuple_of_two_elements(result):
@custom_function_call.py_impl(TransformType.Vmap)
def custom_function_call_vmap(interpreter, autograd_function, *operands):
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
if any(
isinstance(val, torch.Tensor)
for val in torch.utils._pytree.tree_flatten(kwargs)[0]
):
raise NotImplementedError(
f"Run vmap on autograd.Function with kwarg-only Tensor args. "
f"Please do not pass kwarg-only Tensors to autograd.Function. "
f"Got: {kwargs}"
)
if autograd_function.generate_vmap_rule:
if has_overriden_vmap_rule(autograd_function):
# TODO: Update link to stable once that's out
@ -302,22 +312,32 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands):
f"https://pytorch.org/docs/main/notes/extending.func.html"
)
return custom_function_call_vmap_helper(
interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
)
def custom_function_call_vmap_helper(
interpreter, vmap_function, op, *operands, **kwargs
):
current_level = interpreter.level()
info = VmapInfo(
batch_size=interpreter.batch_size(),
randomness=interpreter.randomness(),
)
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
return custom_function_call(autograd_function, *operands)
if isinstance(op, torch.autograd.function.FunctionMeta):
return custom_function_call(op, *operands)
else:
return op(*operands, **kwargs)
with interpreter.lower():
result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
validate_vmap_returns_tuple_of_two_elements(result)
unwrapped_output, out_dims = result

View File

@ -180,6 +180,7 @@ class CustomOpDef:
self._setup_context_fn: Optional[Callable] = None
self._backward_fn: Optional[Callable] = None
self._torch_dispatch_fns: Dict[type, Callable] = {}
self._vmap_fn: Optional[Callable] = None
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher()
@ -662,6 +663,103 @@ class CustomOpDef:
def __call__(self, *args, **kwargs):
return self._opoverload(*args, **kwargs)
def register_vmap(
self,
func: Optional[Callable] = None,
):
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
This API may be used as a decorator.
In order for an operator to work with :func:`torch.vmap`, you may need to register a
vmap implementation in the following signature:
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
It specifies how do we compute the batched version of ``op`` given inputs with an additional
dimension (specified by ``in_dims``).
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
specifying what dimension of the Tensor is being vmapped over.
``info`` is a collection of additional metadata that may be helpful:
``info.batch_size`` specifies the size of the dimension being vmapped over, while
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
per output that specifies if the output has the vmapped dimension and what index it is in.
Examples:
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>> return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>> x_np = to_numpy(x)
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>> return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>> result = numpy_cube(x)
>>> return result, (in_dims[0], in_dims[0])
>>>
>>> numpy_cube.register_vmap(numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @numpy_mul.register_vmap
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>> x_bdim, y_bdim = in_dims
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>> result = x * y
>>> result = result.movedim(-1, 0)
>>> return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)
"""
from torch._functorch.autograd_function import custom_function_call_vmap_helper
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
def register(func):
need_register = self._vmap_fn is None
self._vmap_fn = func
if need_register:
def wrapped_func(keyset, *args, **kwargs):
interpreter = retrieve_current_functorch_interpreter()
return custom_function_call_vmap_helper(
interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
)
self._lib.impl(
self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
)
if func is None:
return register
else:
return register(func)
# NOTE: [Supporting decorator and non-decorator usage]
#

View File

@ -954,6 +954,136 @@ def register_torch_dispatch(
return register(func)
def register_vmap(
op: _op_identifier,
func: Optional[Callable] = None,
/,
*,
lib=None,
):
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
This API may be used as a decorator (see examples).
In order for an operator to work with :func:`torch.vmap`, you may need to register a
vmap implementation in the following signature:
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
We do not support kwarg-only Tensor args.
It specifies how do we compute the batched version of ``op`` given inputs with an additional
dimension (specified by ``in_dims``).
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
specifying what dimension of the Tensor is being vmapped over.
``info`` is a collection of additional metadata that may be helpful:
``info.batch_size`` specifies the size of the dimension being vmapped over, while
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
per output that specifies if the output has the vmapped dimension and what index it is in.
Examples:
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>> from typing import Tuple
>>>
>>> def to_numpy(tensor):
>>> return tensor.cpu().numpy()
>>>
>>> lib = torch.library.Library("mylib", "FRAGMENT")
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
>>> x_np = to_numpy(x)
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
>>> return torch.tensor(x_np ** 3, device=x.device), dx
>>>
>>> def numpy_cube_vmap(info, in_dims, x):
>>> result = numpy_cube(x)
>>> return result, (in_dims[0], in_dims[0])
>>>
>>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
>>>
>>> x = torch.randn(3)
>>> torch.vmap(numpy_cube)(x)
>>>
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
>>>
>>> @torch.library.register_vmap("mylib::numpy_mul")
>>> def numpy_mul_vmap(info, in_dims, x, y):
>>> x_bdim, y_bdim = in_dims
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
>>> result = x * y
>>> result = result.movedim(-1, 0)
>>> return result, 0
>>>
>>>
>>> x = torch.randn(3)
>>> y = torch.randn(3)
>>> torch.vmap(numpy_mul)(x, y)
.. note::
The vmap function should aim to preserve the semantics of the entire custom operator.
That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
If your custom operator has any custom behavior in the backward pass, please
keep this in mind.
"""
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
return opdef.register_vmap(func)
assert isinstance(op, str)
qualname = op
op = torch._library.utils.lookup_op(qualname)
schema = op._schema
if _library.utils.has_kwarg_only_tensors(schema):
raise NotImplementedError(
f"register_vmap with kwarg-only Tensor args. In the original "
f"definition of the op, please make your tensors not kwarg-only. "
f"Got: {schema}"
)
def register(func):
nonlocal op, lib
namespace, opname = torch._library.utils.parse_namespace(qualname)
if lib is None:
lib = Library(namespace, "FRAGMENT")
_keep_alive.append(lib)
from torch._functorch.autograd_function import custom_function_call_vmap_helper
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
def wrapped_func(keyset, *args, **kwargs):
interpreter = retrieve_current_functorch_interpreter()
return custom_function_call_vmap_helper(
interpreter, func, op, *args, **kwargs
)
lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
if func is None:
return register
else:
return register(func)
# If the op was defined in C++, then we want to make sure there was an
# m.set_python_module(module, ...) call and that the module is the
# same as the module that called torch.library.register_fake.

View File

@ -50,6 +50,12 @@ def numpy_cube_backward(ctx, grad_out, grad_dx):
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
def numpy_cube_vmap(info, in_dims, x):
result = numpy_cube(x)
return result, (in_dims[0], in_dims[0])
numpy_cube.register_vmap(numpy_cube_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@ -70,6 +76,16 @@ def numpy_mul_backward(ctx, grad_out):
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
def numpy_mul_vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
numpy_mul.register_vmap(numpy_mul_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
return torch.tensor(to_numpy(x) * scalar, device=x.device)
@ -87,6 +103,15 @@ def numpy_mul_scalar_backward(ctx, grad_out):
numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
x_bdim, = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
result = x * scalar
result = result.movedim(-1, 0)
return result, 0
numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
device = x.device
@ -116,6 +141,14 @@ def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
def numpy_sort_vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
dim = dim if dim >= 0 else dim + x.dim() - 1
result = numpy_sort(x, dim + 1)
return result, (0, 0, 0)
numpy_sort.register_vmap(numpy_sort_vmap)
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
@ -144,6 +177,26 @@ def numpy_take_backward(ctx, grad_out):
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return numpy_take(x, ind, ind_inv, dim + 1), 0
numpy_take.register_vmap(numpy_take_vmap)
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
def numpy_nonzero(x: Tensor) -> Tensor:
x_np = to_numpy(x)
@ -170,6 +223,11 @@ def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
yield SampleInput(result, args=())
def numpy_nonzero_vmap(info, in_dims, x):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nonzero.register_vmap(numpy_nonzero_vmap)
@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
@ -186,6 +244,16 @@ def numpy_view_copy_backward(ctx, grad_out):
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
def numpy_view_copy_vmap(info, in_dims, x, shape):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
x_shape = x.shape[0]
batch_shape = (x_shape, *shape)
result = numpy_view_copy(x, batch_shape)
return result, 0
numpy_view_copy.register_vmap(numpy_view_copy_vmap)
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
result = make_arg(2, 3, 4, low=0.9, high=2)
@ -222,6 +290,13 @@ def numpy_cat_backward(ctx, grad_out):
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
def numpy_cat_vmap(info, in_dims, x, dim):
x_bdim, = in_dims
result = numpy_cat(x, dim)
return result, x_bdim
numpy_cat.register_vmap(numpy_cat_vmap)
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
r0 = make_arg(2, 3, 4, low=0.9, high=2)
@ -249,6 +324,14 @@ def numpy_split_copy_backward(ctx, grad_out):
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result = numpy_split_copy(x, splits, dim + 1)
return result, 0
numpy_split_copy.register_vmap(numpy_split_copy_vmap)
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
x = make_arg(2, 9, low=0.9, high=2)
@ -275,6 +358,14 @@ numpy_split_copy_with_int.register_autograd(
numpy_split_copy_with_int_backward,
setup_context=numpy_split_copy_with_int_setup_context)
def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
return (result, len_split), ([0 for _ in range(len(result))], None)
numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
# Adapted from Ross Girshick's fast-rcnn implementation at
@ -331,6 +422,11 @@ def _(boxes, scores, iou_threshold):
result = boxes.new_empty([i0], dtype=torch.int64)
return result
def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nms.register_vmap(numpy_nms_vmap)
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
N = 64