Make DispatchKeySet serializable; add __eq__ (#152732)

These seem like reasonable things to add. Also fixes a bug in vLLM for
me.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152732
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou
2025-05-02 14:00:25 -07:00
committed by PyTorch MergeBot
parent 792736f9ac
commit 762844355e
3 changed files with 31 additions and 1 deletions

View File

@ -2,6 +2,7 @@
# ruff: noqa: F841
import logging
import pickle
import sys
import tempfile
import unittest
@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase):
torch.ops.custom.sum.default(a)
self.assertTrue(meta_is_called)
def test_dispatchkeyset_pickle(self) -> None:
keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
serialized = pickle.dumps(keyset)
new_keyset = pickle.loads(serialized)
self.assertEqual(new_keyset, keyset)
def test_dispatchkeyset_eq(self) -> None:
a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
self.assertTrue(a == b)
self.assertFalse(a != b)
self.assertTrue(a != c)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2: