mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129759 Approved by: https://github.com/justinchuby, https://github.com/ezyang
163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
from io import BytesIO
|
|
|
|
import torch
|
|
from torch.package import (
|
|
Importer,
|
|
OrderedImporter,
|
|
PackageExporter,
|
|
PackageImporter,
|
|
sys_importer,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
class TestImporter(PackageTestCase):
|
|
"""Tests for Importer and derived classes."""
|
|
|
|
def test_sys_importer(self):
|
|
import package_a
|
|
import package_a.subpackage
|
|
|
|
self.assertIs(sys_importer.import_module("package_a"), package_a)
|
|
self.assertIs(
|
|
sys_importer.import_module("package_a.subpackage"), package_a.subpackage
|
|
)
|
|
|
|
def test_sys_importer_roundtrip(self):
|
|
import package_a
|
|
import package_a.subpackage
|
|
|
|
importer = sys_importer
|
|
type_ = package_a.subpackage.PackageASubpackageObject
|
|
module_name, type_name = importer.get_name(type_)
|
|
|
|
module = importer.import_module(module_name)
|
|
self.assertIs(getattr(module, type_name), type_)
|
|
|
|
def test_single_ordered_importer(self):
|
|
import module_a # noqa: F401
|
|
import package_a
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe:
|
|
pe.save_module(package_a.__name__)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
|
|
# Construct an importer-only environment.
|
|
ordered_importer = OrderedImporter(importer)
|
|
|
|
# The module returned by this environment should be the same one that's
|
|
# in the importer.
|
|
self.assertIs(
|
|
ordered_importer.import_module("package_a"),
|
|
importer.import_module("package_a"),
|
|
)
|
|
# It should not be the one available in the outer Python environment.
|
|
self.assertIsNot(ordered_importer.import_module("package_a"), package_a)
|
|
|
|
# We didn't package this module, so it should not be available.
|
|
with self.assertRaises(ModuleNotFoundError):
|
|
ordered_importer.import_module("module_a")
|
|
|
|
def test_ordered_importer_basic(self):
|
|
import package_a
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as pe:
|
|
pe.save_module(package_a.__name__)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
|
|
ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
|
|
self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a)
|
|
|
|
ordered_importer_package_first = OrderedImporter(importer, sys_importer)
|
|
self.assertIs(
|
|
ordered_importer_package_first.import_module("package_a"),
|
|
importer.import_module("package_a"),
|
|
)
|
|
|
|
def test_ordered_importer_whichmodule(self):
|
|
"""OrderedImporter's implementation of whichmodule should try each
|
|
underlying importer's whichmodule in order.
|
|
"""
|
|
|
|
class DummyImporter(Importer):
|
|
def __init__(self, whichmodule_return):
|
|
self._whichmodule_return = whichmodule_return
|
|
|
|
def import_module(self, module_name):
|
|
raise NotImplementedError
|
|
|
|
def whichmodule(self, obj, name):
|
|
return self._whichmodule_return
|
|
|
|
class DummyClass:
|
|
pass
|
|
|
|
dummy_importer_foo = DummyImporter("foo")
|
|
dummy_importer_bar = DummyImporter("bar")
|
|
dummy_importer_not_found = DummyImporter(
|
|
"__main__"
|
|
) # __main__ is used as a proxy for "not found" by CPython
|
|
|
|
foo_then_bar = OrderedImporter(dummy_importer_foo, dummy_importer_bar)
|
|
self.assertEqual(foo_then_bar.whichmodule(DummyClass(), ""), "foo")
|
|
|
|
bar_then_foo = OrderedImporter(dummy_importer_bar, dummy_importer_foo)
|
|
self.assertEqual(bar_then_foo.whichmodule(DummyClass(), ""), "bar")
|
|
|
|
notfound_then_foo = OrderedImporter(
|
|
dummy_importer_not_found, dummy_importer_foo
|
|
)
|
|
self.assertEqual(notfound_then_foo.whichmodule(DummyClass(), ""), "foo")
|
|
|
|
def test_package_importer_whichmodule_no_dunder_module(self):
|
|
"""Exercise corner case where we try to pickle an object whose
|
|
__module__ doesn't exist because it's from a C extension.
|
|
"""
|
|
# torch.float16 is an example of such an object: it is a C extension
|
|
# type for which there is no __module__ defined. The default pickler
|
|
# finds it using special logic to traverse sys.modules and look up
|
|
# `float16` on each module (see pickle.py:whichmodule).
|
|
#
|
|
# We must ensure that we emulate the same behavior from PackageImporter.
|
|
my_dtype = torch.float16
|
|
|
|
# Set up a PackageImporter which has a torch.float16 object pickled:
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.save_pickle("foo", "foo.pkl", my_dtype)
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")
|
|
|
|
# Re-save a package with only our PackageImporter as the importer
|
|
buffer2 = BytesIO()
|
|
with PackageExporter(buffer2, importer=importer) as exporter:
|
|
exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)
|
|
|
|
buffer2.seek(0)
|
|
|
|
importer2 = PackageImporter(buffer2)
|
|
my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
|
|
self.assertIs(my_dtype, my_loaded_dtype)
|
|
self.assertIs(my_dtype, my_loaded_dtype2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|