mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom ops] Add register_vmap for custom ops (#130589)
Fixes #130284 Fixes #130653 - Add `torch.library.register_vmap` to custom ops - Add `register_vmap` for operators in ops in custom_op_db. - Make `torch.autograd.Function` support kwarg-only kwargs for vmap - test operators in op_db with `tests/test_vmap`. - change `test_vmap` to allow custom `out_dim` and allow "None" in `out_dim` when testing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130589 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e5ecc4277
commit
074b420641
@ -1617,6 +1617,28 @@ class TestAutogradFunctionVmapAPI(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "returned an incompatible"):
|
||||
result = vmap(Zeros.apply)(x)
|
||||
|
||||
def test_kwarg_only_tensors(self, device):
|
||||
with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"):
|
||||
|
||||
class MyClass(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(x, *, y):
|
||||
return x + y
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def vmap(info, in_dims, x, *, y):
|
||||
assert in_dims == (0,)
|
||||
return x + y, 0
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
|
||||
vmap(MyClass.apply)(x, y=y)
|
||||
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestVmapOfGrad(TestCase):
|
||||
|
Reference in New Issue
Block a user