fix pickling for BitwiseFn (#163571)

Summary:
ran into AttributeError: Can't get local object 'make_opaque_bitwise_fn.<locals>.BitwiseFn'

looks like it was fixed for UnaryFn but not BitwiseFn in https://github.com/pytorch/pytorch/pull/138395

Fixes #147841

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163571
Approved by: https://github.com/jamesjwu
This commit is contained in:
dolpm
2025-09-25 04:52:11 +00:00
committed by PyTorch MergeBot
parent 783a9dcb6d
commit cde5c9aebd
3 changed files with 17 additions and 5 deletions

View File

@ -206,7 +206,8 @@ class TestSubprocess(TestCase):
start = time.time()
last_report = start
while _AsyncFxCompile._stat_compiled_runs < 4:
while True:
start_stat_compiled_runs = _AsyncFxCompile._stat_compiled_runs
# Sleep a bit so we don't drive the CPU unnecessarily.
time.sleep(0.25)
@ -219,6 +220,9 @@ class TestSubprocess(TestCase):
# Backward pass
output.sum().backward()
if _AsyncFxCompile._stat_compiled_runs - start_stat_compiled_runs == 2:
break
# DEBUGGING: Print a periodic message so we know we're still
# running...
now = time.time()
@ -231,12 +235,12 @@ class TestSubprocess(TestCase):
"Test timed out before producing a compiled artifact."
)
self.assertEqual(_AsyncFxCompile._stat_compiled_runs, 4)
self.assertGreater(_AsyncFxCompile._stat_compiled_runs, 1)
# Make sure we ran eager at least once. Normally this will be
# something like 80.
self.assertGreater(_AsyncFxCompile._stat_eager_runs, 0)
self.assertEqual(_AsyncFxCompile._stat_bg_started, 1)
self.assertEqual(_AsyncFxCompile._stat_bg_finished, 1)
self.assertEqual(_AsyncFxCompile._stat_bg_started, 2)
self.assertEqual(_AsyncFxCompile._stat_bg_finished, 2)
if RUN_CPU:

View File

@ -24,6 +24,7 @@ from torch.utils._sympy.functions import (
FloorDiv,
Identity,
OpaqueUnaryFn_cos,
BitwiseFn_bitwise_and,
simple_floordiv_gcd,
)
from torch.utils._sympy.interp import sympy_interp
@ -873,6 +874,10 @@ class TestSympyFunctions(TestCase):
r = pickle.loads(pickle.dumps(x))
self.assertEqual(x, r)
x = BitwiseFn_bitwise_and(sympy.Symbol("a"), sympy.Symbol("b"))
r = pickle.loads(pickle.dumps(x))
self.assertEqual(x, r)
class TestSingletonInt(TestCase):
def test_basic(self):

View File

@ -1411,7 +1411,10 @@ def make_opaque_bitwise_fn(name, real_op_name):
return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
return None
BitwiseFn.__name__ = "BitwiseFn_" + name
nm = "BitwiseFn_" + name
BitwiseFn.__name__ = nm
BitwiseFn.__qualname__ = nm
return BitwiseFn