mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make DispatchKeySet serializable; add __eq__
(#152732)
These seem like reasonable things to add. Also fixes a bug in vLLM for me. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/152732 Approved by: https://github.com/bdhirsh
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
# ruff: noqa: F841
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase):
|
||||
torch.ops.custom.sum.default(a)
|
||||
self.assertTrue(meta_is_called)
|
||||
|
||||
def test_dispatchkeyset_pickle(self) -> None:
|
||||
keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
serialized = pickle.dumps(keyset)
|
||||
new_keyset = pickle.loads(serialized)
|
||||
self.assertEqual(new_keyset, keyset)
|
||||
|
||||
def test_dispatchkeyset_eq(self) -> None:
|
||||
a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
|
||||
self.assertTrue(a == b)
|
||||
self.assertFalse(a != b)
|
||||
self.assertTrue(a != c)
|
||||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||
|
Reference in New Issue
Block a user