Files
pytorch/test/package/test_misc.py
Meghan Lele d58c00a5d8 [package] Make exporters write to buffer in fbcode (#54303)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54303

**Summary**
Creating temporary files can cause problem in fbcode. This commit
updates the packaging tests so that exporters write to a memory
buffer when tests run in fbcode.

**Test Plan**
Continuous integration.

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D27180839

Pulled By: SplitInfinity

fbshipit-source-id: 75689d59448de2cd1595ef0ecec69e1bbcf9a96f
2021-03-19 19:59:35 -07:00

140 lines
4.6 KiB
Python

import inspect
from io import BytesIO
from sys import version_info
from textwrap import dedent
from unittest import skipIf
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore
class TestMisc(PackageTestCase):
"""Tests for one-off or random functionality. Try not to add to this!"""
def test_file_structure(self):
buffer = BytesIO()
export_plain = dedent(
"""\
├── main
│ └── main
├── obj
│ └── obj.pkl
├── package_a
│ ├── __init__.py
│ └── subpackage.py
└── module_a.py
"""
)
export_include = dedent(
"""\
├── obj
│ └── obj.pkl
└── package_a
└── subpackage.py
"""
)
import_exclude = dedent(
"""\
├── .data
│ ├── extern_modules
│ └── version
├── main
│ └── main
├── obj
│ └── obj.pkl
├── package_a
│ ├── __init__.py
│ └── subpackage.py
└── module_a.py
"""
)
with PackageExporter(buffer, verbose=False) as he:
import module_a
import package_a
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
he.save_module(module_a.__name__)
he.save_module(package_a.__name__)
he.save_pickle("obj", "obj.pkl", obj)
he.save_text("main", "main", "my string")
export_file_structure = he.file_structure()
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
self.assertEqual(
dedent("\n".join(str(export_file_structure).split("\n")[1:])),
export_plain,
)
export_file_structure = he.file_structure(
include=["**/subpackage.py", "**/*.pkl"]
)
self.assertEqual(
dedent("\n".join(str(export_file_structure).split("\n")[1:])),
export_include,
)
buffer.seek(0)
hi = PackageImporter(buffer)
import_file_structure = hi.file_structure(exclude="**/*.storage")
self.assertEqual(
dedent("\n".join(str(import_file_structure).split("\n")[1:])),
import_exclude,
)
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_custom_requires(self):
buffer = BytesIO()
class Custom(PackageExporter):
def require_module(self, name, dependencies):
if name == "module_a":
self.save_mock_module("module_a")
elif name == "package_a":
self.save_source_string(
"package_a", "import module_a\nresult = 5\n"
)
else:
raise NotImplementedError("wat")
with Custom(buffer, verbose=False) as he:
he.save_source_string("main", "import package_a\n")
buffer.seek(0)
hi = PackageImporter(buffer)
hi.import_module("module_a").should_be_mocked
bar = hi.import_module("package_a")
self.assertEqual(bar.result, 5)
def test_inspect_class(self):
"""Should be able to retrieve source for a packaged class."""
import package_a.subpackage
buffer = BytesIO()
obj = package_a.subpackage.PackageASubpackageObject()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_pickle("obj", "obj.pkl", obj)
buffer.seek(0)
pi = PackageImporter(buffer)
packaged_class = pi.import_module(
"package_a.subpackage"
).PackageASubpackageObject
regular_class = package_a.subpackage.PackageASubpackageObject
packaged_src = inspect.getsourcelines(packaged_class)
regular_src = inspect.getsourcelines(regular_class)
self.assertEqual(packaged_src, regular_src)
if __name__ == "__main__":
run_tests()