mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make OpaqueUnaryFn pickleable (#138395)
Fixes https://github.com/pytorch/pytorch/issues/138070 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/138395 Approved by: https://github.com/XuehaiPan, https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d9b5a87e4
commit
8274dadac5
@ -5,6 +5,7 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, List, Tuple, Type
|
from typing import Callable, List, Tuple, Type
|
||||||
|
import pickle
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
@ -30,6 +31,7 @@ from torch.utils._sympy.reference import (
|
|||||||
from torch.utils._sympy.singleton_int import SingletonInt
|
from torch.utils._sympy.singleton_int import SingletonInt
|
||||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||||
|
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
|
||||||
|
|
||||||
|
|
||||||
UNARY_OPS = [
|
UNARY_OPS = [
|
||||||
@ -811,6 +813,13 @@ class TestSympySolve(TestCase):
|
|||||||
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSympyFunctions(TestCase):
|
||||||
|
def test_pickle(self):
|
||||||
|
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
|
||||||
|
r = pickle.loads(pickle.dumps(x))
|
||||||
|
self.assertEqual(x, r)
|
||||||
|
|
||||||
|
|
||||||
class TestSingletonInt(TestCase):
|
class TestSingletonInt(TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
j1 = SingletonInt(1, coeff=1)
|
j1 = SingletonInt(1, coeff=1)
|
||||||
|
@ -1202,7 +1202,9 @@ def make_opaque_unary_fn(name):
|
|||||||
return getattr(sympy, name)(a)
|
return getattr(sympy, name)(a)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
|
nm = "OpaqueUnaryFn_" + name
|
||||||
|
OpaqueUnaryFn.__name__ = nm
|
||||||
|
OpaqueUnaryFn.__qualname__ = nm
|
||||||
|
|
||||||
return OpaqueUnaryFn
|
return OpaqueUnaryFn
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user