mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class. This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers. Differential Revision: D76463252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743 Approved by: https://github.com/jcwchen, https://github.com/jingsh
287 lines
9.8 KiB
Python
287 lines
9.8 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import pickle
|
|
from io import BytesIO
|
|
from sys import version_info
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
from torch.package import PackageExporter, PackageImporter, sys_importer
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
packaging_directory = Path(__file__).parent
|
|
|
|
|
|
class TestSaveLoad(PackageTestCase):
|
|
"""Core save_* and loading API tests."""
|
|
|
|
def test_saving_source(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.save_source_file("foo", str(packaging_directory / "module_a.py"))
|
|
he.save_source_file("foodir", str(packaging_directory / "package_a"))
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
foo = hi.import_module("foo")
|
|
s = hi.import_module("foodir.subpackage")
|
|
self.assertEqual(foo.result, "module_a")
|
|
self.assertEqual(s.result, "package_a.subpackage")
|
|
|
|
def test_saving_string(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
src = dedent(
|
|
"""\
|
|
import math
|
|
the_math = math
|
|
"""
|
|
)
|
|
he.save_source_string("my_mod", src)
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
m = hi.import_module("math")
|
|
import math
|
|
|
|
self.assertIs(m, math)
|
|
my_mod = hi.import_module("my_mod")
|
|
self.assertIs(my_mod.math, math)
|
|
|
|
def test_save_module(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
import module_a
|
|
import package_a
|
|
|
|
he.save_module(module_a.__name__)
|
|
he.save_module(package_a.__name__)
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
module_a_i = hi.import_module("module_a")
|
|
self.assertEqual(module_a_i.result, "module_a")
|
|
self.assertIsNot(module_a, module_a_i)
|
|
package_a_i = hi.import_module("package_a")
|
|
self.assertEqual(package_a_i.result, "package_a")
|
|
self.assertIsNot(package_a_i, package_a)
|
|
|
|
def test_dunder_imports(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
import package_b
|
|
|
|
obj = package_b.PackageBObject
|
|
he.intern("**")
|
|
he.save_pickle("res", "obj.pkl", obj)
|
|
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
hi.load_pickle("res", "obj.pkl")
|
|
|
|
package_b = hi.import_module("package_b")
|
|
self.assertEqual(package_b.result, "package_b")
|
|
|
|
math = hi.import_module("math")
|
|
self.assertEqual(math.__name__, "math")
|
|
|
|
xml_sub_sub_package = hi.import_module("xml.sax.xmlreader")
|
|
self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader")
|
|
|
|
subpackage_1 = hi.import_module("package_b.subpackage_1")
|
|
self.assertEqual(subpackage_1.result, "subpackage_1")
|
|
|
|
subpackage_2 = hi.import_module("package_b.subpackage_2")
|
|
self.assertEqual(subpackage_2.result, "subpackage_2")
|
|
|
|
subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0")
|
|
self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
|
|
|
|
def test_bad_dunder_imports(self):
|
|
"""Test to ensure bad __imports__ don't cause PackageExporter to fail."""
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as e:
|
|
e.save_source_string(
|
|
"m", '__import__(these, unresolvable, "things", wont, crash, me)'
|
|
)
|
|
|
|
def test_save_module_binary(self):
|
|
f = BytesIO()
|
|
with PackageExporter(f) as he:
|
|
import module_a
|
|
import package_a
|
|
|
|
he.save_module(module_a.__name__)
|
|
he.save_module(package_a.__name__)
|
|
f.seek(0)
|
|
hi = PackageImporter(f)
|
|
module_a_i = hi.import_module("module_a")
|
|
self.assertEqual(module_a_i.result, "module_a")
|
|
self.assertIsNot(module_a, module_a_i)
|
|
package_a_i = hi.import_module("package_a")
|
|
self.assertEqual(package_a_i.result, "package_a")
|
|
self.assertIsNot(package_a_i, package_a)
|
|
|
|
def test_pickle(self):
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.intern("**")
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
|
|
# check we got dependencies
|
|
sp = hi.import_module("package_a.subpackage")
|
|
# check we didn't get other stuff
|
|
with self.assertRaises(ImportError):
|
|
hi.import_module("module_a")
|
|
|
|
obj_loaded = hi.load_pickle("obj", "obj.pkl")
|
|
self.assertIsNot(obj2, obj_loaded)
|
|
self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
|
|
self.assertIsNot(
|
|
package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject
|
|
)
|
|
|
|
def test_pickle_long_name_with_protocol_4(self):
|
|
import package_a.long_name
|
|
|
|
container = []
|
|
|
|
# Indirectly grab the function to avoid pasting a 256 character
|
|
# function into the test
|
|
package_a.long_name.add_function(container)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.intern("**")
|
|
exporter.save_pickle(
|
|
"container", "container.pkl", container, pickle_protocol=4
|
|
)
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
unpickled_container = importer.load_pickle("container", "container.pkl")
|
|
self.assertIsNot(container, unpickled_container)
|
|
self.assertEqual(len(unpickled_container), 1)
|
|
self.assertEqual(container[0](), unpickled_container[0]())
|
|
|
|
def test_exporting_mismatched_code(self):
|
|
"""
|
|
If an object with the same qualified name is loaded from different
|
|
packages, the user should get an error if they try to re-save the
|
|
object with the wrong package's source code.
|
|
"""
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
b1 = BytesIO()
|
|
with PackageExporter(b1) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
b1.seek(0)
|
|
importer1 = PackageImporter(b1)
|
|
loaded1 = importer1.load_pickle("obj", "obj.pkl")
|
|
|
|
b1.seek(0)
|
|
importer2 = PackageImporter(b1)
|
|
loaded2 = importer2.load_pickle("obj", "obj.pkl")
|
|
|
|
def make_exporter():
|
|
pe = PackageExporter(BytesIO(), importer=[importer1, sys_importer])
|
|
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
|
|
return pe
|
|
|
|
# This succeeds because OrderedImporter.get_name() properly
|
|
# falls back to sys_importer which can find the original PackageAObject
|
|
pe = make_exporter()
|
|
pe.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
# This should also fail. The 'PackageAObject' type defined from 'importer1'
|
|
# is not necessarily the same as the one defined from 'importer2'
|
|
pe = make_exporter()
|
|
with self.assertRaises(pickle.PicklingError):
|
|
pe.save_pickle("obj", "obj.pkl", loaded2)
|
|
|
|
# This should succeed. The 'PackageAObject' type defined from
|
|
# 'importer1' is a match for the one used by loaded1.
|
|
pe = make_exporter()
|
|
pe.save_pickle("obj", "obj.pkl", loaded1)
|
|
|
|
def test_save_imported_module(self):
|
|
"""Saving a module that came from another PackageImporter should work."""
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.intern("**")
|
|
exporter.save_pickle("model", "model.pkl", obj2)
|
|
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
imported_obj2 = importer.load_pickle("model", "model.pkl")
|
|
imported_obj2_module = imported_obj2.__class__.__module__
|
|
|
|
# Should export without error.
|
|
buffer2 = BytesIO()
|
|
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
|
|
exporter.intern("**")
|
|
exporter.save_module(imported_obj2_module)
|
|
|
|
def test_save_imported_module_using_package_importer(self):
|
|
"""Exercise a corner case: re-packaging a module that uses `torch_package_importer`"""
|
|
import package_a.use_torch_package_importer # noqa: F401
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.intern("**")
|
|
exporter.save_module("package_a.use_torch_package_importer")
|
|
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
|
|
# Should export without error.
|
|
buffer2 = BytesIO()
|
|
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
|
|
exporter.intern("**")
|
|
exporter.save_module("package_a.use_torch_package_importer")
|
|
|
|
@skipIf(version_info >= (3, 13), "https://github.com/pytorch/pytorch/issues/142170")
|
|
def test_save_load_fp8(self):
|
|
tensor = torch.rand(20, 20).to(torch.float8_e4m3fn)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.save_pickle("fp8_model", "model.pkl", tensor)
|
|
|
|
buffer.seek(0)
|
|
|
|
importer = PackageImporter(buffer)
|
|
loaded_tensor = importer.load_pickle("fp8_model", "model.pkl")
|
|
self.assertTrue(torch.equal(tensor, loaded_tensor))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|