mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D33970688: [pkg] add generic ZipFile Reader/Writer
Test Plan: revert-hammer Differential Revision: D33970688 (c2c260bfc3) Original commit changeset: 8a524867e62a Original Phabricator Diff: D33970688 (c2c260bfc3) fbshipit-source-id: 18b4aa4e221b86a498fc434c1b453356fc47cfbf (cherry picked from commit a295c2b58d3d9cfacfc9d11d36fd80aabd97675c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
20266f054b
commit
00e2c14b78
@ -6,19 +6,8 @@ from sys import version_info
|
||||
from textwrap import dedent
|
||||
from unittest import skipIf
|
||||
|
||||
from torch.package import (
|
||||
EmptyMatchError,
|
||||
Importer,
|
||||
PackageExporter,
|
||||
PackageImporter,
|
||||
PackagingError,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
|
||||
from torch.package.package_exporter import PackagingError
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
||||
|
||||
try:
|
||||
@ -35,18 +24,13 @@ class TestDependencyAPI(PackageTestCase):
|
||||
- deny()
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_extern(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.extern(["package_a.subpackage", "module_a"])
|
||||
he.save_source_string("foo", "import package_a.subpackage; import module_a")
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
import module_a
|
||||
import package_a.subpackage
|
||||
|
||||
@ -60,7 +44,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
|
||||
def test_extern_glob(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.extern(["package_a.*", "module_*"])
|
||||
he.save_module("package_a")
|
||||
he.save_source_string(
|
||||
@ -73,7 +57,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
),
|
||||
)
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
import module_a
|
||||
import package_a.subpackage
|
||||
|
||||
@ -94,7 +78,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.extern(include=["package_b.*"], allow_empty=False)
|
||||
exporter.save_module("package_a.subpackage")
|
||||
|
||||
@ -105,7 +89,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
|
||||
with self.assertRaisesRegex(PackagingError, "denied"):
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.deny(["package_a.subpackage", "module_a"])
|
||||
exporter.save_source_string("foo", "import package_a.subpackage")
|
||||
|
||||
@ -115,7 +99,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
with self.assertRaises(PackagingError):
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.deny(["package_a.*", "module_*"])
|
||||
exporter.save_source_string(
|
||||
"test_module",
|
||||
@ -130,12 +114,12 @@ class TestDependencyAPI(PackageTestCase):
|
||||
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
||||
def test_mock(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.mock(["package_a.subpackage", "module_a"])
|
||||
# Import something that dependso n package_a.subpackage
|
||||
he.save_source_string("foo", "import package_a.subpackage")
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
import package_a.subpackage
|
||||
|
||||
_ = package_a.subpackage
|
||||
@ -151,7 +135,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
||||
def test_mock_glob(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.mock(["package_a.*", "module*"])
|
||||
he.save_module("package_a")
|
||||
he.save_source_string(
|
||||
@ -164,7 +148,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
),
|
||||
)
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
import package_a.subpackage
|
||||
|
||||
_ = package_a.subpackage
|
||||
@ -186,7 +170,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.mock(include=["package_b.*"], allow_empty=False)
|
||||
exporter.save_module("package_a.subpackage")
|
||||
|
||||
@ -199,7 +183,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.assertRaises(PackagingError):
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.mock(include="package_a.subpackage")
|
||||
he.intern("**")
|
||||
he.save_pickle("obj", "obj.pkl", obj2)
|
||||
@ -212,7 +196,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
obj2 = package_a.PackageAObject(obj)
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.intern(include="package_a.**")
|
||||
he.mock("**")
|
||||
he.save_pickle("obj", "obj.pkl", obj2)
|
||||
@ -221,7 +205,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
|
||||
buffer = BytesIO()
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
# Even though we did not extern a module that matches this
|
||||
# pattern, we want to show the save_module error, not the allow_empty error.
|
||||
|
||||
@ -238,7 +222,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
import package_a # noqa: F401
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.save_module("package_a")
|
||||
|
||||
def test_intern_error(self):
|
||||
@ -251,7 +235,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
|
||||
with self.assertRaises(PackagingError) as e:
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.save_pickle("obj", "obj.pkl", obj2)
|
||||
|
||||
self.assertEqual(
|
||||
@ -266,7 +250,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
)
|
||||
|
||||
# Interning all dependencies should work
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.intern(["package_a", "package_a.subpackage"])
|
||||
he.save_pickle("obj", "obj.pkl", obj2)
|
||||
|
||||
@ -297,7 +281,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
|
||||
with self.assertRaises(PackagingError) as e:
|
||||
with self.PackageExporter(buffer, importer=BrokenImporter()) as exporter:
|
||||
with PackageExporter(buffer, importer=BrokenImporter()) as exporter:
|
||||
exporter.intern(["foo", "bar"])
|
||||
exporter.save_source_string("my_module", "import foo; import bar")
|
||||
|
||||
@ -316,7 +300,7 @@ class TestDependencyAPI(PackageTestCase):
|
||||
"""An incorrectly-formed import should raise a PackagingError."""
|
||||
buffer = BytesIO()
|
||||
with self.assertRaises(PackagingError) as e:
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
# This import will fail to load.
|
||||
exporter.save_source_string("foo", "from ........ import lol")
|
||||
|
||||
@ -335,12 +319,12 @@ class TestDependencyAPI(PackageTestCase):
|
||||
def test_repackage_mocked_module(self):
|
||||
"""Re-packaging a package that contains a mocked module should work correctly."""
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.mock("package_a")
|
||||
exporter.save_source_string("foo", "import package_a")
|
||||
|
||||
buffer.seek(0)
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
foo = importer.import_module("foo")
|
||||
|
||||
# "package_a" should be mocked out.
|
||||
@ -350,13 +334,13 @@ class TestDependencyAPI(PackageTestCase):
|
||||
# Re-package the model, but intern the previously-mocked module and mock
|
||||
# everything else.
|
||||
buffer2 = BytesIO()
|
||||
with self.PackageExporter(buffer2, importer=importer) as exporter:
|
||||
with PackageExporter(buffer2, importer=importer) as exporter:
|
||||
exporter.intern("package_a")
|
||||
exporter.mock("**")
|
||||
exporter.save_source_string("foo", "import package_a")
|
||||
|
||||
buffer2.seek(0)
|
||||
importer2 = self.PackageImporter(buffer2)
|
||||
importer2 = PackageImporter(buffer2)
|
||||
foo2 = importer2.import_module("foo")
|
||||
|
||||
# "package_a" should still be mocked out.
|
||||
@ -364,12 +348,5 @@ class TestDependencyAPI(PackageTestCase):
|
||||
foo2.package_a.get_something()
|
||||
|
||||
|
||||
class TestDependencyAPINoTorch(TestDependencyAPI):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -5,9 +5,6 @@ from io import BytesIO
|
||||
from torch.package import (
|
||||
PackageExporter,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
@ -23,10 +20,6 @@ class TestDependencyHooks(PackageTestCase):
|
||||
- register_extern_hook()
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_single_hook(self):
|
||||
buffer = BytesIO()
|
||||
|
||||
@ -35,7 +28,7 @@ class TestDependencyHooks(PackageTestCase):
|
||||
def my_extern_hook(package_exporter, module_name):
|
||||
my_externs.add(module_name)
|
||||
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.extern(["package_a.subpackage", "module_a"])
|
||||
exporter.register_extern_hook(my_extern_hook)
|
||||
exporter.save_source_string("foo", "import module_a")
|
||||
@ -54,7 +47,7 @@ class TestDependencyHooks(PackageTestCase):
|
||||
def my_extern_hook2(package_exporter, module_name):
|
||||
my_externs.remove(module_name)
|
||||
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.extern(["package_a.subpackage", "module_a"])
|
||||
exporter.register_extern_hook(my_extern_hook)
|
||||
exporter.register_extern_hook(my_extern_hook2)
|
||||
@ -74,7 +67,7 @@ class TestDependencyHooks(PackageTestCase):
|
||||
def my_mock_hook2(package_exporter, module_name):
|
||||
my_mocks.remove(module_name)
|
||||
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.mock(["package_a.subpackage", "module_a"])
|
||||
exporter.register_mock_hook(my_mock_hook)
|
||||
exporter.register_mock_hook(my_mock_hook2)
|
||||
@ -94,7 +87,7 @@ class TestDependencyHooks(PackageTestCase):
|
||||
def my_extern_hook2(package_exporter, module_name):
|
||||
my_externs2.add(module_name)
|
||||
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.extern(["package_a.subpackage", "module_a"])
|
||||
handle = exporter.register_extern_hook(my_extern_hook)
|
||||
exporter.register_extern_hook(my_extern_hook2)
|
||||
@ -116,7 +109,7 @@ class TestDependencyHooks(PackageTestCase):
|
||||
def my_mock_hook(package_exporter, module_name):
|
||||
my_mocks.add(module_name)
|
||||
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.extern("module_a")
|
||||
exporter.mock("package_a")
|
||||
exporter.register_extern_hook(my_extern_hook)
|
||||
@ -127,11 +120,5 @@ class TestDependencyHooks(PackageTestCase):
|
||||
self.assertEqual(my_mocks, set(["package_a"]))
|
||||
|
||||
|
||||
class TestDependencyHooksNoTorch(TestDependencyHooks):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -116,7 +116,7 @@ class TestDiGraph(PackageTestCase):
|
||||
|
||||
result = g.all_paths("1", "3")
|
||||
# to get rid of indeterminism
|
||||
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
|
||||
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
|
||||
expected = {
|
||||
'"2" -> "3"',
|
||||
'"1" -> "7"',
|
||||
|
||||
@ -10,12 +10,6 @@ from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch.package import PackageExporter, PackageImporter
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
IS_FBCODE,
|
||||
@ -50,11 +44,6 @@ packaging_directory = Path(__file__).parent
|
||||
class DirectoryReaderTest(PackageTestCase):
|
||||
"""Tests use of DirectoryReader as accessor for opened packages."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageExporter = PackageExporter
|
||||
self.PackageImporter = PackageImporter
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_loading_pickle(self):
|
||||
"""
|
||||
@ -63,7 +52,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
resnet = resnet18()
|
||||
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as e:
|
||||
with PackageExporter(filename) as e:
|
||||
e.intern("**")
|
||||
e.save_pickle("model", "model.pkl", resnet)
|
||||
|
||||
@ -71,7 +60,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_mod = importer.load_pickle("model", "model.pkl")
|
||||
input = torch.rand(1, 3, 224, 224)
|
||||
self.assertEqual(dir_mod(input), resnet(input))
|
||||
@ -83,14 +72,14 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
import package_a
|
||||
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as e:
|
||||
with PackageExporter(filename) as e:
|
||||
e.save_module("package_a")
|
||||
|
||||
zip_file = zipfile.ZipFile(filename, "r")
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_mod = dir_importer.import_module("package_a")
|
||||
self.assertEqual(dir_mod.result, package_a.result)
|
||||
|
||||
@ -101,14 +90,14 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
import package_a # noqa: F401
|
||||
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as e:
|
||||
with PackageExporter(filename) as e:
|
||||
e.save_module("package_a")
|
||||
|
||||
zip_file = zipfile.ZipFile(filename, "r")
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
|
||||
self.assertFalse(dir_importer.zip_reader.has_record("package_a"))
|
||||
|
||||
@ -116,7 +105,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
def test_resource_reader(self):
|
||||
"""Tests DirectoryReader as the base for get_resource_reader."""
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as pe:
|
||||
with PackageExporter(filename) as pe:
|
||||
# Layout looks like:
|
||||
# package
|
||||
# ├── one/
|
||||
@ -143,7 +132,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
reader_one = importer.get_resource_reader("one")
|
||||
|
||||
# Different behavior from still zipped archives
|
||||
@ -198,7 +187,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
"""
|
||||
)
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as pe:
|
||||
with PackageExporter(filename) as pe:
|
||||
pe.save_source_string("foo.bar", mod_src)
|
||||
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
|
||||
|
||||
@ -206,7 +195,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
self.assertEqual(
|
||||
dir_importer.import_module("foo.bar").secret_message(),
|
||||
"my sekrit plays",
|
||||
@ -215,7 +204,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
|
||||
def test_importer_access(self):
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as he:
|
||||
with PackageExporter(filename) as he:
|
||||
he.save_text("main", "main", "my string")
|
||||
he.save_binary("main", "main_binary", "my string".encode("utf-8"))
|
||||
src = dedent(
|
||||
@ -233,7 +222,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
m = dir_importer.import_module("main")
|
||||
self.assertEqual(m.t, "my string")
|
||||
self.assertEqual(m.b, "my string".encode("utf-8"))
|
||||
@ -244,7 +233,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
Tests that packaged code can used importlib.resources.path.
|
||||
"""
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as e:
|
||||
with PackageExporter(filename) as e:
|
||||
e.save_binary("string_module", "my_string", "my string".encode("utf-8"))
|
||||
src = dedent(
|
||||
"""\
|
||||
@ -262,7 +251,7 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
m = dir_importer.import_module("main")
|
||||
self.assertEqual(m.s, "my string")
|
||||
|
||||
@ -271,16 +260,12 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
Test basic saving and loading of a ScriptModule in a directory.
|
||||
Currently not supported.
|
||||
"""
|
||||
|
||||
if self.PackageExporter != PackageExporter:
|
||||
return
|
||||
|
||||
from package_a.test_module import ModWithTensor
|
||||
|
||||
scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
||||
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as e:
|
||||
with PackageExporter(filename) as e:
|
||||
e.save_pickle("res", "mod.pkl", scripted_mod)
|
||||
|
||||
zip_file = zipfile.ZipFile(filename, "r")
|
||||
@ -292,18 +277,9 @@ class DirectoryReaderTest(PackageTestCase):
|
||||
):
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
dir_importer = self.PackageImporter(
|
||||
Path(temp_dir) / Path(filename).name
|
||||
)
|
||||
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
||||
dir_mod = dir_importer.load_pickle("res", "mod.pkl")
|
||||
|
||||
|
||||
class DirectoryReaderTestNoTorch(DirectoryReaderTest):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -10,12 +10,6 @@ from torch.package import (
|
||||
PackageImporter,
|
||||
sys_importer,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
@ -28,11 +22,6 @@ except ImportError:
|
||||
class TestImporter(PackageTestCase):
|
||||
"""Tests for Importer and derived classes."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_sys_importer(self):
|
||||
import package_a
|
||||
import package_a.subpackage
|
||||
@ -58,11 +47,11 @@ class TestImporter(PackageTestCase):
|
||||
import package_a
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.save_module(package_a.__name__)
|
||||
|
||||
buffer.seek(0)
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
|
||||
# Construct an importer-only environment.
|
||||
ordered_importer = OrderedImporter(importer)
|
||||
@ -84,11 +73,11 @@ class TestImporter(PackageTestCase):
|
||||
import package_a
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.save_module(package_a.__name__)
|
||||
|
||||
buffer.seek(0)
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
|
||||
ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
|
||||
self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a)
|
||||
@ -148,32 +137,25 @@ class TestImporter(PackageTestCase):
|
||||
|
||||
# Set up a PackageImporter which has a torch.float16 object pickled:
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.save_pickle("foo", "foo.pkl", my_dtype)
|
||||
buffer.seek(0)
|
||||
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")
|
||||
|
||||
# Re-save a package with only our PackageImporter as the importer
|
||||
buffer2 = BytesIO()
|
||||
with self.PackageExporter(buffer2, importer=importer) as exporter:
|
||||
with PackageExporter(buffer2, importer=importer) as exporter:
|
||||
exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)
|
||||
|
||||
buffer2.seek(0)
|
||||
|
||||
importer2 = self.PackageImporter(buffer2)
|
||||
importer2 = PackageImporter(buffer2)
|
||||
my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
|
||||
self.assertIs(my_dtype, my_loaded_dtype)
|
||||
self.assertIs(my_dtype, my_loaded_dtype2)
|
||||
|
||||
|
||||
class TestImporterNoTorch(TestImporter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -9,12 +9,6 @@ from torch.package._mangling import (
|
||||
get_mangle_prefix,
|
||||
is_mangled,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
@ -25,11 +19,6 @@ except ImportError:
|
||||
|
||||
|
||||
class TestMangling(PackageTestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_unique_manglers(self):
|
||||
"""
|
||||
Each mangler instance should generate a unique mangled name for a given input.
|
||||
@ -94,11 +83,11 @@ class TestMangling(PackageTestCase):
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
obj2 = package_a.PackageAObject(obj)
|
||||
f1 = BytesIO()
|
||||
with self.PackageExporter(f1) as pe:
|
||||
with PackageExporter(f1) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", obj2)
|
||||
f1.seek(0)
|
||||
importer1 = self.PackageImporter(f1)
|
||||
importer1 = PackageImporter(f1)
|
||||
loaded1 = importer1.load_pickle("obj", "obj.pkl")
|
||||
f1.seek(0)
|
||||
importer2 = PackageImporter(f1)
|
||||
@ -119,12 +108,5 @@ class TestMangling(PackageTestCase):
|
||||
self.assertEqual(b.demangle(a_mangled), a_mangled)
|
||||
|
||||
|
||||
class TestManglingNoTorch(TestMangling):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -8,19 +8,9 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from unittest import skipIf
|
||||
|
||||
from torch.package import (
|
||||
PackageExporter,
|
||||
PackageImporter,
|
||||
is_from_package,
|
||||
PackagingError,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, IS_FBCODE, IS_SANDCASTLE
|
||||
from torch.package import PackageExporter, PackageImporter, is_from_package
|
||||
from torch.package.package_exporter import PackagingError
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
|
||||
|
||||
try:
|
||||
from .common import PackageTestCase
|
||||
@ -32,11 +22,6 @@ except ImportError:
|
||||
class TestMisc(PackageTestCase):
|
||||
"""Tests for one-off or random functionality. Try not to add to this!"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_file_structure(self):
|
||||
"""
|
||||
Tests package's Directory structure representation of a zip file. Ensures
|
||||
@ -86,7 +71,7 @@ class TestMisc(PackageTestCase):
|
||||
"""
|
||||
)
|
||||
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
import module_a
|
||||
import package_a
|
||||
import package_a.subpackage
|
||||
@ -99,7 +84,7 @@ class TestMisc(PackageTestCase):
|
||||
he.save_text("main", "main", "my string")
|
||||
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
|
||||
file_structure = hi.file_structure()
|
||||
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
|
||||
@ -154,7 +139,7 @@ class TestMisc(PackageTestCase):
|
||||
Test Directory's has_file() method.
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
import package_a.subpackage
|
||||
|
||||
he.intern("**")
|
||||
@ -163,7 +148,7 @@ class TestMisc(PackageTestCase):
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
file_structure = importer.file_structure()
|
||||
self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
|
||||
self.assertFalse(file_structure.has_file("package_a/subpackage"))
|
||||
@ -173,7 +158,7 @@ class TestMisc(PackageTestCase):
|
||||
Test content list API for PackageExporter's contained modules.
|
||||
"""
|
||||
|
||||
with self.PackageExporter(BytesIO()) as he:
|
||||
with PackageExporter(BytesIO()) as he:
|
||||
import package_b
|
||||
|
||||
he.extern("package_b.subpackage_1")
|
||||
@ -189,7 +174,7 @@ class TestMisc(PackageTestCase):
|
||||
self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])
|
||||
|
||||
with self.assertRaises(PackagingError) as e:
|
||||
with self.PackageExporter(BytesIO()) as he:
|
||||
with PackageExporter(BytesIO()) as he:
|
||||
import package_b
|
||||
|
||||
he.deny("package_b")
|
||||
@ -203,12 +188,12 @@ class TestMisc(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
mod = pi.import_module("package_a.subpackage")
|
||||
loaded_obj = pi.load_pickle("obj", "obj.pkl")
|
||||
|
||||
@ -225,12 +210,12 @@ class TestMisc(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
packaged_class = pi.import_module(
|
||||
"package_a.subpackage"
|
||||
).PackageASubpackageObject
|
||||
@ -249,12 +234,12 @@ class TestMisc(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
mod = pi.import_module("package_a.subpackage")
|
||||
self.assertTrue(hasattr(mod, "__torch_package__"))
|
||||
|
||||
@ -268,12 +253,12 @@ class TestMisc(PackageTestCase):
|
||||
|
||||
buffer = BytesIO()
|
||||
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_module(mod.__name__)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
imported_mod = pi.import_module(mod.__name__)
|
||||
self.assertTrue(imported_mod.is_from_package())
|
||||
self.assertFalse(mod.is_from_package())
|
||||
@ -289,22 +274,15 @@ class TestMisc(PackageTestCase):
|
||||
buffer = BytesIO()
|
||||
mod = package_a.std_sys_module_hacks.Module()
|
||||
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", mod)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
mod = pi.load_pickle("obj", "obj.pkl")
|
||||
mod()
|
||||
|
||||
|
||||
class TestMiscNoTorch(TestMisc):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -624,6 +624,7 @@ class TestPackageScript(PackageTestCase):
|
||||
buffer_1.seek(0)
|
||||
importer = PackageImporter(buffer_1)
|
||||
loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
|
||||
|
||||
self.assertEqual(
|
||||
loaded_mod_1.tensor.storage()._cdata,
|
||||
loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
|
||||
|
||||
@ -7,12 +7,6 @@ from torch.package import (
|
||||
PackageImporter,
|
||||
sys_importer,
|
||||
)
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
@ -25,28 +19,23 @@ except ImportError:
|
||||
class TestRepackage(PackageTestCase):
|
||||
"""Tests for repackaging."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_repackage_import_indirectly_via_parent_module(self):
|
||||
from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
|
||||
from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage
|
||||
|
||||
model_a = ImportsDirectlyFromSubSubPackage()
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("default", "model.py", model_a)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = self.PackageImporter(buffer)
|
||||
pi = PackageImporter(buffer)
|
||||
loaded_model = pi.load_pickle("default", "model.py")
|
||||
|
||||
model_b = ImportsIndirectlyFromSubPackage()
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(
|
||||
with PackageExporter(
|
||||
buffer,
|
||||
importer=(
|
||||
pi,
|
||||
@ -57,12 +46,5 @@ class TestRepackage(PackageTestCase):
|
||||
pe.save_pickle("default", "model_b.py", model_b)
|
||||
|
||||
|
||||
class TestRepackageNoTorch(TestRepackage):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -7,12 +7,6 @@ from textwrap import dedent
|
||||
from unittest import skipIf
|
||||
|
||||
from torch.package import PackageExporter, PackageImporter
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
try:
|
||||
@ -26,15 +20,10 @@ except ImportError:
|
||||
class TestResources(PackageTestCase):
|
||||
"""Tests for access APIs for packaged resources."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
def test_resource_reader(self):
|
||||
"""Test compliance with the get_resource_reader importlib API."""
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
# Layout looks like:
|
||||
# package
|
||||
# ├── one/
|
||||
@ -58,7 +47,7 @@ class TestResources(PackageTestCase):
|
||||
pe.save_text("two", "g.txt", "hello, g!")
|
||||
|
||||
buffer.seek(0)
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
|
||||
reader_one = importer.get_resource_reader("one")
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
@ -102,19 +91,19 @@ class TestResources(PackageTestCase):
|
||||
"""
|
||||
)
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as pe:
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.save_source_string("foo.bar", mod_src)
|
||||
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
|
||||
|
||||
buffer.seek(0)
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
self.assertEqual(
|
||||
importer.import_module("foo.bar").secret_message(), "my sekrit plays"
|
||||
)
|
||||
|
||||
def test_importer_access(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.save_text("main", "main", "my string")
|
||||
he.save_binary("main", "main_binary", "my string".encode("utf-8"))
|
||||
src = dedent(
|
||||
@ -128,7 +117,7 @@ class TestResources(PackageTestCase):
|
||||
)
|
||||
he.save_source_string("main", src, is_package=True)
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
m = hi.import_module("main")
|
||||
self.assertEqual(m.t, "my string")
|
||||
self.assertEqual(m.b, "my string".encode("utf-8"))
|
||||
@ -138,7 +127,7 @@ class TestResources(PackageTestCase):
|
||||
Tests that packaged code can used importlib.resources.path.
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
he.save_binary("string_module", "my_string", "my string".encode("utf-8"))
|
||||
src = dedent(
|
||||
"""\
|
||||
@ -152,17 +141,10 @@ class TestResources(PackageTestCase):
|
||||
)
|
||||
he.save_source_string("main", src, is_package=True)
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
m = hi.import_module("main")
|
||||
self.assertEqual(m.s, "my string")
|
||||
|
||||
|
||||
class TestResourcesNoTorch(TestResources):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -6,12 +6,6 @@ from textwrap import dedent
|
||||
from unittest import skipIf
|
||||
|
||||
from torch.package import PackageExporter, PackageImporter, sys_importer
|
||||
from torch.package.package_exporter_no_torch import (
|
||||
PackageExporter as PackageExporterNoTorch,
|
||||
)
|
||||
from torch.package.package_importer_no_torch import (
|
||||
PackageImporter as PackageImporterNoTorch,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, IS_FBCODE, IS_SANDCASTLE
|
||||
|
||||
try:
|
||||
@ -28,21 +22,16 @@ packaging_directory = Path(__file__).parent
|
||||
class TestSaveLoad(PackageTestCase):
|
||||
"""Core save_* and loading API tests."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporter
|
||||
self.PackageExporter = PackageExporter
|
||||
|
||||
@skipIf(
|
||||
IS_FBCODE or IS_SANDCASTLE,
|
||||
"Tests that use temporary files are disabled in fbcode",
|
||||
)
|
||||
def test_saving_source(self):
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as he:
|
||||
with PackageExporter(filename) as he:
|
||||
he.save_source_file("foo", str(packaging_directory / "module_a.py"))
|
||||
he.save_source_file("foodir", str(packaging_directory / "package_a"))
|
||||
hi = self.PackageImporter(filename)
|
||||
hi = PackageImporter(filename)
|
||||
foo = hi.import_module("foo")
|
||||
s = hi.import_module("foodir.subpackage")
|
||||
self.assertEqual(foo.result, "module_a")
|
||||
@ -54,7 +43,7 @@ class TestSaveLoad(PackageTestCase):
|
||||
)
|
||||
def test_saving_string(self):
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as he:
|
||||
with PackageExporter(filename) as he:
|
||||
src = dedent(
|
||||
"""\
|
||||
import math
|
||||
@ -62,7 +51,7 @@ class TestSaveLoad(PackageTestCase):
|
||||
"""
|
||||
)
|
||||
he.save_source_string("my_mod", src)
|
||||
hi = self.PackageImporter(filename)
|
||||
hi = PackageImporter(filename)
|
||||
m = hi.import_module("math")
|
||||
import math
|
||||
|
||||
@ -76,13 +65,13 @@ class TestSaveLoad(PackageTestCase):
|
||||
)
|
||||
def test_save_module(self):
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as he:
|
||||
with PackageExporter(filename) as he:
|
||||
import module_a
|
||||
import package_a
|
||||
|
||||
he.save_module(module_a.__name__)
|
||||
he.save_module(package_a.__name__)
|
||||
hi = self.PackageImporter(filename)
|
||||
hi = PackageImporter(filename)
|
||||
module_a_i = hi.import_module("module_a")
|
||||
self.assertEqual(module_a_i.result, "module_a")
|
||||
self.assertIsNot(module_a, module_a_i)
|
||||
@ -92,7 +81,7 @@ class TestSaveLoad(PackageTestCase):
|
||||
|
||||
def test_dunder_imports(self):
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as he:
|
||||
with PackageExporter(buffer) as he:
|
||||
import package_b
|
||||
|
||||
obj = package_b.PackageBObject
|
||||
@ -100,7 +89,7 @@ class TestSaveLoad(PackageTestCase):
|
||||
he.save_pickle("res", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
hi = self.PackageImporter(buffer)
|
||||
hi = PackageImporter(buffer)
|
||||
loaded_obj = hi.load_pickle("res", "obj.pkl")
|
||||
|
||||
package_b = hi.import_module("package_b")
|
||||
@ -124,21 +113,21 @@ class TestSaveLoad(PackageTestCase):
|
||||
def test_bad_dunder_imports(self):
|
||||
"""Test to ensure bad __imports__ don't cause PackageExporter to fail."""
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as e:
|
||||
with PackageExporter(buffer) as e:
|
||||
e.save_source_string(
|
||||
"m", '__import__(these, unresolvable, "things", wont, crash, me)'
|
||||
)
|
||||
|
||||
def test_save_module_binary(self):
|
||||
f = BytesIO()
|
||||
with self.PackageExporter(f) as he:
|
||||
with PackageExporter(f) as he:
|
||||
import module_a
|
||||
import package_a
|
||||
|
||||
he.save_module(module_a.__name__)
|
||||
he.save_module(package_a.__name__)
|
||||
f.seek(0)
|
||||
hi = self.PackageImporter(f)
|
||||
hi = PackageImporter(f)
|
||||
module_a_i = hi.import_module("module_a")
|
||||
self.assertEqual(module_a_i.result, "module_a")
|
||||
self.assertIsNot(module_a, module_a_i)
|
||||
@ -157,10 +146,10 @@ class TestSaveLoad(PackageTestCase):
|
||||
obj2 = package_a.PackageAObject(obj)
|
||||
|
||||
filename = self.temp()
|
||||
with self.PackageExporter(filename) as he:
|
||||
with PackageExporter(filename) as he:
|
||||
he.intern("**")
|
||||
he.save_pickle("obj", "obj.pkl", obj2)
|
||||
hi = self.PackageImporter(filename)
|
||||
hi = PackageImporter(filename)
|
||||
|
||||
# check we got dependencies
|
||||
sp = hi.import_module("package_a.subpackage")
|
||||
@ -190,19 +179,19 @@ class TestSaveLoad(PackageTestCase):
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
obj2 = package_a.PackageAObject(obj)
|
||||
f1 = self.temp()
|
||||
with self.PackageExporter(f1) as pe:
|
||||
with PackageExporter(f1) as pe:
|
||||
pe.intern("**")
|
||||
pe.save_pickle("obj", "obj.pkl", obj2)
|
||||
|
||||
importer1 = self.PackageImporter(f1)
|
||||
importer1 = PackageImporter(f1)
|
||||
loaded1 = importer1.load_pickle("obj", "obj.pkl")
|
||||
importer2 = self.PackageImporter(f1)
|
||||
importer2 = PackageImporter(f1)
|
||||
loaded2 = importer2.load_pickle("obj", "obj.pkl")
|
||||
|
||||
f2 = self.temp()
|
||||
|
||||
def make_exporter():
|
||||
pe = self.PackageExporter(f2, importer=[importer1, sys_importer])
|
||||
pe = PackageExporter(f2, importer=[importer1, sys_importer])
|
||||
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
|
||||
return pe
|
||||
|
||||
@ -231,21 +220,19 @@ class TestSaveLoad(PackageTestCase):
|
||||
obj2 = package_a.PackageAObject(obj)
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.intern("**")
|
||||
exporter.save_pickle("model", "model.pkl", obj2)
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
imported_obj2 = importer.load_pickle("model", "model.pkl")
|
||||
imported_obj2_module = imported_obj2.__class__.__module__
|
||||
|
||||
# Should export without error.
|
||||
buffer2 = BytesIO()
|
||||
with self.PackageExporter(
|
||||
buffer2, importer=(importer, sys_importer)
|
||||
) as exporter:
|
||||
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
|
||||
exporter.intern("**")
|
||||
exporter.save_module(imported_obj2_module)
|
||||
|
||||
@ -254,29 +241,20 @@ class TestSaveLoad(PackageTestCase):
|
||||
import package_a.use_torch_package_importer # noqa: F401
|
||||
|
||||
buffer = BytesIO()
|
||||
with self.PackageExporter(buffer) as exporter:
|
||||
with PackageExporter(buffer) as exporter:
|
||||
exporter.intern("**")
|
||||
exporter.save_module("package_a.use_torch_package_importer")
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
importer = self.PackageImporter(buffer)
|
||||
importer = PackageImporter(buffer)
|
||||
|
||||
# Should export without error.
|
||||
buffer2 = BytesIO()
|
||||
with self.PackageExporter(
|
||||
buffer2, importer=(importer, sys_importer)
|
||||
) as exporter:
|
||||
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
|
||||
exporter.intern("**")
|
||||
exporter.save_module("package_a.use_torch_package_importer")
|
||||
|
||||
|
||||
class TestSaveLoadNoTorch(TestSaveLoad):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.PackageImporter = PackageImporterNoTorch
|
||||
self.PackageExporter = PackageExporterNoTorch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -3,7 +3,6 @@ import torch
|
||||
from torch.package._package_pickler import create_pickler
|
||||
from torch.package._package_unpickler import PackageUnpickler
|
||||
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
|
||||
from torch.package._zip_file_torchscript import TorchScriptPackageZipFileReader
|
||||
from torch.serialization import _maybe_decode_ascii
|
||||
|
||||
def _save_storages(importer, obj):
|
||||
@ -49,10 +48,7 @@ def _save_storages(importer, obj):
|
||||
pickler.persistent_id = persistent_id
|
||||
pickler.dump(obj)
|
||||
data_value = data_buf.getvalue()
|
||||
|
||||
assert (not importer or isinstance(importer.zip_reader, TorchScriptPackageZipFileReader)), \
|
||||
f'importer {importer}\'s zip reader is of type {type(importer.zip_reader)} not TorchScriptPackageZipFileReader'
|
||||
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader.zip_reader if importer else None
|
||||
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
|
||||
|
||||
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
|
||||
|
||||
|
||||
@ -44,7 +44,6 @@ from torch.jit._monkeytype_config import (
|
||||
JitTypeTraceStore
|
||||
)
|
||||
from torch._classes import classes
|
||||
from torch.package._zip_file_torchscript import TorchScriptPackageZipFileWriter, TorchScriptPackageZipFileReader
|
||||
|
||||
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
|
||||
|
||||
@ -343,21 +342,15 @@ def unpackage_script_module(importer: PackageImporter, script_module_id: str) ->
|
||||
Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
|
||||
Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
|
||||
"""
|
||||
|
||||
if not isinstance(importer.zip_reader, TorchScriptPackageZipFileReader):
|
||||
raise RuntimeError(
|
||||
f"Loading ScriptObjects from a PackageImporter must be done using a TorchScriptPackageZipFileReader"
|
||||
f"not an object of type {type(importer.zip_reader)}"
|
||||
)
|
||||
if importer.zip_reader.is_directory():
|
||||
if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader):
|
||||
raise RuntimeError(
|
||||
"Loading ScriptObjects from a PackageImporter created from a "
|
||||
f"directory is not supported. Use a package archive file instead. is of type {type(importer.zip_reader)}"
|
||||
"directory is not supported. Use a package archive file instead."
|
||||
)
|
||||
cu = torch._C.CompilationUnit()
|
||||
cpp_module = torch._C._import_ir_module_from_package(
|
||||
cu,
|
||||
importer.zip_reader.zip_reader, # type: ignore[arg-type]
|
||||
importer.zip_reader,
|
||||
importer.storage_context,
|
||||
validate_map_location(importer.last_map_location),
|
||||
script_module_id,
|
||||
@ -542,10 +535,8 @@ if _enabled:
|
||||
Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s
|
||||
Pickler's ``persistent_load`` function.
|
||||
"""
|
||||
assert isinstance(exporter.zip_file, TorchScriptPackageZipFileWriter)
|
||||
script_module_serializer = exporter.zip_file.script_module_serializer
|
||||
script_module_id = exporter.get_unique_id()
|
||||
script_module_serializer.serialize(self._c, int(script_module_id))
|
||||
exporter.script_module_serializer.serialize(self._c, int(script_module_id))
|
||||
return (unpackage_script_module, (script_module_id,))
|
||||
|
||||
class RecursiveScriptModule(ScriptModule):
|
||||
|
||||
@ -8,6 +8,5 @@ from .importer import (
|
||||
OrderedImporter,
|
||||
sys_importer,
|
||||
)
|
||||
from .package_exporter import PackageExporter
|
||||
from .package_exporter_no_torch import EmptyMatchError, PackagingError
|
||||
from .package_exporter import EmptyMatchError, PackageExporter, PackagingError
|
||||
from .package_importer import PackageImporter
|
||||
|
||||
@ -1,5 +1,17 @@
|
||||
import os.path
|
||||
from glob import glob
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.types import Storage
|
||||
|
||||
# because get_storage_from_record returns a tensor!?
|
||||
class _HasStorage(object):
|
||||
def __init__(self, storage):
|
||||
self._storage = storage
|
||||
|
||||
def storage(self):
|
||||
return self._storage
|
||||
|
||||
|
||||
class DirectoryReader(object):
|
||||
@ -7,6 +19,9 @@ class DirectoryReader(object):
|
||||
Class to allow PackageImporter to operate on unzipped packages. Methods
|
||||
copy the behavior of the internal PyTorchFileReader class (which is used for
|
||||
accessing packages in all other cases).
|
||||
|
||||
N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
|
||||
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
|
||||
"""
|
||||
|
||||
def __init__(self, directory):
|
||||
@ -17,6 +32,12 @@ class DirectoryReader(object):
|
||||
with open(filename, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_storage_from_record(self, name, numel, dtype):
|
||||
filename = f"{self.directory}/{name}"
|
||||
nbytes = torch._utils._element_size(dtype) * numel
|
||||
storage = cast(Storage, torch._UntypedStorage)
|
||||
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
|
||||
|
||||
def has_record(self, path):
|
||||
full_path = os.path.join(self.directory, path)
|
||||
return os.path.isfile(full_path)
|
||||
@ -29,6 +50,3 @@ class DirectoryReader(object):
|
||||
if not os.path.isdir(filename):
|
||||
files.append(filename[len(self.directory) + 1 :])
|
||||
return files
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.types import Storage
|
||||
|
||||
from ._directory_reader import DirectoryReader
|
||||
|
||||
# because get_storage_from_record returns a tensor!?
|
||||
class _HasStorage(object):
|
||||
def __init__(self, storage):
|
||||
self._storage = storage
|
||||
|
||||
def storage(self):
|
||||
return self._storage
|
||||
|
||||
|
||||
class TorchScriptDirectoryReader(DirectoryReader):
|
||||
"""
|
||||
Class to allow PackageImporter to operate on unzipped packages which include
|
||||
torchscript modules. Methods copy the behavior of the internal PyTorchFileReader
|
||||
class (which is used for accessing packages in all other cases).
|
||||
|
||||
N.B.: ScriptObjects are not depickleable or accessible via this TorchScriptDirectoryReader
|
||||
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
|
||||
"""
|
||||
|
||||
def __init__(self, directory):
|
||||
super().__init__(directory)
|
||||
|
||||
def get_storage_from_record(self, name, numel, dtype):
|
||||
filename = f"{self.directory}/{name}"
|
||||
nbytes = torch._utils._element_size(dtype) * numel
|
||||
storage = cast(Storage, torch._UntypedStorage)
|
||||
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
|
||||
@ -1,38 +0,0 @@
|
||||
import weakref
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RemovableHandle:
|
||||
"""A handle which provides the capability to remove a hook."""
|
||||
|
||||
id: int
|
||||
next_id: int = 0
|
||||
|
||||
def __init__(self, hooks_dict: Any) -> None:
|
||||
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
||||
self.id = RemovableHandle.next_id
|
||||
RemovableHandle.next_id += 1
|
||||
|
||||
def remove(self) -> None:
|
||||
hooks_dict = self.hooks_dict_ref()
|
||||
if hooks_dict is not None and self.id in hooks_dict:
|
||||
del hooks_dict[self.id]
|
||||
|
||||
def __getstate__(self):
|
||||
return (self.hooks_dict_ref(), self.id)
|
||||
|
||||
def __setstate__(self, state) -> None:
|
||||
if state[0] is None:
|
||||
# create a dead reference
|
||||
self.hooks_dict_ref = weakref.ref(OrderedDict())
|
||||
else:
|
||||
self.hooks_dict_ref = weakref.ref(state[0])
|
||||
self.id = state[1]
|
||||
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
||||
|
||||
def __enter__(self) -> "RemovableHandle":
|
||||
return self
|
||||
|
||||
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
||||
self.remove()
|
||||
@ -1,163 +0,0 @@
|
||||
import os.path
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List, Union, BinaryIO, Optional
|
||||
|
||||
from ._directory_reader import DirectoryReader
|
||||
|
||||
|
||||
class PackageZipFileReader(ABC):
|
||||
"""
|
||||
Class to allow PackageImporter to operate objects. To create a custom
|
||||
zip file reader for PackageImporter simply inherit this class.
|
||||
"""
|
||||
|
||||
def __init__(self, file_or_buffer: Union[str, Path, BinaryIO]):
|
||||
raise NotImplementedError(
|
||||
f"init(self, name: str) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, name: str) -> bytes:
|
||||
raise NotImplementedError(
|
||||
f"get_record(self, name: str) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def has_record(self, path: str) -> bool:
|
||||
raise NotImplementedError(
|
||||
f"has_record(self, path: str) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_all_records(self) -> List[str]:
|
||||
raise NotImplementedError(
|
||||
f"get_all_records(self) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_filename(self) -> str:
|
||||
raise NotImplementedError(
|
||||
f"get_filename(self) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
def is_directory(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
f"is_directory(self) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
raise NotImplementedError(f"close(self) is not implemented in {type(self)}")
|
||||
|
||||
|
||||
class PackageZipFileWriter(ABC):
|
||||
"""
|
||||
Class to allow PackageExporter to operate objects. To create a custom
|
||||
zip file writer for PackageExporter simply inherit this class.
|
||||
"""
|
||||
|
||||
def __init__(self, f: Union[str, Path, BinaryIO]):
|
||||
raise NotImplementedError(
|
||||
f"init(self, name: str) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def write_record(self, f, str_or_bytes: Union[str, bytes], size: int):
|
||||
raise NotImplementedError(
|
||||
f"write_record(self, f, str_or_bytes, size) is not implemented in {type(self)}"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
raise NotImplementedError(f"close(self) is not implemented in {type(self)}")
|
||||
|
||||
|
||||
class DefaultPackageZipFileWriter(zipfile.ZipFile, PackageZipFileWriter):
|
||||
"""
|
||||
Class to allow PackageExporter to operate general objects. This is default
|
||||
zipfile reader. This is effectively a wrapper around ZipFile to have a similar
|
||||
API to torch._C.PyTorchWriter.
|
||||
"""
|
||||
|
||||
def __init__(self, f: Union[str, Path, BinaryIO]):
|
||||
|
||||
if isinstance(f, (Path, str)):
|
||||
f = str(f)
|
||||
self.buffer: Optional[BinaryIO] = None
|
||||
else: # is a byte buffer
|
||||
self.buffer = f
|
||||
|
||||
super().__init__(f, mode="w")
|
||||
|
||||
self.prefix: str = "archive"
|
||||
if isinstance(f, (Path, str)):
|
||||
self.prefix = "/".join(str(f).strip("/").split("/")[1:])
|
||||
super().writestr(f"{self.prefix}/.data/version", "6\n")
|
||||
|
||||
def write_record(self, f: str, str_or_bytes: Union[str, bytes], size: int = None):
|
||||
super().writestr(f"{self.prefix}/{f}", str_or_bytes)
|
||||
|
||||
def close(self):
|
||||
if self.buffer:
|
||||
self.buffer.flush()
|
||||
super().close()
|
||||
|
||||
|
||||
class DefaultPackageZipFileReader(PackageZipFileReader):
|
||||
"""
|
||||
Class to allow PackageImporter to operate general objects. This is default
|
||||
zipfile reader. This is effectively a wrapper around ZipFile to have a similar
|
||||
API to torch._C.PyTorchReader.
|
||||
"""
|
||||
|
||||
def __init__(self, file_or_buffer: Union[str, Path, BinaryIO]):
|
||||
|
||||
if isinstance(file_or_buffer, (Path, str)):
|
||||
self.filename = str(file_or_buffer)
|
||||
if not os.path.isdir(self.filename):
|
||||
self.zip_reader: Union[
|
||||
zipfile.ZipFile, DirectoryReader
|
||||
] = zipfile.ZipFile(self.filename)
|
||||
else:
|
||||
self.zip_reader = DirectoryReader(self.filename)
|
||||
else:
|
||||
self.filename = "<binary>"
|
||||
self.zip_reader = zipfile.ZipFile(file_or_buffer)
|
||||
|
||||
if isinstance(self.zip_reader, DirectoryReader):
|
||||
self.records = self.zip_reader.get_all_records()
|
||||
|
||||
elif isinstance(self.zip_reader, zipfile.ZipFile):
|
||||
prefixed_records = self.zip_reader.namelist()
|
||||
|
||||
self.records = []
|
||||
if isinstance(file_or_buffer, BytesIO):
|
||||
self.prefix = "archive"
|
||||
else:
|
||||
self.prefix = "/".join(str(file_or_buffer).strip("/").split("/")[1:])
|
||||
for record in prefixed_records:
|
||||
self.records.append(record[len(self.prefix) + 1 :])
|
||||
|
||||
def get_record(self, name: str) -> bytes:
|
||||
if isinstance(self.zip_reader, DirectoryReader):
|
||||
return self.zip_reader.get_record(f"{name}")
|
||||
else:
|
||||
return self.zip_reader.read(f"{self.prefix}/{name}")
|
||||
|
||||
def has_record(self, path: str) -> bool:
|
||||
return path in self.records
|
||||
|
||||
def get_all_records(self) -> List[str]:
|
||||
return list(self.records)
|
||||
|
||||
def get_filename(self) -> str:
|
||||
return self.filename
|
||||
|
||||
def is_directory(self) -> bool:
|
||||
return isinstance(self.zip_reader, DirectoryReader)
|
||||
|
||||
def close(self):
|
||||
self.zip_reader.close()
|
||||
@ -1,91 +0,0 @@
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import List, BinaryIO, Union, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ._directory_reader_torchscript import TorchScriptDirectoryReader, _HasStorage
|
||||
from ._zip_file import PackageZipFileReader, PackageZipFileWriter
|
||||
|
||||
|
||||
class TorchScriptPackageZipFileWriter(PackageZipFileWriter):
|
||||
"""
|
||||
Class to allow PackageExporter to operate torchscript objects. This
|
||||
is a wrapper around the PyTorchFileWriter and ScriptModuleSerializer classes.
|
||||
"""
|
||||
|
||||
def __init__(self, f: Union[str, Path, BinaryIO]):
|
||||
|
||||
if isinstance(f, (Path, str)):
|
||||
f = str(f)
|
||||
self.buffer: Optional[BinaryIO] = None
|
||||
else: # is a byte buffer
|
||||
self.buffer = f
|
||||
|
||||
self.zip_file_writer = torch._C.PyTorchFileWriter(f)
|
||||
self.zip_file_writer.set_min_version(6)
|
||||
self.script_module_serializer = torch._C.ScriptModuleSerializer(
|
||||
self.zip_file_writer
|
||||
)
|
||||
self.storage_context = self.script_module_serializer.storage_context()
|
||||
|
||||
def write_record(self, f: str, str_or_bytes: Union[str, bytes], size: int):
|
||||
if isinstance(str_or_bytes, str):
|
||||
str_or_bytes = str.encode(f)
|
||||
self.zip_file_writer.write_record(f, str_or_bytes, size)
|
||||
|
||||
def close(self):
|
||||
self.script_module_serializer.write_files()
|
||||
if self.buffer:
|
||||
self.buffer.flush()
|
||||
|
||||
|
||||
class TorchScriptPackageZipFileReader(PackageZipFileReader):
|
||||
"""
|
||||
Class to allow PackageImporter to operate torchscript objects. This
|
||||
is a wrapper around the PyTorchReader class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO]
|
||||
):
|
||||
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
|
||||
self.filename = "<pytorch_file_reader>"
|
||||
self.zip_reader: Union[
|
||||
torch._C.PyTorchFileReader, TorchScriptDirectoryReader
|
||||
] = file_or_buffer
|
||||
elif isinstance(file_or_buffer, (Path, str)):
|
||||
self.filename = str(file_or_buffer)
|
||||
if not os.path.isdir(self.filename):
|
||||
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
||||
else:
|
||||
self.zip_reader = TorchScriptDirectoryReader(self.filename)
|
||||
else:
|
||||
self.filename = "<binary>"
|
||||
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
|
||||
|
||||
def get_record(self, name: str) -> bytes:
|
||||
return self.zip_reader.get_record(name)
|
||||
|
||||
# NOTE: for has_record, get_all_records, and get_storage_from_record pybind doesn't reaveal
|
||||
# the attributes of PyTorchFileReader, so it'll call an error. Strangely, this error
|
||||
# doesn't have an error code which is why it's ignored
|
||||
def has_record(self, path: str) -> bool:
|
||||
return self.zip_reader.has_record(path) # type: ignore[union-attr]
|
||||
|
||||
def get_all_records(self) -> List[str]:
|
||||
return self.zip_reader.get_all_records() # type: ignore[union-attr]
|
||||
|
||||
def get_storage_from_record(
|
||||
self, name: str, numel: int, dtype: torch.dtype
|
||||
) -> _HasStorage:
|
||||
return self.zip_reader.get_storage_from_record(name, numel, dtype) # type: ignore[union-attr]
|
||||
|
||||
def get_filename(self) -> str:
|
||||
return self.filename
|
||||
|
||||
def is_directory(self) -> bool:
|
||||
return isinstance(self.zip_reader, TorchScriptDirectoryReader)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from ..package_exporter_no_torch import PackagingError
|
||||
from ..package_exporter import PackagingError
|
||||
|
||||
|
||||
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,21 +1,55 @@
|
||||
import builtins
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import linecache
|
||||
import os.path
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Union, BinaryIO, Callable, Dict
|
||||
from typing import cast, Any, BinaryIO, Callable, Dict, List, Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
from torch.serialization import _get_restore_location
|
||||
from torch.serialization import _get_restore_location, _maybe_decode_ascii
|
||||
|
||||
from ._directory_reader_torchscript import TorchScriptDirectoryReader
|
||||
from ._zip_file_torchscript import TorchScriptPackageZipFileReader
|
||||
from .package_importer_no_torch import _maybe_decode_ascii
|
||||
from .package_importer_no_torch import PackageImporter as DefaultPackageImporter
|
||||
from ._directory_reader import DirectoryReader
|
||||
from ._importlib import (
|
||||
_calc___package__,
|
||||
_normalize_line_endings,
|
||||
_normalize_path,
|
||||
_resolve_name,
|
||||
_sanity_check,
|
||||
)
|
||||
from ._mangling import PackageMangler, demangle
|
||||
from ._package_unpickler import PackageUnpickler
|
||||
from .file_structure_representation import Directory, _create_directory_from_file_list
|
||||
from .glob_group import GlobPattern
|
||||
from .importer import Importer
|
||||
|
||||
|
||||
class PackageImporter(DefaultPackageImporter):
|
||||
class PackageImporter(Importer):
|
||||
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
|
||||
Code is loaded in a hermetic way, using files from the package
|
||||
rather than the normal python import system. This allows
|
||||
for the packaging of PyTorch model code and data so that it can be run
|
||||
on a server or used in the future for transfer learning.
|
||||
|
||||
The importer for packages ensures that code in the module can only be loaded from
|
||||
within the package, except for modules explicitly listed as external during export.
|
||||
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
|
||||
This prevents "implicit" dependencies where the package runs locally because it is importing
|
||||
a locally-installed package, but then fails when the package is copied to another machine.
|
||||
"""
|
||||
|
||||
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
|
||||
local to this importer.
|
||||
"""
|
||||
modules: Dict[str, types.ModuleType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_or_buffer: Union[str, Path, BinaryIO],
|
||||
file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
|
||||
module_allowed: Callable[[str], bool] = lambda module_name: True,
|
||||
):
|
||||
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
|
||||
@ -31,69 +65,107 @@ class PackageImporter(DefaultPackageImporter):
|
||||
Raises:
|
||||
ImportError: If the package will use a disallowed module.
|
||||
"""
|
||||
super(PackageImporter, self).__init__(
|
||||
file_or_buffer,
|
||||
module_allowed,
|
||||
zip_file_reader_type=TorchScriptPackageZipFileReader,
|
||||
)
|
||||
|
||||
def persistent_load(self, typename, data):
|
||||
assert isinstance(
|
||||
self.zip_reader,
|
||||
(TorchScriptDirectoryReader, TorchScriptPackageZipFileReader),
|
||||
)
|
||||
|
||||
def load_tensor(dtype, size, key, location, restore_location):
|
||||
assert self.loaded_storages is not None
|
||||
name = f"{key}.storage"
|
||||
|
||||
if self.storage_context.has_storage(name):
|
||||
storage = self.storage_context.get_storage(name, dtype).storage()
|
||||
self.zip_reader: Any
|
||||
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
|
||||
self.filename = "<pytorch_file_reader>"
|
||||
self.zip_reader = file_or_buffer
|
||||
elif isinstance(file_or_buffer, (Path, str)):
|
||||
self.filename = str(file_or_buffer)
|
||||
if not os.path.isdir(self.filename):
|
||||
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
||||
else:
|
||||
tensor = self.zip_reader.get_storage_from_record( # type: ignore[attr-defined]
|
||||
".data/" + name, size, dtype
|
||||
self.zip_reader = DirectoryReader(self.filename)
|
||||
else:
|
||||
self.filename = "<binary>"
|
||||
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
|
||||
|
||||
self.root = _PackageNode(None)
|
||||
self.modules = {}
|
||||
self.extern_modules = self._read_extern()
|
||||
|
||||
for extern_module in self.extern_modules:
|
||||
if not module_allowed(extern_module):
|
||||
raise ImportError(
|
||||
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
|
||||
f"but that module has been disallowed"
|
||||
)
|
||||
if not self.zip_reader.is_directory():
|
||||
self.storage_context.add_storage(name, tensor)
|
||||
storage = tensor.storage()
|
||||
self.loaded_storages[key] = restore_location(storage, location)
|
||||
self._add_extern(extern_module)
|
||||
|
||||
if typename == "storage":
|
||||
storage_type, key, location, size = data
|
||||
dtype = storage_type.dtype
|
||||
assert self.loaded_storages is not None
|
||||
if key not in self.loaded_storages:
|
||||
load_tensor(
|
||||
dtype,
|
||||
size,
|
||||
key,
|
||||
_maybe_decode_ascii(location),
|
||||
self.restore_location,
|
||||
)
|
||||
storage = self.loaded_storages[key]
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
return torch.storage._TypedStorage(
|
||||
wrap_storage=storage._untyped(), dtype=dtype
|
||||
)
|
||||
return None
|
||||
for fname in self.zip_reader.get_all_records():
|
||||
self._add_file(fname)
|
||||
|
||||
@contextmanager
|
||||
def set_torch_deserialization_context(self, map_location):
|
||||
# to let reduce_package access deserializaiton context
|
||||
self.storage_context = torch._C.DeserializationStorageContext()
|
||||
self.last_map_location = map_location
|
||||
self.restore_location = _get_restore_location(map_location)
|
||||
self.loaded_storages: Union[Dict[int, Any], None] = {}
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.storage_context = None
|
||||
self.last_map_location = None
|
||||
self.restore_location = None
|
||||
self.loaded_storages = None
|
||||
self.patched_builtins = builtins.__dict__.copy()
|
||||
self.patched_builtins["__import__"] = self.__import__
|
||||
# Allow packaged modules to reference their PackageImporter
|
||||
self.modules["torch_package_importer"] = self # type: ignore[assignment]
|
||||
|
||||
self._mangler = PackageMangler()
|
||||
|
||||
# used for reduce deserializaiton
|
||||
self.storage_context: Any = None
|
||||
self.last_map_location = None
|
||||
|
||||
# used for torch.serialization._load
|
||||
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
|
||||
|
||||
def import_module(self, name: str, package=None):
|
||||
"""Load a module from the package if it hasn't already been loaded, and then return
|
||||
the module. Modules are loaded locally
|
||||
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
|
||||
|
||||
Args:
|
||||
name (str): Fully qualified name of the module to load.
|
||||
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
types.ModuleType: The (possibly already) loaded module.
|
||||
"""
|
||||
# We should always be able to support importing modules from this package.
|
||||
# This is to support something like:
|
||||
# obj = importer.load_pickle(...)
|
||||
# importer.import_module(obj.__module__) <- this string will be mangled
|
||||
#
|
||||
# Note that _mangler.demangle will not demangle any module names
|
||||
# produced by a different PackageImporter instance.
|
||||
name = self._mangler.demangle(name)
|
||||
|
||||
return self._gcd_import(name)
|
||||
|
||||
def load_binary(self, package: str, resource: str) -> bytes:
|
||||
"""Load raw bytes.
|
||||
|
||||
Args:
|
||||
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
|
||||
resource (str): The unique name for the resource.
|
||||
|
||||
Returns:
|
||||
bytes: The loaded data.
|
||||
"""
|
||||
|
||||
path = self._zipfile_path(package, resource)
|
||||
return self.zip_reader.get_record(path)
|
||||
|
||||
def load_text(
|
||||
self,
|
||||
package: str,
|
||||
resource: str,
|
||||
encoding: str = "utf-8",
|
||||
errors: str = "strict",
|
||||
) -> str:
|
||||
"""Load a string.
|
||||
|
||||
Args:
|
||||
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
|
||||
resource (str): The unique name for the resource.
|
||||
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
|
||||
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
|
||||
|
||||
Returns:
|
||||
str: The loaded text.
|
||||
"""
|
||||
data = self.load_binary(package, resource)
|
||||
return data.decode(encoding, errors)
|
||||
|
||||
# TODO: load_pickle to reduce the repeated code between this and the non-torch version
|
||||
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
|
||||
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
|
||||
using :meth:`import_module`.
|
||||
@ -107,16 +179,49 @@ class PackageImporter(DefaultPackageImporter):
|
||||
Any: The unpickled object.
|
||||
"""
|
||||
pickle_file = self._zipfile_path(package, resource)
|
||||
restore_location = _get_restore_location(map_location)
|
||||
loaded_storages = {}
|
||||
loaded_reduces = {}
|
||||
storage_context = torch._C.DeserializationStorageContext()
|
||||
|
||||
def _persistent_load(saved_id):
|
||||
def load_tensor(dtype, size, key, location, restore_location):
|
||||
name = f"{key}.storage"
|
||||
|
||||
if storage_context.has_storage(name):
|
||||
storage = storage_context.get_storage(name, dtype).storage()
|
||||
else:
|
||||
tensor = self.zip_reader.get_storage_from_record(
|
||||
".data/" + name, size, dtype
|
||||
)
|
||||
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
|
||||
storage_context.add_storage(name, tensor)
|
||||
storage = tensor.storage()
|
||||
loaded_storages[key] = restore_location(storage, location)
|
||||
|
||||
def persistent_load(saved_id):
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = _maybe_decode_ascii(saved_id[0])
|
||||
data = saved_id[1:]
|
||||
module = self.persistent_load(typename, data)
|
||||
if module is not None:
|
||||
return module
|
||||
if typename == "reduce_package":
|
||||
|
||||
if typename == "storage":
|
||||
storage_type, key, location, size = data
|
||||
dtype = storage_type.dtype
|
||||
|
||||
if key not in loaded_storages:
|
||||
load_tensor(
|
||||
dtype,
|
||||
size,
|
||||
key,
|
||||
_maybe_decode_ascii(location),
|
||||
restore_location,
|
||||
)
|
||||
storage = loaded_storages[key]
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
return torch.storage._TypedStorage(
|
||||
wrap_storage=storage._untyped(), dtype=dtype
|
||||
)
|
||||
elif typename == "reduce_package":
|
||||
# to fix BC breaking change, objects on this load path
|
||||
# will be loaded multiple times erroneously
|
||||
if len(data) == 2:
|
||||
@ -132,13 +237,466 @@ class PackageImporter(DefaultPackageImporter):
|
||||
# Load the data (which may in turn use `persistent_load` to load tensors)
|
||||
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
|
||||
unpickler = self.Unpickler(data_file)
|
||||
unpickler.persistent_load = _persistent_load
|
||||
with self.set_torch_deserialization_context(map_location):
|
||||
unpickler.persistent_load = persistent_load
|
||||
|
||||
@contextmanager
|
||||
def set_deserialization_context():
|
||||
# to let reduce_package access deserializaiton context
|
||||
self.storage_context = storage_context
|
||||
self.last_map_location = map_location
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.storage_context = None
|
||||
self.last_map_location = None
|
||||
|
||||
with set_deserialization_context():
|
||||
result = unpickler.load()
|
||||
# TODO from zdevito:
|
||||
# This stateful weird function will need to be removed in our efforts
|
||||
# to unify the format. It has a race condition if multiple python
|
||||
# threads try to read independent files
|
||||
torch._utils._validate_loaded_sparse_tensors()
|
||||
|
||||
# TODO from zdevito:
|
||||
# This stateful weird function will need to be removed in our efforts
|
||||
# to unify the format. It has a race condition if multiple python
|
||||
# threads try to read independent files
|
||||
torch._utils._validate_loaded_sparse_tensors()
|
||||
|
||||
return result
|
||||
|
||||
def id(self):
|
||||
"""
|
||||
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
|
||||
Looks like::
|
||||
|
||||
<torch_package_0>
|
||||
"""
|
||||
return self._mangler.parent_name()
|
||||
|
||||
def file_structure(
|
||||
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
|
||||
) -> Directory:
|
||||
"""Returns a file structure representation of package's zipfile.
|
||||
|
||||
Args:
|
||||
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
|
||||
for the names of the files to be inluded in the zipfile representation. This can also be
|
||||
a glob-style pattern, as described in :meth:`PackageExporter.mock`
|
||||
|
||||
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
|
||||
|
||||
Returns:
|
||||
:class:`Directory`
|
||||
"""
|
||||
return _create_directory_from_file_list(
|
||||
self.filename, self.zip_reader.get_all_records(), include, exclude
|
||||
)
|
||||
|
||||
def python_version(self):
|
||||
"""Returns the version of python that was used to create this package.
|
||||
|
||||
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
|
||||
file later on.
|
||||
|
||||
Returns:
|
||||
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
|
||||
"""
|
||||
python_version_path = ".data/python_version"
|
||||
return (
|
||||
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
|
||||
if self.zip_reader.has_record(python_version_path)
|
||||
else None
|
||||
)
|
||||
|
||||
def _read_extern(self):
|
||||
return (
|
||||
self.zip_reader.get_record(".data/extern_modules")
|
||||
.decode("utf-8")
|
||||
.splitlines(keepends=False)
|
||||
)
|
||||
|
||||
def _make_module(
|
||||
self, name: str, filename: Optional[str], is_package: bool, parent: str
|
||||
):
|
||||
mangled_filename = self._mangler.mangle(filename) if filename else None
|
||||
spec = importlib.machinery.ModuleSpec(
|
||||
name,
|
||||
self, # type: ignore[arg-type]
|
||||
origin="<package_importer>",
|
||||
is_package=is_package,
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
self.modules[name] = module
|
||||
module.__name__ = self._mangler.mangle(name)
|
||||
ns = module.__dict__
|
||||
ns["__spec__"] = spec
|
||||
ns["__loader__"] = self
|
||||
ns["__file__"] = mangled_filename
|
||||
ns["__cached__"] = None
|
||||
ns["__builtins__"] = self.patched_builtins
|
||||
ns["__torch_package__"] = True
|
||||
|
||||
# Add this module to our private global registry. It should be unique due to mangling.
|
||||
assert module.__name__ not in _package_imported_modules
|
||||
_package_imported_modules[module.__name__] = module
|
||||
|
||||
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
|
||||
# access sys.modules
|
||||
self._install_on_parent(parent, name, module)
|
||||
|
||||
if filename is not None:
|
||||
assert mangled_filename is not None
|
||||
# pre-emptively install the source in `linecache` so that stack traces,
|
||||
# `inspect`, etc. work.
|
||||
assert filename not in linecache.cache # type: ignore[attr-defined]
|
||||
linecache.lazycache(mangled_filename, ns)
|
||||
|
||||
code = self._compile_source(filename, mangled_filename)
|
||||
exec(code, ns)
|
||||
|
||||
return module
|
||||
|
||||
def _load_module(self, name: str, parent: str):
|
||||
cur: _PathNode = self.root
|
||||
for atom in name.split("."):
|
||||
if not isinstance(cur, _PackageNode) or atom not in cur.children:
|
||||
raise ModuleNotFoundError(
|
||||
f'No module named "{name}" in self-contained archive "{self.filename}"'
|
||||
f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
|
||||
name=name,
|
||||
)
|
||||
cur = cur.children[atom]
|
||||
if isinstance(cur, _ExternNode):
|
||||
module = self.modules[name] = importlib.import_module(name)
|
||||
return module
|
||||
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
|
||||
|
||||
def _compile_source(self, fullpath: str, mangled_filename: str):
|
||||
source = self.zip_reader.get_record(fullpath)
|
||||
source = _normalize_line_endings(source)
|
||||
return compile(source, mangled_filename, "exec", dont_inherit=True)
|
||||
|
||||
# note: named `get_source` so that linecache can find the source
|
||||
# when this is the __loader__ of a module.
|
||||
def get_source(self, module_name) -> str:
|
||||
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
|
||||
module = self.import_module(demangle(module_name))
|
||||
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
|
||||
|
||||
# note: named `get_resource_reader` so that importlib.resources can find it.
|
||||
# This is otherwise considered an internal method.
|
||||
def get_resource_reader(self, fullname):
|
||||
try:
|
||||
package = self._get_package(fullname)
|
||||
except ImportError:
|
||||
return None
|
||||
if package.__loader__ is not self:
|
||||
return None
|
||||
return _PackageResourceReader(self, fullname)
|
||||
|
||||
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
|
||||
if not parent:
|
||||
return
|
||||
# Set the module as an attribute on its parent.
|
||||
parent_module = self.modules[parent]
|
||||
if parent_module.__loader__ is self:
|
||||
setattr(parent_module, name.rpartition(".")[2], module)
|
||||
|
||||
# note: copied from cpython's import code, with call to create module replaced with _make_module
|
||||
def _do_find_and_load(self, name):
|
||||
path = None
|
||||
parent = name.rpartition(".")[0]
|
||||
if parent:
|
||||
if parent not in self.modules:
|
||||
self._gcd_import(parent)
|
||||
# Crazy side-effects!
|
||||
if name in self.modules:
|
||||
return self.modules[name]
|
||||
parent_module = self.modules[parent]
|
||||
try:
|
||||
path = parent_module.__path__ # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
|
||||
raise ModuleNotFoundError(msg, name=name) from None
|
||||
|
||||
module = self._load_module(name, parent)
|
||||
|
||||
self._install_on_parent(parent, name, module)
|
||||
|
||||
return module
|
||||
|
||||
# note: copied from cpython's import code
|
||||
def _find_and_load(self, name):
|
||||
module = self.modules.get(name, _NEEDS_LOADING)
|
||||
if module is _NEEDS_LOADING:
|
||||
return self._do_find_and_load(name)
|
||||
|
||||
if module is None:
|
||||
message = "import of {} halted; " "None in sys.modules".format(name)
|
||||
raise ModuleNotFoundError(message, name=name)
|
||||
|
||||
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
|
||||
# creation of fake submodules via the hacking of sys.modules is not import
|
||||
# friendly
|
||||
if name == "os":
|
||||
self.modules["os.path"] = cast(Any, module).path
|
||||
elif name == "typing":
|
||||
self.modules["typing.io"] = cast(Any, module).io
|
||||
self.modules["typing.re"] = cast(Any, module).re
|
||||
|
||||
return module
|
||||
|
||||
def _gcd_import(self, name, package=None, level=0):
|
||||
"""Import and return the module based on its name, the package the call is
|
||||
being made from, and the level adjustment.
|
||||
|
||||
This function represents the greatest common denominator of functionality
|
||||
between import_module and __import__. This includes setting __package__ if
|
||||
the loader did not.
|
||||
|
||||
"""
|
||||
_sanity_check(name, package, level)
|
||||
if level > 0:
|
||||
name = _resolve_name(name, package, level)
|
||||
|
||||
return self._find_and_load(name)
|
||||
|
||||
# note: copied from cpython's import code
|
||||
def _handle_fromlist(self, module, fromlist, *, recursive=False):
|
||||
"""Figure out what __import__ should return.
|
||||
|
||||
The import_ parameter is a callable which takes the name of module to
|
||||
import. It is required to decouple the function from assuming importlib's
|
||||
import implementation is desired.
|
||||
|
||||
"""
|
||||
module_name = demangle(module.__name__)
|
||||
# The hell that is fromlist ...
|
||||
# If a package was imported, try to import stuff from fromlist.
|
||||
if hasattr(module, "__path__"):
|
||||
for x in fromlist:
|
||||
if not isinstance(x, str):
|
||||
if recursive:
|
||||
where = module_name + ".__all__"
|
||||
else:
|
||||
where = "``from list''"
|
||||
raise TypeError(
|
||||
f"Item in {where} must be str, " f"not {type(x).__name__}"
|
||||
)
|
||||
elif x == "*":
|
||||
if not recursive and hasattr(module, "__all__"):
|
||||
self._handle_fromlist(module, module.__all__, recursive=True)
|
||||
elif not hasattr(module, x):
|
||||
from_name = "{}.{}".format(module_name, x)
|
||||
try:
|
||||
self._gcd_import(from_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
# Backwards-compatibility dictates we ignore failed
|
||||
# imports triggered by fromlist for modules that don't
|
||||
# exist.
|
||||
if (
|
||||
exc.name == from_name
|
||||
and self.modules.get(from_name, _NEEDS_LOADING) is not None
|
||||
):
|
||||
continue
|
||||
raise
|
||||
return module
|
||||
|
||||
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if level == 0:
|
||||
module = self._gcd_import(name)
|
||||
else:
|
||||
globals_ = globals if globals is not None else {}
|
||||
package = _calc___package__(globals_)
|
||||
module = self._gcd_import(name, package, level)
|
||||
if not fromlist:
|
||||
# Return up to the first dot in 'name'. This is complicated by the fact
|
||||
# that 'name' may be relative.
|
||||
if level == 0:
|
||||
return self._gcd_import(name.partition(".")[0])
|
||||
elif not name:
|
||||
return module
|
||||
else:
|
||||
# Figure out where to slice the module's name up to the first dot
|
||||
# in 'name'.
|
||||
cut_off = len(name) - len(name.partition(".")[0])
|
||||
# Slice end needs to be positive to alleviate need to special-case
|
||||
# when ``'.' not in name``.
|
||||
module_name = demangle(module.__name__)
|
||||
return self.modules[module_name[: len(module_name) - cut_off]]
|
||||
else:
|
||||
return self._handle_fromlist(module, fromlist)
|
||||
|
||||
def _get_package(self, package):
|
||||
"""Take a package name or module object and return the module.
|
||||
|
||||
If a name, the module is imported. If the passed or imported module
|
||||
object is not a package, raise an exception.
|
||||
"""
|
||||
if hasattr(package, "__spec__"):
|
||||
if package.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package.__spec__.name))
|
||||
else:
|
||||
return package
|
||||
else:
|
||||
module = self.import_module(package)
|
||||
if module.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package))
|
||||
else:
|
||||
return module
|
||||
|
||||
def _zipfile_path(self, package, resource=None):
|
||||
package = self._get_package(package)
|
||||
assert package.__loader__ is self
|
||||
name = demangle(package.__name__)
|
||||
if resource is not None:
|
||||
resource = _normalize_path(resource)
|
||||
return f"{name.replace('.', '/')}/{resource}"
|
||||
else:
|
||||
return f"{name.replace('.', '/')}"
|
||||
|
||||
def _get_or_create_package(
|
||||
self, atoms: List[str]
|
||||
) -> "Union[_PackageNode, _ExternNode]":
|
||||
cur = self.root
|
||||
for i, atom in enumerate(atoms):
|
||||
node = cur.children.get(atom, None)
|
||||
if node is None:
|
||||
node = cur.children[atom] = _PackageNode(None)
|
||||
if isinstance(node, _ExternNode):
|
||||
return node
|
||||
if isinstance(node, _ModuleNode):
|
||||
name = ".".join(atoms[:i])
|
||||
raise ImportError(
|
||||
f"inconsistent module structure. module {name} is not a package, but has submodules"
|
||||
)
|
||||
assert isinstance(node, _PackageNode)
|
||||
cur = node
|
||||
return cur
|
||||
|
||||
def _add_file(self, filename: str):
|
||||
"""Assembles a Python module out of the given file. Will ignore files in the .data directory.
|
||||
|
||||
Args:
|
||||
filename (str): the name of the file inside of the package archive to be added
|
||||
"""
|
||||
*prefix, last = filename.split("/")
|
||||
if len(prefix) > 1 and prefix[0] == ".data":
|
||||
return
|
||||
package = self._get_or_create_package(prefix)
|
||||
if isinstance(package, _ExternNode):
|
||||
raise ImportError(
|
||||
f"inconsistent module structure. package contains a module file {filename}"
|
||||
f" that is a subpackage of a module marked external."
|
||||
)
|
||||
if last == "__init__.py":
|
||||
package.source_file = filename
|
||||
elif last.endswith(".py"):
|
||||
package_name = last[: -len(".py")]
|
||||
package.children[package_name] = _ModuleNode(filename)
|
||||
|
||||
def _add_extern(self, extern_name: str):
|
||||
*prefix, last = extern_name.split(".")
|
||||
package = self._get_or_create_package(prefix)
|
||||
if isinstance(package, _ExternNode):
|
||||
return # the shorter extern covers this extern case
|
||||
package.children[last] = _ExternNode()
|
||||
|
||||
|
||||
_NEEDS_LOADING = object()
|
||||
_ERR_MSG_PREFIX = "No module named "
|
||||
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
|
||||
|
||||
|
||||
class _PathNode:
|
||||
pass
|
||||
|
||||
|
||||
class _PackageNode(_PathNode):
|
||||
def __init__(self, source_file: Optional[str]):
|
||||
self.source_file = source_file
|
||||
self.children: Dict[str, _PathNode] = {}
|
||||
|
||||
|
||||
class _ModuleNode(_PathNode):
|
||||
__slots__ = ["source_file"]
|
||||
|
||||
def __init__(self, source_file: str):
|
||||
self.source_file = source_file
|
||||
|
||||
|
||||
class _ExternNode(_PathNode):
|
||||
pass
|
||||
|
||||
|
||||
# A private global registry of all modules that have been package-imported.
|
||||
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
|
||||
|
||||
# `inspect` by default only looks in `sys.modules` to find source files for classes.
|
||||
# Patch it to check our private registry of package-imported modules as well.
|
||||
_orig_getfile = inspect.getfile
|
||||
|
||||
|
||||
def patched_getfile(object):
|
||||
if inspect.isclass(object):
|
||||
if object.__module__ in _package_imported_modules:
|
||||
return _package_imported_modules[object.__module__].__file__
|
||||
return _orig_getfile(object)
|
||||
|
||||
|
||||
inspect.getfile = patched_getfile
|
||||
|
||||
|
||||
class _PackageResourceReader:
|
||||
"""Private class used to support PackageImporter.get_resource_reader().
|
||||
|
||||
Confirms to the importlib.abc.ResourceReader interface. Allowed to access
|
||||
the innards of PackageImporter.
|
||||
"""
|
||||
|
||||
def __init__(self, importer, fullname):
|
||||
self.importer = importer
|
||||
self.fullname = fullname
|
||||
|
||||
def open_resource(self, resource):
|
||||
from io import BytesIO
|
||||
|
||||
return BytesIO(self.importer.load_binary(self.fullname, resource))
|
||||
|
||||
def resource_path(self, resource):
|
||||
# The contract for resource_path is that it either returns a concrete
|
||||
# file system path or raises FileNotFoundError.
|
||||
if isinstance(
|
||||
self.importer.zip_reader, DirectoryReader
|
||||
) and self.importer.zip_reader.has_record(
|
||||
os.path.join(self.fullname, resource)
|
||||
):
|
||||
return os.path.join(
|
||||
self.importer.zip_reader.directory, self.fullname, resource
|
||||
)
|
||||
raise FileNotFoundError
|
||||
|
||||
def is_resource(self, name):
|
||||
path = self.importer._zipfile_path(self.fullname, name)
|
||||
return self.importer.zip_reader.has_record(path)
|
||||
|
||||
def contents(self):
|
||||
from pathlib import Path
|
||||
|
||||
filename = self.fullname.replace(".", "/")
|
||||
|
||||
fullname_path = Path(self.importer._zipfile_path(self.fullname))
|
||||
files = self.importer.zip_reader.get_all_records()
|
||||
subdirs_seen = set()
|
||||
for filename in files:
|
||||
try:
|
||||
relative = Path(filename).relative_to(fullname_path)
|
||||
except ValueError:
|
||||
continue
|
||||
# If the path of the file (which is relative to the top of the zip
|
||||
# namespace), relative to the package given when the resource
|
||||
# reader was created, has a parent, then it's a name in a
|
||||
# subdirectory and thus we skip it.
|
||||
parent_name = relative.parent.name
|
||||
if len(parent_name) == 0:
|
||||
yield relative.name
|
||||
elif parent_name not in subdirs_seen:
|
||||
subdirs_seen.add(parent_name)
|
||||
yield parent_name
|
||||
|
||||
@ -1,672 +0,0 @@
|
||||
import builtins
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import linecache
|
||||
import os.path
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import cast, Any, BinaryIO, Callable, Dict, List, Optional, Union, Type
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from ._directory_reader import DirectoryReader
|
||||
from ._importlib import (
|
||||
_calc___package__,
|
||||
_normalize_line_endings,
|
||||
_normalize_path,
|
||||
_resolve_name,
|
||||
_sanity_check,
|
||||
)
|
||||
from ._mangling import PackageMangler, demangle
|
||||
from ._package_unpickler import PackageUnpickler
|
||||
from ._zip_file import PackageZipFileReader, DefaultPackageZipFileReader
|
||||
from .file_structure_representation import Directory, _create_directory_from_file_list
|
||||
from .glob_group import GlobPattern
|
||||
from .importer import Importer
|
||||
|
||||
|
||||
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
||||
# When using encoding='bytes' in Py3, some **internal** keys stored as
|
||||
# strings in Py2 are loaded as bytes. This function decodes them with
|
||||
# ascii encoding, one that Py3 uses by default.
|
||||
#
|
||||
# NOTE: This should only be used on internal keys (e.g., `typename` and
|
||||
# `location` in `persistent_load` below!
|
||||
if isinstance(bytes_str, bytes):
|
||||
return bytes_str.decode("ascii")
|
||||
return bytes_str
|
||||
|
||||
|
||||
class PackageImporter(Importer):
|
||||
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
|
||||
Code is loaded in a hermetic way, using files from the package
|
||||
rather than the normal python import system. This allows
|
||||
for the packaging of PyTorch model code and data so that it can be run
|
||||
on a server or used in the future for transfer learning.
|
||||
|
||||
The importer for packages ensures that code in the module can only be loaded from
|
||||
within the package, except for modules explicitly listed as external during export.
|
||||
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
|
||||
This prevents "implicit" dependencies where the package runs locally because it is importing
|
||||
a locally-installed package, but then fails when the package is copied to another machine.
|
||||
"""
|
||||
|
||||
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
|
||||
local to this importer.
|
||||
"""
|
||||
modules: Dict[str, types.ModuleType]
|
||||
|
||||
def get_zip_reader(self, file_or_buffer):
|
||||
zip_reader: Any
|
||||
if isinstance(file_or_buffer, DefaultPackageZipFileReader):
|
||||
filename = "<pytorch_file_reader>"
|
||||
zip_reader = file_or_buffer
|
||||
elif isinstance(file_or_buffer, (Path, str)):
|
||||
filename = str(file_or_buffer)
|
||||
if not os.path.isdir(filename):
|
||||
zip_reader = DefaultPackageZipFileReader(filename)
|
||||
else:
|
||||
zip_reader = DirectoryReader(filename)
|
||||
else:
|
||||
filename = "<binary>"
|
||||
zip_reader = DefaultPackageZipFileReader(file_or_buffer)
|
||||
return filename, zip_reader
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_or_buffer: Union[str, Path, BinaryIO],
|
||||
module_allowed: Callable[[str], bool] = lambda module_name: True,
|
||||
zip_file_reader_type: Type[PackageZipFileReader] = DefaultPackageZipFileReader,
|
||||
):
|
||||
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
|
||||
allowed by ``module_allowed``
|
||||
Args:
|
||||
file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
|
||||
a string, or an ``os.PathLike`` object containing a filename.
|
||||
module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
|
||||
should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
|
||||
does not support. Defaults to allowing anything.
|
||||
zip_file_writer_type: A subclass of PackageZipFileReader which would be used to instantiate the zip file reader
|
||||
Raises:
|
||||
ImportError: If the package will use a disallowed module.
|
||||
"""
|
||||
|
||||
self.zip_reader = zip_file_reader_type(file_or_buffer)
|
||||
|
||||
self.root = _PackageNode(None)
|
||||
self.modules = {}
|
||||
self.extern_modules = self._read_extern()
|
||||
|
||||
for extern_module in self.extern_modules:
|
||||
if not module_allowed(extern_module):
|
||||
raise ImportError(
|
||||
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
|
||||
f"but that module has been disallowed"
|
||||
)
|
||||
self._add_extern(extern_module)
|
||||
|
||||
for fname in self.zip_reader.get_all_records():
|
||||
self._add_file(fname)
|
||||
|
||||
self.patched_builtins = builtins.__dict__.copy()
|
||||
self.patched_builtins["__import__"] = self.__import__
|
||||
# Allow packaged modules to reference their PackageImporter
|
||||
self.modules["torch_package_importer"] = self # type: ignore[assignment]
|
||||
|
||||
self._mangler = PackageMangler()
|
||||
|
||||
# used for reduce deserializaiton
|
||||
self.storage_context: Any = None
|
||||
self.last_map_location = None
|
||||
|
||||
# used for torch.serialization._load
|
||||
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
|
||||
|
||||
def import_module(self, name: str, package=None):
|
||||
"""Load a module from the package if it hasn't already been loaded, and then return
|
||||
the module. Modules are loaded locally
|
||||
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
|
||||
|
||||
Args:
|
||||
name (str): Fully qualified name of the module to load.
|
||||
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
types.ModuleType: The (possibly already) loaded module.
|
||||
"""
|
||||
# We should always be able to support importing modules from this package.
|
||||
# This is to support something like:
|
||||
# obj = importer.load_pickle(...)
|
||||
# importer.import_module(obj.__module__) <- this string will be mangled
|
||||
#
|
||||
# Note that _mangler.demangle will not demangle any module names
|
||||
# produced by a different PackageImporter instance.
|
||||
name = self._mangler.demangle(name)
|
||||
|
||||
return self._gcd_import(name)
|
||||
|
||||
def load_binary(self, package: str, resource: str) -> bytes:
|
||||
"""Load raw bytes.
|
||||
|
||||
Args:
|
||||
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
|
||||
resource (str): The unique name for the resource.
|
||||
|
||||
Returns:
|
||||
bytes: The loaded data.
|
||||
"""
|
||||
|
||||
path = self._zipfile_path(package, resource)
|
||||
return self.zip_reader.get_record(path)
|
||||
|
||||
def load_text(
|
||||
self,
|
||||
package: str,
|
||||
resource: str,
|
||||
encoding: str = "utf-8",
|
||||
errors: str = "strict",
|
||||
) -> str:
|
||||
"""Load a string.
|
||||
|
||||
Args:
|
||||
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
|
||||
resource (str): The unique name for the resource.
|
||||
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
|
||||
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
|
||||
|
||||
Returns:
|
||||
str: The loaded text.
|
||||
"""
|
||||
data = self.load_binary(package, resource)
|
||||
return data.decode(encoding, errors)
|
||||
|
||||
def persistent_load(self, typename, data):
|
||||
# meant to be overwritten
|
||||
return None
|
||||
|
||||
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
|
||||
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
|
||||
using :meth:`import_module`.
|
||||
|
||||
Args:
|
||||
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
|
||||
resource (str): The unique name for the resource.
|
||||
map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
Any: The unpickled object.
|
||||
"""
|
||||
pickle_file = self._zipfile_path(package, resource)
|
||||
loaded_reduces = {}
|
||||
|
||||
def _persistent_load(saved_id):
|
||||
assert isinstance(saved_id, tuple)
|
||||
typename = _maybe_decode_ascii(saved_id[0])
|
||||
data = saved_id[1:]
|
||||
module = self.persistent_load(typename, data)
|
||||
if module is not None:
|
||||
return module
|
||||
if typename == "reduce_package":
|
||||
# to fix BC breaking change, objects on this load path
|
||||
# will be loaded multiple times erroneously
|
||||
if len(data) == 2:
|
||||
func, args = data
|
||||
return func(self, *args)
|
||||
reduce_id, func, args = data
|
||||
if reduce_id not in loaded_reduces:
|
||||
loaded_reduces[reduce_id] = func(self, *args)
|
||||
return loaded_reduces[reduce_id]
|
||||
else:
|
||||
f"Unknown typename for persistent_load, got '{typename}'"
|
||||
|
||||
# Load the data (which may in turn use `persistent_load` to load tensors)
|
||||
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
|
||||
unpickler = self.Unpickler(data_file)
|
||||
unpickler.persistent_load = _persistent_load
|
||||
result = unpickler.load()
|
||||
|
||||
return result
|
||||
|
||||
def id(self):
|
||||
"""
|
||||
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
|
||||
Looks like::
|
||||
|
||||
<torch_package_0>
|
||||
"""
|
||||
return self._mangler.parent_name()
|
||||
|
||||
def file_structure(
|
||||
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
|
||||
) -> Directory:
|
||||
"""Returns a file structure representation of package's zipfile.
|
||||
|
||||
Args:
|
||||
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
|
||||
for the names of the files to be inluded in the zipfile representation. This can also be
|
||||
a glob-style pattern, as described in :meth:`PackageExporter.mock`
|
||||
|
||||
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
|
||||
|
||||
Returns:
|
||||
:class:`Directory`
|
||||
"""
|
||||
return _create_directory_from_file_list(
|
||||
self.zip_reader.get_filename(),
|
||||
self.zip_reader.get_all_records(),
|
||||
include,
|
||||
exclude,
|
||||
)
|
||||
|
||||
def python_version(self):
|
||||
"""Returns the version of python that was used to create this package.
|
||||
|
||||
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
|
||||
file later on.
|
||||
|
||||
Returns:
|
||||
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
|
||||
"""
|
||||
python_version_path = ".data/python_version"
|
||||
return (
|
||||
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
|
||||
if self.zip_reader.has_record(python_version_path)
|
||||
else None
|
||||
)
|
||||
|
||||
def _read_extern(self):
|
||||
return (
|
||||
self.zip_reader.get_record(".data/extern_modules")
|
||||
.decode("utf-8")
|
||||
.splitlines(keepends=False)
|
||||
)
|
||||
|
||||
def _make_module(
|
||||
self, name: str, filename: Optional[str], is_package: bool, parent: str
|
||||
):
|
||||
mangled_filename = self._mangler.mangle(filename) if filename else None
|
||||
spec = importlib.machinery.ModuleSpec(
|
||||
name,
|
||||
self, # type: ignore[arg-type]
|
||||
origin="<package_importer>",
|
||||
is_package=is_package,
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
self.modules[name] = module
|
||||
module.__name__ = self._mangler.mangle(name)
|
||||
ns = module.__dict__
|
||||
ns["__spec__"] = spec
|
||||
ns["__loader__"] = self
|
||||
ns["__file__"] = mangled_filename
|
||||
ns["__cached__"] = None
|
||||
ns["__builtins__"] = self.patched_builtins
|
||||
ns["__torch_package__"] = True
|
||||
|
||||
# Add this module to our private global registry. It should be unique due to mangling.
|
||||
assert module.__name__ not in _package_imported_modules
|
||||
_package_imported_modules[module.__name__] = module
|
||||
|
||||
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
|
||||
# access sys.modules
|
||||
self._install_on_parent(parent, name, module)
|
||||
|
||||
if filename is not None:
|
||||
assert mangled_filename is not None
|
||||
# pre-emptively install the source in `linecache` so that stack traces,
|
||||
# `inspect`, etc. work.
|
||||
assert filename not in linecache.cache # type: ignore[attr-defined]
|
||||
linecache.lazycache(mangled_filename, ns)
|
||||
|
||||
code = self._compile_source(filename, mangled_filename)
|
||||
exec(code, ns)
|
||||
|
||||
return module
|
||||
|
||||
def _load_module(self, name: str, parent: str):
|
||||
cur: _PathNode = self.root
|
||||
for atom in name.split("."):
|
||||
if not isinstance(cur, _PackageNode) or atom not in cur.children:
|
||||
raise ModuleNotFoundError(
|
||||
f'No module named "{name}" in self-contained archive "{self.zip_reader.get_filename()}"'
|
||||
f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
|
||||
name=name,
|
||||
)
|
||||
cur = cur.children[atom]
|
||||
if isinstance(cur, _ExternNode):
|
||||
module = self.modules[name] = importlib.import_module(name)
|
||||
return module
|
||||
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
|
||||
|
||||
def _compile_source(self, fullpath: str, mangled_filename: str):
|
||||
source = self.zip_reader.get_record(fullpath)
|
||||
source = _normalize_line_endings(source)
|
||||
return compile(source, mangled_filename, "exec", dont_inherit=True)
|
||||
|
||||
# note: named `get_source` so that linecache can find the source
|
||||
# when this is the __loader__ of a module.
|
||||
def get_source(self, module_name) -> str:
|
||||
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
|
||||
module = self.import_module(demangle(module_name))
|
||||
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
|
||||
|
||||
# note: named `get_resource_reader` so that importlib.resources can find it.
|
||||
# This is otherwise considered an internal method.
|
||||
def get_resource_reader(self, fullname):
|
||||
try:
|
||||
package = self._get_package(fullname)
|
||||
except ImportError:
|
||||
return None
|
||||
if package.__loader__ is not self:
|
||||
return None
|
||||
return _PackageResourceReader(self, fullname)
|
||||
|
||||
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
|
||||
if not parent:
|
||||
return
|
||||
# Set the module as an attribute on its parent.
|
||||
parent_module = self.modules[parent]
|
||||
if parent_module.__loader__ is self:
|
||||
setattr(parent_module, name.rpartition(".")[2], module)
|
||||
|
||||
# note: copied from cpython's import code, with call to create module replaced with _make_module
|
||||
def _do_find_and_load(self, name):
|
||||
path = None
|
||||
parent = name.rpartition(".")[0]
|
||||
if parent:
|
||||
if parent not in self.modules:
|
||||
self._gcd_import(parent)
|
||||
# Crazy side-effects!
|
||||
if name in self.modules:
|
||||
return self.modules[name]
|
||||
parent_module = self.modules[parent]
|
||||
try:
|
||||
path = parent_module.__path__ # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
|
||||
raise ModuleNotFoundError(msg, name=name) from None
|
||||
|
||||
module = self._load_module(name, parent)
|
||||
|
||||
self._install_on_parent(parent, name, module)
|
||||
|
||||
return module
|
||||
|
||||
# note: copied from cpython's import code
|
||||
def _find_and_load(self, name):
|
||||
module = self.modules.get(name, _NEEDS_LOADING)
|
||||
if module is _NEEDS_LOADING:
|
||||
return self._do_find_and_load(name)
|
||||
|
||||
if module is None:
|
||||
message = "import of {} halted; " "None in sys.modules".format(name)
|
||||
raise ModuleNotFoundError(message, name=name)
|
||||
|
||||
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
|
||||
# creation of fake submodules via the hacking of sys.modules is not import
|
||||
# friendly
|
||||
if name == "os":
|
||||
self.modules["os.path"] = cast(Any, module).path
|
||||
elif name == "typing":
|
||||
self.modules["typing.io"] = cast(Any, module).io
|
||||
self.modules["typing.re"] = cast(Any, module).re
|
||||
|
||||
return module
|
||||
|
||||
def _gcd_import(self, name, package=None, level=0):
|
||||
"""Import and return the module based on its name, the package the call is
|
||||
being made from, and the level adjustment.
|
||||
|
||||
This function represents the greatest common denominator of functionality
|
||||
between import_module and __import__. This includes setting __package__ if
|
||||
the loader did not.
|
||||
|
||||
"""
|
||||
_sanity_check(name, package, level)
|
||||
if level > 0:
|
||||
name = _resolve_name(name, package, level)
|
||||
|
||||
return self._find_and_load(name)
|
||||
|
||||
# note: copied from cpython's import code
|
||||
def _handle_fromlist(self, module, fromlist, *, recursive=False):
|
||||
"""Figure out what __import__ should return.
|
||||
|
||||
The import_ parameter is a callable which takes the name of module to
|
||||
import. It is required to decouple the function from assuming importlib's
|
||||
import implementation is desired.
|
||||
|
||||
"""
|
||||
module_name = demangle(module.__name__)
|
||||
# The hell that is fromlist ...
|
||||
# If a package was imported, try to import stuff from fromlist.
|
||||
if hasattr(module, "__path__"):
|
||||
for x in fromlist:
|
||||
if not isinstance(x, str):
|
||||
if recursive:
|
||||
where = module_name + ".__all__"
|
||||
else:
|
||||
where = "``from list''"
|
||||
raise TypeError(
|
||||
f"Item in {where} must be str, " f"not {type(x).__name__}"
|
||||
)
|
||||
elif x == "*":
|
||||
if not recursive and hasattr(module, "__all__"):
|
||||
self._handle_fromlist(module, module.__all__, recursive=True)
|
||||
elif not hasattr(module, x):
|
||||
from_name = "{}.{}".format(module_name, x)
|
||||
try:
|
||||
self._gcd_import(from_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
# Backwards-compatibility dictates we ignore failed
|
||||
# imports triggered by fromlist for modules that don't
|
||||
# exist.
|
||||
if (
|
||||
exc.name == from_name
|
||||
and self.modules.get(from_name, _NEEDS_LOADING) is not None
|
||||
):
|
||||
continue
|
||||
raise
|
||||
return module
|
||||
|
||||
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if level == 0:
|
||||
module = self._gcd_import(name)
|
||||
else:
|
||||
globals_ = globals if globals is not None else {}
|
||||
package = _calc___package__(globals_)
|
||||
module = self._gcd_import(name, package, level)
|
||||
if not fromlist:
|
||||
# Return up to the first dot in 'name'. This is complicated by the fact
|
||||
# that 'name' may be relative.
|
||||
if level == 0:
|
||||
return self._gcd_import(name.partition(".")[0])
|
||||
elif not name:
|
||||
return module
|
||||
else:
|
||||
# Figure out where to slice the module's name up to the first dot
|
||||
# in 'name'.
|
||||
cut_off = len(name) - len(name.partition(".")[0])
|
||||
# Slice end needs to be positive to alleviate need to special-case
|
||||
# when ``'.' not in name``.
|
||||
module_name = demangle(module.__name__)
|
||||
return self.modules[module_name[: len(module_name) - cut_off]]
|
||||
else:
|
||||
return self._handle_fromlist(module, fromlist)
|
||||
|
||||
def _get_package(self, package):
|
||||
"""Take a package name or module object and return the module.
|
||||
|
||||
If a name, the module is imported. If the passed or imported module
|
||||
object is not a package, raise an exception.
|
||||
"""
|
||||
if hasattr(package, "__spec__"):
|
||||
if package.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package.__spec__.name))
|
||||
else:
|
||||
return package
|
||||
else:
|
||||
module = self.import_module(package)
|
||||
if module.__spec__.submodule_search_locations is None:
|
||||
raise TypeError("{!r} is not a package".format(package))
|
||||
else:
|
||||
return module
|
||||
|
||||
def _zipfile_path(self, package, resource=None):
|
||||
package = self._get_package(package)
|
||||
assert package.__loader__ is self
|
||||
name = demangle(package.__name__)
|
||||
if resource is not None:
|
||||
resource = _normalize_path(resource)
|
||||
return f"{name.replace('.', '/')}/{resource}"
|
||||
else:
|
||||
return f"{name.replace('.', '/')}"
|
||||
|
||||
def _get_or_create_package(
|
||||
self, atoms: List[str]
|
||||
) -> "Union[_PackageNode, _ExternNode]":
|
||||
cur = self.root
|
||||
for i, atom in enumerate(atoms):
|
||||
node = cur.children.get(atom, None)
|
||||
if node is None:
|
||||
node = cur.children[atom] = _PackageNode(None)
|
||||
if isinstance(node, _ExternNode):
|
||||
return node
|
||||
if isinstance(node, _ModuleNode):
|
||||
name = ".".join(atoms[:i])
|
||||
raise ImportError(
|
||||
f"inconsistent module structure. module {name} is not a package, but has submodules"
|
||||
)
|
||||
assert isinstance(node, _PackageNode)
|
||||
cur = node
|
||||
return cur
|
||||
|
||||
def _add_file(self, filename: str):
|
||||
"""Assembles a Python module out of the given file. Will ignore files in the .data directory.
|
||||
|
||||
Args:
|
||||
filename (str): the name of the file inside of the package archive to be added
|
||||
"""
|
||||
*prefix, last = filename.split("/")
|
||||
if len(prefix) > 1 and prefix[0] == ".data":
|
||||
return
|
||||
package = self._get_or_create_package(prefix)
|
||||
if isinstance(package, _ExternNode):
|
||||
raise ImportError(
|
||||
f"inconsistent module structure. package contains a module file {filename}"
|
||||
f" that is a subpackage of a module marked external."
|
||||
)
|
||||
if last == "__init__.py":
|
||||
package.source_file = filename
|
||||
elif last.endswith(".py"):
|
||||
package_name = last[: -len(".py")]
|
||||
package.children[package_name] = _ModuleNode(filename)
|
||||
|
||||
def _add_extern(self, extern_name: str):
|
||||
*prefix, last = extern_name.split(".")
|
||||
package = self._get_or_create_package(prefix)
|
||||
if isinstance(package, _ExternNode):
|
||||
return # the shorter extern covers this extern case
|
||||
package.children[last] = _ExternNode()
|
||||
|
||||
|
||||
_NEEDS_LOADING = object()
|
||||
_ERR_MSG_PREFIX = "No module named "
|
||||
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
|
||||
|
||||
|
||||
class _PathNode:
|
||||
pass
|
||||
|
||||
|
||||
class _PackageNode(_PathNode):
|
||||
def __init__(self, source_file: Optional[str]):
|
||||
self.source_file = source_file
|
||||
self.children: Dict[str, _PathNode] = {}
|
||||
|
||||
|
||||
class _ModuleNode(_PathNode):
|
||||
__slots__ = ["source_file"]
|
||||
|
||||
def __init__(self, source_file: str):
|
||||
self.source_file = source_file
|
||||
|
||||
|
||||
class _ExternNode(_PathNode):
|
||||
pass
|
||||
|
||||
|
||||
# A private global registry of all modules that have been package-imported.
|
||||
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
|
||||
|
||||
# `inspect` by default only looks in `sys.modules` to find source files for classes.
|
||||
# Patch it to check our private registry of package-imported modules as well.
|
||||
_orig_getfile = inspect.getfile
|
||||
|
||||
|
||||
def patched_getfile(object):
|
||||
if inspect.isclass(object):
|
||||
if object.__module__ in _package_imported_modules:
|
||||
return _package_imported_modules[object.__module__].__file__
|
||||
return _orig_getfile(object)
|
||||
|
||||
|
||||
inspect.getfile = patched_getfile
|
||||
|
||||
|
||||
class _PackageResourceReader:
|
||||
"""Private class used to support PackageImporter.get_resource_reader().
|
||||
|
||||
Confirms to the importlib.abc.ResourceReader interface. Allowed to access
|
||||
the innards of PackageImporter.
|
||||
"""
|
||||
|
||||
def __init__(self, importer, fullname):
|
||||
self.importer = importer
|
||||
self.fullname = fullname
|
||||
|
||||
def open_resource(self, resource):
|
||||
from io import BytesIO
|
||||
|
||||
return BytesIO(self.importer.load_binary(self.fullname, resource))
|
||||
|
||||
def resource_path(self, resource):
|
||||
# The contract for resource_path is that it either returns a concrete
|
||||
# file system path or raises FileNotFoundError.
|
||||
if (
|
||||
self.importer.zip_reader.is_directory()
|
||||
and self.importer.zip_reader.has_record(
|
||||
os.path.join(self.fullname, resource)
|
||||
)
|
||||
):
|
||||
return os.path.join(
|
||||
self.importer.zip_reader.get_filename(), self.fullname, resource
|
||||
)
|
||||
raise FileNotFoundError
|
||||
|
||||
def is_resource(self, name):
|
||||
path = self.importer._zipfile_path(self.fullname, name)
|
||||
return self.importer.zip_reader.has_record(path)
|
||||
|
||||
def contents(self):
|
||||
from pathlib import Path
|
||||
|
||||
filename = self.fullname.replace(".", "/")
|
||||
|
||||
fullname_path = Path(self.importer._zipfile_path(self.fullname))
|
||||
files = self.importer.zip_reader.get_all_records()
|
||||
subdirs_seen = set()
|
||||
for filename in files:
|
||||
try:
|
||||
relative = Path(filename).relative_to(fullname_path)
|
||||
except ValueError:
|
||||
continue
|
||||
# If the path of the file (which is relative to the top of the zip
|
||||
# namespace), relative to the package given when the resource
|
||||
# reader was created, has a parent, then it's a name in a
|
||||
# subdirectory and thus we skip it.
|
||||
parent_name = relative.parent.name
|
||||
if len(parent_name) == 0:
|
||||
yield relative.name
|
||||
elif parent_name not in subdirs_seen:
|
||||
subdirs_seen.add(parent_name)
|
||||
yield parent_name
|
||||
Reference in New Issue
Block a user