mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove functorch dispatch keys in legacyExtractDispatchKey
(#133018)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133018 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd565bc455
commit
c518b50c4c
@ -911,6 +911,9 @@ inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
|
|||||||
DispatchKeySet(
|
DispatchKeySet(
|
||||||
{DispatchKey::Functionalize,
|
{DispatchKey::Functionalize,
|
||||||
DispatchKey::PythonTLSSnapshot,
|
DispatchKey::PythonTLSSnapshot,
|
||||||
|
DispatchKey::FuncTorchGradWrapper,
|
||||||
|
DispatchKey::FuncTorchVmapMode,
|
||||||
|
DispatchKey::FuncTorchBatched,
|
||||||
DispatchKey::Python}))
|
DispatchKey::Python}))
|
||||||
.highestPriorityTypeId();
|
.highestPriorityTypeId();
|
||||||
}
|
}
|
||||||
|
@ -2637,6 +2637,17 @@ class TestJvp(TestCase):
|
|||||||
self.assertTrue(isinstance(result, tuple))
|
self.assertTrue(isinstance(result, tuple))
|
||||||
self.assertEqual(result, expected)
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_jvp_new_tensor(self):
|
||||||
|
def f(x):
|
||||||
|
y = x.new_tensor(0.5)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = torch.rand(10, 10)
|
||||||
|
tangents = torch.zeros_like(x)
|
||||||
|
actual = jvp(f, (x,), (tangents,))
|
||||||
|
expected = (f(x), torch.zeros_like(x))
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_primals_tangents_length_mismatch(self, device):
|
def test_primals_tangents_length_mismatch(self, device):
|
||||||
x = torch.randn(2, 3, device=device)
|
x = torch.randn(2, 3, device=device)
|
||||||
t = torch.randn(2, 3, device=device)
|
t = torch.randn(2, 3, device=device)
|
||||||
|
Reference in New Issue
Block a user