mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make OpaqueUnaryFn pickleable (#138395)
Fixes https://github.com/pytorch/pytorch/issues/138070 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/138395 Approved by: https://github.com/XuehaiPan, https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d9b5a87e4
commit
8274dadac5
@ -5,6 +5,7 @@ import itertools
|
||||
import math
|
||||
import sys
|
||||
from typing import Callable, List, Tuple, Type
|
||||
import pickle
|
||||
|
||||
import sympy
|
||||
|
||||
@ -30,6 +31,7 @@ from torch.utils._sympy.reference import (
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
|
||||
|
||||
|
||||
UNARY_OPS = [
|
||||
@ -811,6 +813,13 @@ class TestSympySolve(TestCase):
|
||||
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
||||
|
||||
|
||||
class TestSympyFunctions(TestCase):
|
||||
def test_pickle(self):
|
||||
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
|
||||
r = pickle.loads(pickle.dumps(x))
|
||||
self.assertEqual(x, r)
|
||||
|
||||
|
||||
class TestSingletonInt(TestCase):
|
||||
def test_basic(self):
|
||||
j1 = SingletonInt(1, coeff=1)
|
||||
|
@ -1202,7 +1202,9 @@ def make_opaque_unary_fn(name):
|
||||
return getattr(sympy, name)(a)
|
||||
return None
|
||||
|
||||
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
|
||||
nm = "OpaqueUnaryFn_" + name
|
||||
OpaqueUnaryFn.__name__ = nm
|
||||
OpaqueUnaryFn.__qualname__ = nm
|
||||
|
||||
return OpaqueUnaryFn
|
||||
|
||||
|
Reference in New Issue
Block a user