mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import textwrap
|
|
import types
|
|
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.utils._freeze import Freezer, PATH_MARKER
|
|
|
|
|
|
class TestFreezer(TestCase):
|
|
"""Tests the freeze.py script"""
|
|
|
|
def test_compile_string(self):
|
|
freezer = Freezer(True)
|
|
code_str = textwrap.dedent(
|
|
"""
|
|
class MyCls:
|
|
def __init__(self) -> None:
|
|
pass
|
|
"""
|
|
)
|
|
co = freezer.compile_string(code_str)
|
|
num_co = 0
|
|
|
|
def verify_filename(co: types.CodeType):
|
|
nonlocal num_co
|
|
|
|
if not isinstance(co, types.CodeType):
|
|
return
|
|
|
|
self.assertEqual(PATH_MARKER, co.co_filename)
|
|
num_co += 1
|
|
|
|
for nested_co in co.co_consts:
|
|
verify_filename(nested_co)
|
|
|
|
verify_filename(co)
|
|
# there is at least one nested code object besides the top level one
|
|
self.assertTrue(num_co >= 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|