[custom ops] Add register_vmap for custom ops (#130589)

Fixes #130284
Fixes #130653

- Add `torch.library.register_vmap` to custom ops
- Add `register_vmap` for operators in ops in custom_op_db.
- Make `torch.autograd.Function` support kwarg-only kwargs for vmap
- test operators in op_db with `tests/test_vmap`.
- change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589
Approved by: https://github.com/zou3519
This commit is contained in:
Shangdi Yu
2024-07-23 17:48:36 +00:00
committed by PyTorch MergeBot
parent 404d640c39
commit 68c725a094
9 changed files with 680 additions and 15 deletions

View File

@ -1617,6 +1617,28 @@ class TestAutogradFunctionVmapAPI(TestCase):
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
result = vmap(Zeros.apply)(x)
def test_kwarg_only_tensors(self, device):
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
class MyClass(torch.autograd.Function):
@staticmethod
def forward(x, *, y):
return x + y
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def vmap(info, in_dims, x, *, y):
assert in_dims == (0,)
return x + y, 0
x = torch.randn(3)
y = torch.randn(3)
vmap(MyClass.apply)(x, y=y)
@markDynamoStrictTest
class TestVmapOfGrad(TestCase):