diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 81ed1126dcbb..06dd86711673 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -5,6 +5,7 @@ import itertools import math import sys from typing import Callable, List, Tuple, Type +import pickle import sympy @@ -30,6 +31,7 @@ from torch.utils._sympy.reference import ( from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.functions import OpaqueUnaryFn_cos UNARY_OPS = [ @@ -811,6 +813,13 @@ class TestSympySolve(TestCase): self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) +class TestSympyFunctions(TestCase): + def test_pickle(self): + x = OpaqueUnaryFn_cos(sympy.Symbol('a')) + r = pickle.loads(pickle.dumps(x)) + self.assertEqual(x, r) + + class TestSingletonInt(TestCase): def test_basic(self): j1 = SingletonInt(1, coeff=1) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index b369673d9213..480be602efa6 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1202,7 +1202,9 @@ def make_opaque_unary_fn(name): return getattr(sympy, name)(a) return None - OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name + nm = "OpaqueUnaryFn_" + name + OpaqueUnaryFn.__name__ = nm + OpaqueUnaryFn.__qualname__ = nm return OpaqueUnaryFn