Remove functorch dispatch keys in legacyExtractDispatchKey (#133018)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133018
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2024-08-12 21:36:20 +00:00
committed by PyTorch MergeBot
parent cd565bc455
commit c518b50c4c
2 changed files with 14 additions and 0 deletions

View File

@ -911,6 +911,9 @@ inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
DispatchKeySet(
{DispatchKey::Functionalize,
DispatchKey::PythonTLSSnapshot,
DispatchKey::FuncTorchGradWrapper,
DispatchKey::FuncTorchVmapMode,
DispatchKey::FuncTorchBatched,
DispatchKey::Python}))
.highestPriorityTypeId();
}

View File

@ -2637,6 +2637,17 @@ class TestJvp(TestCase):
self.assertTrue(isinstance(result, tuple))
self.assertEqual(result, expected)
def test_jvp_new_tensor(self):
def f(x):
y = x.new_tensor(0.5)
return x + y
x = torch.rand(10, 10)
tangents = torch.zeros_like(x)
actual = jvp(f, (x,), (tangents,))
expected = (f(x), torch.zeros_like(x))
self.assertEqual(actual, expected)
def test_primals_tangents_length_mismatch(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)