mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54303 **Summary** Creating temporary files can cause problem in fbcode. This commit updates the packaging tests so that exporters write to a memory buffer when tests run in fbcode. **Test Plan** Continuous integration. Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D27180839 Pulled By: SplitInfinity fbshipit-source-id: 75689d59448de2cd1595ef0ecec69e1bbcf9a96f
140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
import inspect
|
|
from io import BytesIO
|
|
from sys import version_info
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
from torch.package import PackageExporter, PackageImporter
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase # type: ignore
|
|
|
|
|
|
class TestMisc(PackageTestCase):
|
|
"""Tests for one-off or random functionality. Try not to add to this!"""
|
|
|
|
def test_file_structure(self):
|
|
buffer = BytesIO()
|
|
|
|
export_plain = dedent(
|
|
"""\
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
export_include = dedent(
|
|
"""\
|
|
├── obj
|
|
│ └── obj.pkl
|
|
└── package_a
|
|
└── subpackage.py
|
|
"""
|
|
)
|
|
import_exclude = dedent(
|
|
"""\
|
|
├── .data
|
|
│ ├── extern_modules
|
|
│ └── version
|
|
├── main
|
|
│ └── main
|
|
├── obj
|
|
│ └── obj.pkl
|
|
├── package_a
|
|
│ ├── __init__.py
|
|
│ └── subpackage.py
|
|
└── module_a.py
|
|
"""
|
|
)
|
|
|
|
with PackageExporter(buffer, verbose=False) as he:
|
|
import module_a
|
|
import package_a
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
he.save_module(module_a.__name__)
|
|
he.save_module(package_a.__name__)
|
|
he.save_pickle("obj", "obj.pkl", obj)
|
|
he.save_text("main", "main", "my string")
|
|
|
|
export_file_structure = he.file_structure()
|
|
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
|
|
self.assertEqual(
|
|
dedent("\n".join(str(export_file_structure).split("\n")[1:])),
|
|
export_plain,
|
|
)
|
|
export_file_structure = he.file_structure(
|
|
include=["**/subpackage.py", "**/*.pkl"]
|
|
)
|
|
self.assertEqual(
|
|
dedent("\n".join(str(export_file_structure).split("\n")[1:])),
|
|
export_include,
|
|
)
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
import_file_structure = hi.file_structure(exclude="**/*.storage")
|
|
self.assertEqual(
|
|
dedent("\n".join(str(import_file_structure).split("\n")[1:])),
|
|
import_exclude,
|
|
)
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_custom_requires(self):
|
|
buffer = BytesIO()
|
|
|
|
class Custom(PackageExporter):
|
|
def require_module(self, name, dependencies):
|
|
if name == "module_a":
|
|
self.save_mock_module("module_a")
|
|
elif name == "package_a":
|
|
self.save_source_string(
|
|
"package_a", "import module_a\nresult = 5\n"
|
|
)
|
|
else:
|
|
raise NotImplementedError("wat")
|
|
|
|
with Custom(buffer, verbose=False) as he:
|
|
he.save_source_string("main", "import package_a\n")
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
hi.import_module("module_a").should_be_mocked
|
|
bar = hi.import_module("package_a")
|
|
self.assertEqual(bar.result, 5)
|
|
|
|
def test_inspect_class(self):
|
|
"""Should be able to retrieve source for a packaged class."""
|
|
import package_a.subpackage
|
|
|
|
buffer = BytesIO()
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
|
|
with PackageExporter(buffer, verbose=False) as pe:
|
|
pe.save_pickle("obj", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
pi = PackageImporter(buffer)
|
|
packaged_class = pi.import_module(
|
|
"package_a.subpackage"
|
|
).PackageASubpackageObject
|
|
regular_class = package_a.subpackage.PackageASubpackageObject
|
|
|
|
packaged_src = inspect.getsourcelines(packaged_class)
|
|
regular_src = inspect.getsourcelines(regular_class)
|
|
self.assertEqual(packaged_src, regular_src)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|