Make OpaqueUnaryFn pickleable (#138395)

Fixes https://github.com/pytorch/pytorch/issues/138070

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138395
Approved by: https://github.com/XuehaiPan, https://github.com/bobrenjc93
This commit is contained in:
Edward Z. Yang
2024-10-19 10:22:20 -07:00
committed by PyTorch MergeBot
parent 4d9b5a87e4
commit 8274dadac5
2 changed files with 12 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import itertools
import math
import sys
from typing import Callable, List, Tuple, Type
import pickle
import sympy
@ -30,6 +31,7 @@ from torch.utils._sympy.reference import (
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
UNARY_OPS = [
@ -811,6 +813,13 @@ class TestSympySolve(TestCase):
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
class TestSympyFunctions(TestCase):
def test_pickle(self):
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
r = pickle.loads(pickle.dumps(x))
self.assertEqual(x, r)
class TestSingletonInt(TestCase):
def test_basic(self):
j1 = SingletonInt(1, coeff=1)

View File

@ -1202,7 +1202,9 @@ def make_opaque_unary_fn(name):
return getattr(sympy, name)(a)
return None
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
nm = "OpaqueUnaryFn_" + name
OpaqueUnaryFn.__name__ = nm
OpaqueUnaryFn.__qualname__ = nm
return OpaqueUnaryFn