mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove functorch dispatch keys in legacyExtractDispatchKey
(#133018)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133018 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd565bc455
commit
c518b50c4c
@ -911,6 +911,9 @@ inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
|
||||
DispatchKeySet(
|
||||
{DispatchKey::Functionalize,
|
||||
DispatchKey::PythonTLSSnapshot,
|
||||
DispatchKey::FuncTorchGradWrapper,
|
||||
DispatchKey::FuncTorchVmapMode,
|
||||
DispatchKey::FuncTorchBatched,
|
||||
DispatchKey::Python}))
|
||||
.highestPriorityTypeId();
|
||||
}
|
||||
|
@ -2637,6 +2637,17 @@ class TestJvp(TestCase):
|
||||
self.assertTrue(isinstance(result, tuple))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_jvp_new_tensor(self):
|
||||
def f(x):
|
||||
y = x.new_tensor(0.5)
|
||||
return x + y
|
||||
|
||||
x = torch.rand(10, 10)
|
||||
tangents = torch.zeros_like(x)
|
||||
actual = jvp(f, (x,), (tangents,))
|
||||
expected = (f(x), torch.zeros_like(x))
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_primals_tangents_length_mismatch(self, device):
|
||||
x = torch.randn(2, 3, device=device)
|
||||
t = torch.randn(2, 3, device=device)
|
||||
|
Reference in New Issue
Block a user