Files
pytorch/test/dynamo/cpython/3_13/test_heapq.diff
Guilherme Leobas d0e8a0ec4c Add CPython test for heapq (#159370)
Not used directly but used internally by `collections.Counter`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159370
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2025-07-30 18:43:06 +00:00

105 lines
3.3 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
index 1aa8e4e2897..94315fa68b4 100644
--- a/test/dynamo/cpython/3_13/test_heapq.py
+++ b/test/dynamo/cpython/3_13/test_heapq.py
@@ -1,3 +1,23 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_collections.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import run_tests
+
+__TestCase = CPythonTestCase
+
+# ======= END DYNAMO PATCH =======
+
"""Unittests for heapq."""
import random
@@ -16,7 +36,7 @@ c_heapq = import_helper.import_fresh_module('heapq', fresh=['_heapq'])
func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
'_heappop_max', '_heapreplace_max', '_heapify_max']
-class TestModules(TestCase):
+class TestModules(__TestCase):
def test_py_functions(self):
for fname in func_names:
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
@@ -27,24 +47,7 @@ class TestModules(TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
-def load_tests(loader, tests, ignore):
- # The 'merge' function has examples in its docstring which we should test
- # with 'doctest'.
- #
- # However, doctest can't easily find all docstrings in the module (loading
- # it through import_fresh_module seems to confuse it), so we specifically
- # create a finder which returns the doctests from the merge method.
-
- class HeapqMergeDocTestFinder:
- def find(self, *args, **kwargs):
- dtf = doctest.DocTestFinder()
- return dtf.find(py_heapq.merge)
-
- tests.addTests(doctest.DocTestSuite(py_heapq,
- test_finder=HeapqMergeDocTestFinder()))
- return tests
-
-class TestHeap:
+class _TestHeap:
def test_push_pop(self):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
@@ -264,12 +267,12 @@ class TestHeap:
self.assertRaises(TypeError, data, LE)
-class TestHeapPython(TestHeap, TestCase):
+class TestHeapPython(_TestHeap, __TestCase):
module = py_heapq
@skipUnless(c_heapq, 'requires _heapq')
-class TestHeapC(TestHeap, TestCase):
+class TestHeapC(_TestHeap, __TestCase):
module = c_heapq
@@ -374,7 +377,7 @@ class SideEffectLT:
return self.value < other.value
-class TestErrorHandling:
+class _TestErrorHandling:
def test_non_sequence(self):
for f in (self.module.heapify, self.module.heappop):
@@ -464,13 +467,13 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
-class TestErrorHandlingPython(TestErrorHandling, TestCase):
+class TestErrorHandlingPython(_TestErrorHandling, __TestCase):
module = py_heapq
@skipUnless(c_heapq, 'requires _heapq')
-class TestErrorHandlingC(TestErrorHandling, TestCase):
+class TestErrorHandlingC(_TestErrorHandling, __TestCase):
module = c_heapq
if __name__ == "__main__":
- unittest.main()
+ run_tests()