Fixes for CPython int/float tests (#155978)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155978
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2025-06-26 14:00:50 +00:00
committed by PyTorch MergeBot
parent d0cfa3e5bf
commit 216bd6091e
61 changed files with 55 additions and 13 deletions

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py
index 48825f46911..ac7aeacbc01 100644
index 48825f46911..4ab200372ea 100644
--- a/test/dynamo/cpython/3_13/test_int.py
+++ b/test/dynamo/cpython/3_13/test_int.py
@@ -1,13 +1,137 @@
@ -153,7 +153,15 @@ index 48825f46911..ac7aeacbc01 100644
def test_basic(self):
self.assertEqual(int(314), 314)
@@ -607,7 +731,7 @@ class IntTestCases(unittest.TestCase):
@@ -566,6 +690,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(n, 1)
self.assertIs(type(n), IntSubclass)
+ @skipIfTorchDynamo("flaky under dynamo")
def test_error_message(self):
def check(s, base=None):
with self.assertRaises(ValueError,
@@ -607,7 +732,7 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
@ -162,7 +170,7 @@ index 48825f46911..ac7aeacbc01 100644
int_class = int # Override this in subclasses to reuse the suite.
@@ -818,7 +942,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
@@ -818,7 +943,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
int_class = IntSubclass
@ -171,7 +179,7 @@ index 48825f46911..ac7aeacbc01 100644
# Tests of the functions in _pylong.py. Those get used when the
# number of digits in the input values are large enough.
@@ -922,4 +1046,4 @@ class PyLongModuleTests(unittest.TestCase):
@@ -922,4 +1047,4 @@ class PyLongModuleTests(unittest.TestCase):
bits <<= 1
if __name__ == "__main__":

View File

@ -9,7 +9,7 @@ import torch
import torch._dynamo.test_case
import unittest
from torch._dynamo.test_case import CPythonTestCase
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
__TestCase = CPythonTestCase
@ -690,6 +690,7 @@ class IntTestCases(__TestCase):
self.assertEqual(n, 1)
self.assertIs(type(n), IntSubclass)
@skipIfTorchDynamo("flaky under dynamo")
def test_error_message(self):
def check(s, base=None):
with self.assertRaises(ValueError,

View File

@ -734,6 +734,7 @@ class TestVmapAPI(TestCase):
# warning, not a warning from the vmap fallback path.
self.assertEqual(len(wa), 1)
@skipIfTorchDynamo("Flaky test")
@unittest.expectedFailure
def test_fallback_warns_when_warnings_are_enabled(self):
# NB: One day we will implement a batching rule for torch.atan2.

View File

@ -374,7 +374,7 @@ def raise_observed_exception(
# stack and raise the exception.
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.set_current_exception(exception_vt)
raise observed_exception_map[exc_type]
raise get_dynamo_observed_exception(exc_type)
def handle_observed_exception(tx: Any) -> None:

View File

@ -186,6 +186,15 @@ def set_difference_update(set1, *others):
set1.update(result)
def assert_multi_line_equal(self_, first, second, msg=None):
return self_.assertTrue(first == second, msg)
# The original impl. uses difflib
def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None):
return self_.assertTrue(seq1 == seq2, msg)
def getattr_and_trace(*args, **kwargs):
wrapper_obj = args[0]
attr_name = args[1]

View File

@ -23,3 +23,8 @@ def intern(string: str, /) -> str:
@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True)
def getrecursionlimit() -> int:
return sys.getrecursionlimit()
@substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True)
def get_int_max_str_digits() -> int:
return sys.get_int_max_str_digits()

View File

@ -22,6 +22,7 @@ from typing import Union
import torch
import torch.testing
from torch._dynamo import polyfills
from torch._logging._internal import trace_log
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
IS_WINDOWS,
@ -136,8 +137,8 @@ class CPythonTestCase(TestCase):
assertRegex = unittest.TestCase.assertRegex
assertNotRegex = unittest.TestCase.assertNotRegex
assertCountEqual = unittest.TestCase.assertCountEqual
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
assertMultiLineEqual = polyfills.assert_multi_line_equal
assertSequenceEqual = polyfills.assert_sequence_equal
assertListEqual = unittest.TestCase.assertListEqual
assertTupleEqual = unittest.TestCase.assertTupleEqual
assertSetEqual = unittest.TestCase.assertSetEqual

View File

@ -1277,6 +1277,12 @@ class BuiltinVariable(VariableTracker):
if isinstance(args[0], ConstantVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
if self.fn is float and len(args) >= 1:
if isinstance(args[0], ConstantVariable):
return ConstantVariable.create(
getattr(float, name)(args[0].as_python_constant())
)
return super().call_method(tx, name, args, kwargs)
def _call_int_float(self, tx: "InstructionTranslator", arg):
@ -2062,7 +2068,6 @@ class BuiltinVariable(VariableTracker):
"assertNotWarns",
"assertWarnsRegex",
"assertDictEqual",
"assertSequenceEqual",
"assertWarns",
)
):

View File

@ -173,7 +173,14 @@ its type to `common_constant_types`.
raise_observed_exception(type(e), tx)
elif isinstance(self.value, (float, int)):
if not (args or kwargs):
return ConstantVariable.create(getattr(self.value, name)())
try:
return ConstantVariable.create(getattr(self.value, name)())
except (OverflowError, ValueError) as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
if (
hasattr(operator, name)
and len(args) == 1
@ -203,9 +210,14 @@ its type to `common_constant_types`.
if name == "__len__" and not (args or kwargs):
return ConstantVariable.create(len(self.value))
elif name == "__round__" and len(args) == 1 and args[0].is_python_constant():
return ConstantVariable.create(
round(self.value, args[0].as_python_constant())
)
try:
return ConstantVariable.create(
round(self.value, args[0].as_python_constant())
)
except Exception as e:
raise_observed_exception(
type(e), tx, args=list(map(ConstantVariable.create, e.args))
)
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
assert not kwargs
search = args[0].as_python_constant()