Revert D33970688: [pkg] add generic ZipFile Reader/Writer

Test Plan: revert-hammer

Differential Revision:
D33970688 (c2c260bfc3)

Original commit changeset: 8a524867e62a

Original Phabricator Diff: D33970688 (c2c260bfc3)

fbshipit-source-id: 18b4aa4e221b86a498fc434c1b453356fc47cfbf
(cherry picked from commit a295c2b58d3d9cfacfc9d11d36fd80aabd97675c)
This commit is contained in:
Natalia Gimelshein
2022-04-05 22:46:00 -07:00
committed by PyTorch MergeBot
parent 20266f054b
commit 00e2c14b78
24 changed files with 1882 additions and 2514 deletions

View File

@ -6,19 +6,8 @@ from sys import version_info
from textwrap import dedent
from unittest import skipIf
from torch.package import (
EmptyMatchError,
Importer,
PackageExporter,
PackageImporter,
PackagingError,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
from torch.package.package_exporter import PackagingError
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
try:
@ -35,18 +24,13 @@ class TestDependencyAPI(PackageTestCase):
- deny()
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_extern(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.extern(["package_a.subpackage", "module_a"])
he.save_source_string("foo", "import package_a.subpackage; import module_a")
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
import module_a
import package_a.subpackage
@ -60,7 +44,7 @@ class TestDependencyAPI(PackageTestCase):
def test_extern_glob(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.extern(["package_a.*", "module_*"])
he.save_module("package_a")
he.save_source_string(
@ -73,7 +57,7 @@ class TestDependencyAPI(PackageTestCase):
),
)
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
import module_a
import package_a.subpackage
@ -94,7 +78,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.extern(include=["package_b.*"], allow_empty=False)
exporter.save_module("package_a.subpackage")
@ -105,7 +89,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaisesRegex(PackagingError, "denied"):
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.deny(["package_a.subpackage", "module_a"])
exporter.save_source_string("foo", "import package_a.subpackage")
@ -115,7 +99,7 @@ class TestDependencyAPI(PackageTestCase):
"""
buffer = BytesIO()
with self.assertRaises(PackagingError):
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.deny(["package_a.*", "module_*"])
exporter.save_source_string(
"test_module",
@ -130,12 +114,12 @@ class TestDependencyAPI(PackageTestCase):
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.mock(["package_a.subpackage", "module_a"])
# Import something that dependso n package_a.subpackage
he.save_source_string("foo", "import package_a.subpackage")
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
import package_a.subpackage
_ = package_a.subpackage
@ -151,7 +135,7 @@ class TestDependencyAPI(PackageTestCase):
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock_glob(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.mock(["package_a.*", "module*"])
he.save_module("package_a")
he.save_source_string(
@ -164,7 +148,7 @@ class TestDependencyAPI(PackageTestCase):
),
)
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
import package_a.subpackage
_ = package_a.subpackage
@ -186,7 +170,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.mock(include=["package_b.*"], allow_empty=False)
exporter.save_module("package_a.subpackage")
@ -199,7 +183,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaises(PackagingError):
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.mock(include="package_a.subpackage")
he.intern("**")
he.save_pickle("obj", "obj.pkl", obj2)
@ -212,7 +196,7 @@ class TestDependencyAPI(PackageTestCase):
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.intern(include="package_a.**")
he.mock("**")
he.save_pickle("obj", "obj.pkl", obj2)
@ -221,7 +205,7 @@ class TestDependencyAPI(PackageTestCase):
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
buffer = BytesIO()
with self.assertRaises(ModuleNotFoundError):
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
# Even though we did not extern a module that matches this
# pattern, we want to show the save_module error, not the allow_empty error.
@ -238,7 +222,7 @@ class TestDependencyAPI(PackageTestCase):
import package_a # noqa: F401
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.save_module("package_a")
def test_intern_error(self):
@ -251,7 +235,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.save_pickle("obj", "obj.pkl", obj2)
self.assertEqual(
@ -266,7 +250,7 @@ class TestDependencyAPI(PackageTestCase):
)
# Interning all dependencies should work
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.intern(["package_a", "package_a.subpackage"])
he.save_pickle("obj", "obj.pkl", obj2)
@ -297,7 +281,7 @@ class TestDependencyAPI(PackageTestCase):
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with self.PackageExporter(buffer, importer=BrokenImporter()) as exporter:
with PackageExporter(buffer, importer=BrokenImporter()) as exporter:
exporter.intern(["foo", "bar"])
exporter.save_source_string("my_module", "import foo; import bar")
@ -316,7 +300,7 @@ class TestDependencyAPI(PackageTestCase):
"""An incorrectly-formed import should raise a PackagingError."""
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
# This import will fail to load.
exporter.save_source_string("foo", "from ........ import lol")
@ -335,12 +319,12 @@ class TestDependencyAPI(PackageTestCase):
def test_repackage_mocked_module(self):
"""Re-packaging a package that contains a mocked module should work correctly."""
buffer = BytesIO()
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.mock("package_a")
exporter.save_source_string("foo", "import package_a")
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
foo = importer.import_module("foo")
# "package_a" should be mocked out.
@ -350,13 +334,13 @@ class TestDependencyAPI(PackageTestCase):
# Re-package the model, but intern the previously-mocked module and mock
# everything else.
buffer2 = BytesIO()
with self.PackageExporter(buffer2, importer=importer) as exporter:
with PackageExporter(buffer2, importer=importer) as exporter:
exporter.intern("package_a")
exporter.mock("**")
exporter.save_source_string("foo", "import package_a")
buffer2.seek(0)
importer2 = self.PackageImporter(buffer2)
importer2 = PackageImporter(buffer2)
foo2 = importer2.import_module("foo")
# "package_a" should still be mocked out.
@ -364,12 +348,5 @@ class TestDependencyAPI(PackageTestCase):
foo2.package_a.get_something()
class TestDependencyAPINoTorch(TestDependencyAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -5,9 +5,6 @@ from io import BytesIO
from torch.package import (
PackageExporter,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
@ -23,10 +20,6 @@ class TestDependencyHooks(PackageTestCase):
- register_extern_hook()
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageExporter = PackageExporter
def test_single_hook(self):
buffer = BytesIO()
@ -35,7 +28,7 @@ class TestDependencyHooks(PackageTestCase):
def my_extern_hook(package_exporter, module_name):
my_externs.add(module_name)
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.extern(["package_a.subpackage", "module_a"])
exporter.register_extern_hook(my_extern_hook)
exporter.save_source_string("foo", "import module_a")
@ -54,7 +47,7 @@ class TestDependencyHooks(PackageTestCase):
def my_extern_hook2(package_exporter, module_name):
my_externs.remove(module_name)
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.extern(["package_a.subpackage", "module_a"])
exporter.register_extern_hook(my_extern_hook)
exporter.register_extern_hook(my_extern_hook2)
@ -74,7 +67,7 @@ class TestDependencyHooks(PackageTestCase):
def my_mock_hook2(package_exporter, module_name):
my_mocks.remove(module_name)
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.mock(["package_a.subpackage", "module_a"])
exporter.register_mock_hook(my_mock_hook)
exporter.register_mock_hook(my_mock_hook2)
@ -94,7 +87,7 @@ class TestDependencyHooks(PackageTestCase):
def my_extern_hook2(package_exporter, module_name):
my_externs2.add(module_name)
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.extern(["package_a.subpackage", "module_a"])
handle = exporter.register_extern_hook(my_extern_hook)
exporter.register_extern_hook(my_extern_hook2)
@ -116,7 +109,7 @@ class TestDependencyHooks(PackageTestCase):
def my_mock_hook(package_exporter, module_name):
my_mocks.add(module_name)
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.extern("module_a")
exporter.mock("package_a")
exporter.register_extern_hook(my_extern_hook)
@ -127,11 +120,5 @@ class TestDependencyHooks(PackageTestCase):
self.assertEqual(my_mocks, set(["package_a"]))
class TestDependencyHooksNoTorch(TestDependencyHooks):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -116,7 +116,7 @@ class TestDiGraph(PackageTestCase):
result = g.all_paths("1", "3")
# to get rid of indeterminism
actual = {i.strip("\n") for i in result.split(";")[2:-1]}
actual = set([i.strip("\n") for i in result.split(";")[2:-1]])
expected = {
'"2" -> "3"',
'"1" -> "7"',

View File

@ -10,12 +10,6 @@ from unittest import skipIf
import torch
from torch.package import PackageExporter, PackageImporter
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import (
run_tests,
IS_FBCODE,
@ -50,11 +44,6 @@ packaging_directory = Path(__file__).parent
class DirectoryReaderTest(PackageTestCase):
"""Tests use of DirectoryReader as accessor for opened packages."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageExporter = PackageExporter
self.PackageImporter = PackageImporter
@skipIfNoTorchVision
def test_loading_pickle(self):
"""
@ -63,7 +52,7 @@ class DirectoryReaderTest(PackageTestCase):
resnet = resnet18()
filename = self.temp()
with self.PackageExporter(filename) as e:
with PackageExporter(filename) as e:
e.intern("**")
e.save_pickle("model", "model.pkl", resnet)
@ -71,7 +60,7 @@ class DirectoryReaderTest(PackageTestCase):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
dir_mod = importer.load_pickle("model", "model.pkl")
input = torch.rand(1, 3, 224, 224)
self.assertEqual(dir_mod(input), resnet(input))
@ -83,14 +72,14 @@ class DirectoryReaderTest(PackageTestCase):
import package_a
filename = self.temp()
with self.PackageExporter(filename) as e:
with PackageExporter(filename) as e:
e.save_module("package_a")
zip_file = zipfile.ZipFile(filename, "r")
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
dir_mod = dir_importer.import_module("package_a")
self.assertEqual(dir_mod.result, package_a.result)
@ -101,14 +90,14 @@ class DirectoryReaderTest(PackageTestCase):
import package_a # noqa: F401
filename = self.temp()
with self.PackageExporter(filename) as e:
with PackageExporter(filename) as e:
e.save_module("package_a")
zip_file = zipfile.ZipFile(filename, "r")
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
self.assertFalse(dir_importer.zip_reader.has_record("package_a"))
@ -116,7 +105,7 @@ class DirectoryReaderTest(PackageTestCase):
def test_resource_reader(self):
"""Tests DirectoryReader as the base for get_resource_reader."""
filename = self.temp()
with self.PackageExporter(filename) as pe:
with PackageExporter(filename) as pe:
# Layout looks like:
# package
# ├── one/
@ -143,7 +132,7 @@ class DirectoryReaderTest(PackageTestCase):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
reader_one = importer.get_resource_reader("one")
# Different behavior from still zipped archives
@ -198,7 +187,7 @@ class DirectoryReaderTest(PackageTestCase):
"""
)
filename = self.temp()
with self.PackageExporter(filename) as pe:
with PackageExporter(filename) as pe:
pe.save_source_string("foo.bar", mod_src)
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
@ -206,7 +195,7 @@ class DirectoryReaderTest(PackageTestCase):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
self.assertEqual(
dir_importer.import_module("foo.bar").secret_message(),
"my sekrit plays",
@ -215,7 +204,7 @@ class DirectoryReaderTest(PackageTestCase):
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
def test_importer_access(self):
filename = self.temp()
with self.PackageExporter(filename) as he:
with PackageExporter(filename) as he:
he.save_text("main", "main", "my string")
he.save_binary("main", "main_binary", "my string".encode("utf-8"))
src = dedent(
@ -233,7 +222,7 @@ class DirectoryReaderTest(PackageTestCase):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
m = dir_importer.import_module("main")
self.assertEqual(m.t, "my string")
self.assertEqual(m.b, "my string".encode("utf-8"))
@ -244,7 +233,7 @@ class DirectoryReaderTest(PackageTestCase):
Tests that packaged code can used importlib.resources.path.
"""
filename = self.temp()
with self.PackageExporter(filename) as e:
with PackageExporter(filename) as e:
e.save_binary("string_module", "my_string", "my string".encode("utf-8"))
src = dedent(
"""\
@ -262,7 +251,7 @@ class DirectoryReaderTest(PackageTestCase):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(Path(temp_dir) / Path(filename).name)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
m = dir_importer.import_module("main")
self.assertEqual(m.s, "my string")
@ -271,16 +260,12 @@ class DirectoryReaderTest(PackageTestCase):
Test basic saving and loading of a ScriptModule in a directory.
Currently not supported.
"""
if self.PackageExporter != PackageExporter:
return
from package_a.test_module import ModWithTensor
scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
filename = self.temp()
with self.PackageExporter(filename) as e:
with PackageExporter(filename) as e:
e.save_pickle("res", "mod.pkl", scripted_mod)
zip_file = zipfile.ZipFile(filename, "r")
@ -292,18 +277,9 @@ class DirectoryReaderTest(PackageTestCase):
):
with TemporaryDirectory() as temp_dir:
zip_file.extractall(path=temp_dir)
dir_importer = self.PackageImporter(
Path(temp_dir) / Path(filename).name
)
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
dir_mod = dir_importer.load_pickle("res", "mod.pkl")
class DirectoryReaderTestNoTorch(DirectoryReaderTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageExporter = PackageExporterNoTorch
self.PackageImporter = PackageImporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -10,12 +10,6 @@ from torch.package import (
PackageImporter,
sys_importer,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
@ -28,11 +22,6 @@ except ImportError:
class TestImporter(PackageTestCase):
"""Tests for Importer and derived classes."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_sys_importer(self):
import package_a
import package_a.subpackage
@ -58,11 +47,11 @@ class TestImporter(PackageTestCase):
import package_a
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.save_module(package_a.__name__)
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
# Construct an importer-only environment.
ordered_importer = OrderedImporter(importer)
@ -84,11 +73,11 @@ class TestImporter(PackageTestCase):
import package_a
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.save_module(package_a.__name__)
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
ordered_importer_sys_first = OrderedImporter(sys_importer, importer)
self.assertIs(ordered_importer_sys_first.import_module("package_a"), package_a)
@ -148,32 +137,25 @@ class TestImporter(PackageTestCase):
# Set up a PackageImporter which has a torch.float16 object pickled:
buffer = BytesIO()
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.save_pickle("foo", "foo.pkl", my_dtype)
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
my_loaded_dtype = importer.load_pickle("foo", "foo.pkl")
# Re-save a package with only our PackageImporter as the importer
buffer2 = BytesIO()
with self.PackageExporter(buffer2, importer=importer) as exporter:
with PackageExporter(buffer2, importer=importer) as exporter:
exporter.save_pickle("foo", "foo.pkl", my_loaded_dtype)
buffer2.seek(0)
importer2 = self.PackageImporter(buffer2)
importer2 = PackageImporter(buffer2)
my_loaded_dtype2 = importer2.load_pickle("foo", "foo.pkl")
self.assertIs(my_dtype, my_loaded_dtype)
self.assertIs(my_dtype, my_loaded_dtype2)
class TestImporterNoTorch(TestImporter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -9,12 +9,6 @@ from torch.package._mangling import (
get_mangle_prefix,
is_mangled,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
@ -25,11 +19,6 @@ except ImportError:
class TestMangling(PackageTestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_unique_manglers(self):
"""
Each mangler instance should generate a unique mangled name for a given input.
@ -94,11 +83,11 @@ class TestMangling(PackageTestCase):
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
f1 = BytesIO()
with self.PackageExporter(f1) as pe:
with PackageExporter(f1) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj2)
f1.seek(0)
importer1 = self.PackageImporter(f1)
importer1 = PackageImporter(f1)
loaded1 = importer1.load_pickle("obj", "obj.pkl")
f1.seek(0)
importer2 = PackageImporter(f1)
@ -119,12 +108,5 @@ class TestMangling(PackageTestCase):
self.assertEqual(b.demangle(a_mangled), a_mangled)
class TestManglingNoTorch(TestMangling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -8,19 +8,9 @@ from pathlib import Path
from textwrap import dedent
from unittest import skipIf
from torch.package import (
PackageExporter,
PackageImporter,
is_from_package,
PackagingError,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests, IS_FBCODE, IS_SANDCASTLE
from torch.package import PackageExporter, PackageImporter, is_from_package
from torch.package.package_exporter import PackagingError
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
try:
from .common import PackageTestCase
@ -32,11 +22,6 @@ except ImportError:
class TestMisc(PackageTestCase):
"""Tests for one-off or random functionality. Try not to add to this!"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_file_structure(self):
"""
Tests package's Directory structure representation of a zip file. Ensures
@ -86,7 +71,7 @@ class TestMisc(PackageTestCase):
"""
)
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
import module_a
import package_a
import package_a.subpackage
@ -99,7 +84,7 @@ class TestMisc(PackageTestCase):
he.save_text("main", "main", "my string")
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
file_structure = hi.file_structure()
# remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
@ -154,7 +139,7 @@ class TestMisc(PackageTestCase):
Test Directory's has_file() method.
"""
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
import package_a.subpackage
he.intern("**")
@ -163,7 +148,7 @@ class TestMisc(PackageTestCase):
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
file_structure = importer.file_structure()
self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
self.assertFalse(file_structure.has_file("package_a/subpackage"))
@ -173,7 +158,7 @@ class TestMisc(PackageTestCase):
Test content list API for PackageExporter's contained modules.
"""
with self.PackageExporter(BytesIO()) as he:
with PackageExporter(BytesIO()) as he:
import package_b
he.extern("package_b.subpackage_1")
@ -189,7 +174,7 @@ class TestMisc(PackageTestCase):
self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])
with self.assertRaises(PackagingError) as e:
with self.PackageExporter(BytesIO()) as he:
with PackageExporter(BytesIO()) as he:
import package_b
he.deny("package_b")
@ -203,12 +188,12 @@ class TestMisc(PackageTestCase):
buffer = BytesIO()
obj = package_a.subpackage.PackageASubpackageObject()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
mod = pi.import_module("package_a.subpackage")
loaded_obj = pi.load_pickle("obj", "obj.pkl")
@ -225,12 +210,12 @@ class TestMisc(PackageTestCase):
buffer = BytesIO()
obj = package_a.subpackage.PackageASubpackageObject()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
packaged_class = pi.import_module(
"package_a.subpackage"
).PackageASubpackageObject
@ -249,12 +234,12 @@ class TestMisc(PackageTestCase):
buffer = BytesIO()
obj = package_a.subpackage.PackageASubpackageObject()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
mod = pi.import_module("package_a.subpackage")
self.assertTrue(hasattr(mod, "__torch_package__"))
@ -268,12 +253,12 @@ class TestMisc(PackageTestCase):
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_module(mod.__name__)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
imported_mod = pi.import_module(mod.__name__)
self.assertTrue(imported_mod.is_from_package())
self.assertFalse(mod.is_from_package())
@ -289,22 +274,15 @@ class TestMisc(PackageTestCase):
buffer = BytesIO()
mod = package_a.std_sys_module_hacks.Module()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", mod)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
mod = pi.load_pickle("obj", "obj.pkl")
mod()
class TestMiscNoTorch(TestMisc):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -624,6 +624,7 @@ class TestPackageScript(PackageTestCase):
buffer_1.seek(0)
importer = PackageImporter(buffer_1)
loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
self.assertEqual(
loaded_mod_1.tensor.storage()._cdata,
loaded_mod_1.sub_mod_0.tensor.storage()._cdata,

View File

@ -7,12 +7,6 @@ from torch.package import (
PackageImporter,
sys_importer,
)
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
@ -25,28 +19,23 @@ except ImportError:
class TestRepackage(PackageTestCase):
"""Tests for repackaging."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_repackage_import_indirectly_via_parent_module(self):
from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage
model_a = ImportsDirectlyFromSubSubPackage()
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.intern("**")
pe.save_pickle("default", "model.py", model_a)
buffer.seek(0)
pi = self.PackageImporter(buffer)
pi = PackageImporter(buffer)
loaded_model = pi.load_pickle("default", "model.py")
model_b = ImportsIndirectlyFromSubPackage()
buffer = BytesIO()
with self.PackageExporter(
with PackageExporter(
buffer,
importer=(
pi,
@ -57,12 +46,5 @@ class TestRepackage(PackageTestCase):
pe.save_pickle("default", "model_b.py", model_b)
class TestRepackageNoTorch(TestRepackage):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -7,12 +7,6 @@ from textwrap import dedent
from unittest import skipIf
from torch.package import PackageExporter, PackageImporter
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests
try:
@ -26,15 +20,10 @@ except ImportError:
class TestResources(PackageTestCase):
"""Tests for access APIs for packaged resources."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
def test_resource_reader(self):
"""Test compliance with the get_resource_reader importlib API."""
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
# Layout looks like:
# package
# ├── one/
@ -58,7 +47,7 @@ class TestResources(PackageTestCase):
pe.save_text("two", "g.txt", "hello, g!")
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
reader_one = importer.get_resource_reader("one")
with self.assertRaises(FileNotFoundError):
@ -102,19 +91,19 @@ class TestResources(PackageTestCase):
"""
)
buffer = BytesIO()
with self.PackageExporter(buffer) as pe:
with PackageExporter(buffer) as pe:
pe.save_source_string("foo.bar", mod_src)
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
self.assertEqual(
importer.import_module("foo.bar").secret_message(), "my sekrit plays"
)
def test_importer_access(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.save_text("main", "main", "my string")
he.save_binary("main", "main_binary", "my string".encode("utf-8"))
src = dedent(
@ -128,7 +117,7 @@ class TestResources(PackageTestCase):
)
he.save_source_string("main", src, is_package=True)
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
m = hi.import_module("main")
self.assertEqual(m.t, "my string")
self.assertEqual(m.b, "my string".encode("utf-8"))
@ -138,7 +127,7 @@ class TestResources(PackageTestCase):
Tests that packaged code can used importlib.resources.path.
"""
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
he.save_binary("string_module", "my_string", "my string".encode("utf-8"))
src = dedent(
"""\
@ -152,17 +141,10 @@ class TestResources(PackageTestCase):
)
he.save_source_string("main", src, is_package=True)
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
m = hi.import_module("main")
self.assertEqual(m.s, "my string")
class TestResourcesNoTorch(TestResources):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -6,12 +6,6 @@ from textwrap import dedent
from unittest import skipIf
from torch.package import PackageExporter, PackageImporter, sys_importer
from torch.package.package_exporter_no_torch import (
PackageExporter as PackageExporterNoTorch,
)
from torch.package.package_importer_no_torch import (
PackageImporter as PackageImporterNoTorch,
)
from torch.testing._internal.common_utils import run_tests, IS_FBCODE, IS_SANDCASTLE
try:
@ -28,21 +22,16 @@ packaging_directory = Path(__file__).parent
class TestSaveLoad(PackageTestCase):
"""Core save_* and loading API tests."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporter
self.PackageExporter = PackageExporter
@skipIf(
IS_FBCODE or IS_SANDCASTLE,
"Tests that use temporary files are disabled in fbcode",
)
def test_saving_source(self):
filename = self.temp()
with self.PackageExporter(filename) as he:
with PackageExporter(filename) as he:
he.save_source_file("foo", str(packaging_directory / "module_a.py"))
he.save_source_file("foodir", str(packaging_directory / "package_a"))
hi = self.PackageImporter(filename)
hi = PackageImporter(filename)
foo = hi.import_module("foo")
s = hi.import_module("foodir.subpackage")
self.assertEqual(foo.result, "module_a")
@ -54,7 +43,7 @@ class TestSaveLoad(PackageTestCase):
)
def test_saving_string(self):
filename = self.temp()
with self.PackageExporter(filename) as he:
with PackageExporter(filename) as he:
src = dedent(
"""\
import math
@ -62,7 +51,7 @@ class TestSaveLoad(PackageTestCase):
"""
)
he.save_source_string("my_mod", src)
hi = self.PackageImporter(filename)
hi = PackageImporter(filename)
m = hi.import_module("math")
import math
@ -76,13 +65,13 @@ class TestSaveLoad(PackageTestCase):
)
def test_save_module(self):
filename = self.temp()
with self.PackageExporter(filename) as he:
with PackageExporter(filename) as he:
import module_a
import package_a
he.save_module(module_a.__name__)
he.save_module(package_a.__name__)
hi = self.PackageImporter(filename)
hi = PackageImporter(filename)
module_a_i = hi.import_module("module_a")
self.assertEqual(module_a_i.result, "module_a")
self.assertIsNot(module_a, module_a_i)
@ -92,7 +81,7 @@ class TestSaveLoad(PackageTestCase):
def test_dunder_imports(self):
buffer = BytesIO()
with self.PackageExporter(buffer) as he:
with PackageExporter(buffer) as he:
import package_b
obj = package_b.PackageBObject
@ -100,7 +89,7 @@ class TestSaveLoad(PackageTestCase):
he.save_pickle("res", "obj.pkl", obj)
buffer.seek(0)
hi = self.PackageImporter(buffer)
hi = PackageImporter(buffer)
loaded_obj = hi.load_pickle("res", "obj.pkl")
package_b = hi.import_module("package_b")
@ -124,21 +113,21 @@ class TestSaveLoad(PackageTestCase):
def test_bad_dunder_imports(self):
"""Test to ensure bad __imports__ don't cause PackageExporter to fail."""
buffer = BytesIO()
with self.PackageExporter(buffer) as e:
with PackageExporter(buffer) as e:
e.save_source_string(
"m", '__import__(these, unresolvable, "things", wont, crash, me)'
)
def test_save_module_binary(self):
f = BytesIO()
with self.PackageExporter(f) as he:
with PackageExporter(f) as he:
import module_a
import package_a
he.save_module(module_a.__name__)
he.save_module(package_a.__name__)
f.seek(0)
hi = self.PackageImporter(f)
hi = PackageImporter(f)
module_a_i = hi.import_module("module_a")
self.assertEqual(module_a_i.result, "module_a")
self.assertIsNot(module_a, module_a_i)
@ -157,10 +146,10 @@ class TestSaveLoad(PackageTestCase):
obj2 = package_a.PackageAObject(obj)
filename = self.temp()
with self.PackageExporter(filename) as he:
with PackageExporter(filename) as he:
he.intern("**")
he.save_pickle("obj", "obj.pkl", obj2)
hi = self.PackageImporter(filename)
hi = PackageImporter(filename)
# check we got dependencies
sp = hi.import_module("package_a.subpackage")
@ -190,19 +179,19 @@ class TestSaveLoad(PackageTestCase):
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
f1 = self.temp()
with self.PackageExporter(f1) as pe:
with PackageExporter(f1) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj2)
importer1 = self.PackageImporter(f1)
importer1 = PackageImporter(f1)
loaded1 = importer1.load_pickle("obj", "obj.pkl")
importer2 = self.PackageImporter(f1)
importer2 = PackageImporter(f1)
loaded2 = importer2.load_pickle("obj", "obj.pkl")
f2 = self.temp()
def make_exporter():
pe = self.PackageExporter(f2, importer=[importer1, sys_importer])
pe = PackageExporter(f2, importer=[importer1, sys_importer])
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
return pe
@ -231,21 +220,19 @@ class TestSaveLoad(PackageTestCase):
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.intern("**")
exporter.save_pickle("model", "model.pkl", obj2)
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
imported_obj2 = importer.load_pickle("model", "model.pkl")
imported_obj2_module = imported_obj2.__class__.__module__
# Should export without error.
buffer2 = BytesIO()
with self.PackageExporter(
buffer2, importer=(importer, sys_importer)
) as exporter:
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
exporter.intern("**")
exporter.save_module(imported_obj2_module)
@ -254,29 +241,20 @@ class TestSaveLoad(PackageTestCase):
import package_a.use_torch_package_importer # noqa: F401
buffer = BytesIO()
with self.PackageExporter(buffer) as exporter:
with PackageExporter(buffer) as exporter:
exporter.intern("**")
exporter.save_module("package_a.use_torch_package_importer")
buffer.seek(0)
importer = self.PackageImporter(buffer)
importer = PackageImporter(buffer)
# Should export without error.
buffer2 = BytesIO()
with self.PackageExporter(
buffer2, importer=(importer, sys_importer)
) as exporter:
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
exporter.intern("**")
exporter.save_module("package_a.use_torch_package_importer")
class TestSaveLoadNoTorch(TestSaveLoad):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.PackageImporter = PackageImporterNoTorch
self.PackageExporter = PackageExporterNoTorch
if __name__ == "__main__":
run_tests()

View File

@ -3,7 +3,6 @@ import torch
from torch.package._package_pickler import create_pickler
from torch.package._package_unpickler import PackageUnpickler
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
from torch.package._zip_file_torchscript import TorchScriptPackageZipFileReader
from torch.serialization import _maybe_decode_ascii
def _save_storages(importer, obj):
@ -49,10 +48,7 @@ def _save_storages(importer, obj):
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
assert (not importer or isinstance(importer.zip_reader, TorchScriptPackageZipFileReader)), \
f'importer {importer}\'s zip reader is of type {type(importer.zip_reader)} not TorchScriptPackageZipFileReader'
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader.zip_reader if importer else None
return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):

View File

@ -44,7 +44,6 @@ from torch.jit._monkeytype_config import (
JitTypeTraceStore
)
from torch._classes import classes
from torch.package._zip_file_torchscript import TorchScriptPackageZipFileWriter, TorchScriptPackageZipFileReader
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
@ -343,21 +342,15 @@ def unpackage_script_module(importer: PackageImporter, script_module_id: str) ->
Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
"""
if not isinstance(importer.zip_reader, TorchScriptPackageZipFileReader):
raise RuntimeError(
f"Loading ScriptObjects from a PackageImporter must be done using a TorchScriptPackageZipFileReader"
f"not an object of type {type(importer.zip_reader)}"
)
if importer.zip_reader.is_directory():
if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader):
raise RuntimeError(
"Loading ScriptObjects from a PackageImporter created from a "
f"directory is not supported. Use a package archive file instead. is of type {type(importer.zip_reader)}"
"directory is not supported. Use a package archive file instead."
)
cu = torch._C.CompilationUnit()
cpp_module = torch._C._import_ir_module_from_package(
cu,
importer.zip_reader.zip_reader, # type: ignore[arg-type]
importer.zip_reader,
importer.storage_context,
validate_map_location(importer.last_map_location),
script_module_id,
@ -542,10 +535,8 @@ if _enabled:
Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s
Pickler's ``persistent_load`` function.
"""
assert isinstance(exporter.zip_file, TorchScriptPackageZipFileWriter)
script_module_serializer = exporter.zip_file.script_module_serializer
script_module_id = exporter.get_unique_id()
script_module_serializer.serialize(self._c, int(script_module_id))
exporter.script_module_serializer.serialize(self._c, int(script_module_id))
return (unpackage_script_module, (script_module_id,))
class RecursiveScriptModule(ScriptModule):

View File

@ -8,6 +8,5 @@ from .importer import (
OrderedImporter,
sys_importer,
)
from .package_exporter import PackageExporter
from .package_exporter_no_torch import EmptyMatchError, PackagingError
from .package_exporter import EmptyMatchError, PackageExporter, PackagingError
from .package_importer import PackageImporter

View File

@ -1,5 +1,17 @@
import os.path
from glob import glob
from typing import cast
import torch
from torch.types import Storage
# because get_storage_from_record returns a tensor!?
class _HasStorage(object):
def __init__(self, storage):
self._storage = storage
def storage(self):
return self._storage
class DirectoryReader(object):
@ -7,6 +19,9 @@ class DirectoryReader(object):
Class to allow PackageImporter to operate on unzipped packages. Methods
copy the behavior of the internal PyTorchFileReader class (which is used for
accessing packages in all other cases).
N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
"""
def __init__(self, directory):
@ -17,6 +32,12 @@ class DirectoryReader(object):
with open(filename, "rb") as f:
return f.read()
def get_storage_from_record(self, name, numel, dtype):
filename = f"{self.directory}/{name}"
nbytes = torch._utils._element_size(dtype) * numel
storage = cast(Storage, torch._UntypedStorage)
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
def has_record(self, path):
full_path = os.path.join(self.directory, path)
return os.path.isfile(full_path)
@ -29,6 +50,3 @@ class DirectoryReader(object):
if not os.path.isdir(filename):
files.append(filename[len(self.directory) + 1 :])
return files
def close(self):
pass

View File

@ -1,34 +0,0 @@
from typing import cast
import torch
from torch.types import Storage
from ._directory_reader import DirectoryReader
# because get_storage_from_record returns a tensor!?
class _HasStorage(object):
def __init__(self, storage):
self._storage = storage
def storage(self):
return self._storage
class TorchScriptDirectoryReader(DirectoryReader):
"""
Class to allow PackageImporter to operate on unzipped packages which include
torchscript modules. Methods copy the behavior of the internal PyTorchFileReader
class (which is used for accessing packages in all other cases).
N.B.: ScriptObjects are not depickleable or accessible via this TorchScriptDirectoryReader
class due to ScriptObjects requiring an actual PyTorchFileReader instance.
"""
def __init__(self, directory):
super().__init__(directory)
def get_storage_from_record(self, name, numel, dtype):
filename = f"{self.directory}/{name}"
nbytes = torch._utils._element_size(dtype) * numel
storage = cast(Storage, torch._UntypedStorage)
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))

View File

@ -1,38 +0,0 @@
import weakref
from collections import OrderedDict
from typing import Any
class RemovableHandle:
"""A handle which provides the capability to remove a hook."""
id: int
next_id: int = 0
def __init__(self, hooks_dict: Any) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
def __getstate__(self):
return (self.hooks_dict_ref(), self.id)
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
def __enter__(self) -> "RemovableHandle":
return self
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()

View File

@ -1,163 +0,0 @@
import os.path
import zipfile
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import List, Union, BinaryIO, Optional
from ._directory_reader import DirectoryReader
class PackageZipFileReader(ABC):
"""
Class to allow PackageImporter to operate objects. To create a custom
zip file reader for PackageImporter simply inherit this class.
"""
def __init__(self, file_or_buffer: Union[str, Path, BinaryIO]):
raise NotImplementedError(
f"init(self, name: str) is not implemented in {type(self)}"
)
@abstractmethod
def get_record(self, name: str) -> bytes:
raise NotImplementedError(
f"get_record(self, name: str) is not implemented in {type(self)}"
)
@abstractmethod
def has_record(self, path: str) -> bool:
raise NotImplementedError(
f"has_record(self, path: str) is not implemented in {type(self)}"
)
@abstractmethod
def get_all_records(self) -> List[str]:
raise NotImplementedError(
f"get_all_records(self) is not implemented in {type(self)}"
)
@abstractmethod
def get_filename(self) -> str:
raise NotImplementedError(
f"get_filename(self) is not implemented in {type(self)}"
)
def is_directory(self) -> bool:
raise NotImplementedError(
f"is_directory(self) is not implemented in {type(self)}"
)
@abstractmethod
def close(self):
raise NotImplementedError(f"close(self) is not implemented in {type(self)}")
class PackageZipFileWriter(ABC):
"""
Class to allow PackageExporter to operate objects. To create a custom
zip file writer for PackageExporter simply inherit this class.
"""
def __init__(self, f: Union[str, Path, BinaryIO]):
raise NotImplementedError(
f"init(self, name: str) is not implemented in {type(self)}"
)
@abstractmethod
def write_record(self, f, str_or_bytes: Union[str, bytes], size: int):
raise NotImplementedError(
f"write_record(self, f, str_or_bytes, size) is not implemented in {type(self)}"
)
@abstractmethod
def close(self):
raise NotImplementedError(f"close(self) is not implemented in {type(self)}")
class DefaultPackageZipFileWriter(zipfile.ZipFile, PackageZipFileWriter):
"""
Class to allow PackageExporter to operate general objects. This is default
zipfile reader. This is effectively a wrapper around ZipFile to have a similar
API to torch._C.PyTorchWriter.
"""
def __init__(self, f: Union[str, Path, BinaryIO]):
if isinstance(f, (Path, str)):
f = str(f)
self.buffer: Optional[BinaryIO] = None
else: # is a byte buffer
self.buffer = f
super().__init__(f, mode="w")
self.prefix: str = "archive"
if isinstance(f, (Path, str)):
self.prefix = "/".join(str(f).strip("/").split("/")[1:])
super().writestr(f"{self.prefix}/.data/version", "6\n")
def write_record(self, f: str, str_or_bytes: Union[str, bytes], size: int = None):
super().writestr(f"{self.prefix}/{f}", str_or_bytes)
def close(self):
if self.buffer:
self.buffer.flush()
super().close()
class DefaultPackageZipFileReader(PackageZipFileReader):
"""
Class to allow PackageImporter to operate general objects. This is default
zipfile reader. This is effectively a wrapper around ZipFile to have a similar
API to torch._C.PyTorchReader.
"""
def __init__(self, file_or_buffer: Union[str, Path, BinaryIO]):
if isinstance(file_or_buffer, (Path, str)):
self.filename = str(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader: Union[
zipfile.ZipFile, DirectoryReader
] = zipfile.ZipFile(self.filename)
else:
self.zip_reader = DirectoryReader(self.filename)
else:
self.filename = "<binary>"
self.zip_reader = zipfile.ZipFile(file_or_buffer)
if isinstance(self.zip_reader, DirectoryReader):
self.records = self.zip_reader.get_all_records()
elif isinstance(self.zip_reader, zipfile.ZipFile):
prefixed_records = self.zip_reader.namelist()
self.records = []
if isinstance(file_or_buffer, BytesIO):
self.prefix = "archive"
else:
self.prefix = "/".join(str(file_or_buffer).strip("/").split("/")[1:])
for record in prefixed_records:
self.records.append(record[len(self.prefix) + 1 :])
def get_record(self, name: str) -> bytes:
if isinstance(self.zip_reader, DirectoryReader):
return self.zip_reader.get_record(f"{name}")
else:
return self.zip_reader.read(f"{self.prefix}/{name}")
def has_record(self, path: str) -> bool:
return path in self.records
def get_all_records(self) -> List[str]:
return list(self.records)
def get_filename(self) -> str:
return self.filename
def is_directory(self) -> bool:
return isinstance(self.zip_reader, DirectoryReader)
def close(self):
self.zip_reader.close()

View File

@ -1,91 +0,0 @@
import os.path
from pathlib import Path
from typing import List, BinaryIO, Union, Optional
import torch
from ._directory_reader_torchscript import TorchScriptDirectoryReader, _HasStorage
from ._zip_file import PackageZipFileReader, PackageZipFileWriter
class TorchScriptPackageZipFileWriter(PackageZipFileWriter):
"""
Class to allow PackageExporter to operate torchscript objects. This
is a wrapper around the PyTorchFileWriter and ScriptModuleSerializer classes.
"""
def __init__(self, f: Union[str, Path, BinaryIO]):
if isinstance(f, (Path, str)):
f = str(f)
self.buffer: Optional[BinaryIO] = None
else: # is a byte buffer
self.buffer = f
self.zip_file_writer = torch._C.PyTorchFileWriter(f)
self.zip_file_writer.set_min_version(6)
self.script_module_serializer = torch._C.ScriptModuleSerializer(
self.zip_file_writer
)
self.storage_context = self.script_module_serializer.storage_context()
def write_record(self, f: str, str_or_bytes: Union[str, bytes], size: int):
if isinstance(str_or_bytes, str):
str_or_bytes = str.encode(f)
self.zip_file_writer.write_record(f, str_or_bytes, size)
def close(self):
self.script_module_serializer.write_files()
if self.buffer:
self.buffer.flush()
class TorchScriptPackageZipFileReader(PackageZipFileReader):
"""
Class to allow PackageImporter to operate torchscript objects. This
is a wrapper around the PyTorchReader class.
"""
def __init__(
self, file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO]
):
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
self.filename = "<pytorch_file_reader>"
self.zip_reader: Union[
torch._C.PyTorchFileReader, TorchScriptDirectoryReader
] = file_or_buffer
elif isinstance(file_or_buffer, (Path, str)):
self.filename = str(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
else:
self.zip_reader = TorchScriptDirectoryReader(self.filename)
else:
self.filename = "<binary>"
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
def get_record(self, name: str) -> bytes:
return self.zip_reader.get_record(name)
# NOTE: for has_record, get_all_records, and get_storage_from_record pybind doesn't reaveal
# the attributes of PyTorchFileReader, so it'll call an error. Strangely, this error
# doesn't have an error code which is why it's ignored
def has_record(self, path: str) -> bool:
return self.zip_reader.has_record(path) # type: ignore[union-attr]
def get_all_records(self) -> List[str]:
return self.zip_reader.get_all_records() # type: ignore[union-attr]
def get_storage_from_record(
self, name: str, numel: int, dtype: torch.dtype
) -> _HasStorage:
return self.zip_reader.get_storage_from_record(name, numel, dtype) # type: ignore[union-attr]
def get_filename(self) -> str:
return self.filename
def is_directory(self) -> bool:
return isinstance(self.zip_reader, TorchScriptDirectoryReader)
def close(self):
pass

View File

@ -1,6 +1,6 @@
from typing import Dict, List
from ..package_exporter_no_torch import PackagingError
from ..package_exporter import PackagingError
def find_first_use_of_broken_modules(exc: PackagingError) -> Dict[str, List[str]]:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,55 @@
import builtins
import importlib
import inspect
import io
import linecache
import os.path
import types
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Union, BinaryIO, Callable, Dict
from typing import cast, Any, BinaryIO, Callable, Dict, List, Optional, Union
from weakref import WeakValueDictionary
import torch
from torch.serialization import _get_restore_location
from torch.serialization import _get_restore_location, _maybe_decode_ascii
from ._directory_reader_torchscript import TorchScriptDirectoryReader
from ._zip_file_torchscript import TorchScriptPackageZipFileReader
from .package_importer_no_torch import _maybe_decode_ascii
from .package_importer_no_torch import PackageImporter as DefaultPackageImporter
from ._directory_reader import DirectoryReader
from ._importlib import (
_calc___package__,
_normalize_line_endings,
_normalize_path,
_resolve_name,
_sanity_check,
)
from ._mangling import PackageMangler, demangle
from ._package_unpickler import PackageUnpickler
from .file_structure_representation import Directory, _create_directory_from_file_list
from .glob_group import GlobPattern
from .importer import Importer
class PackageImporter(DefaultPackageImporter):
class PackageImporter(Importer):
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
Code is loaded in a hermetic way, using files from the package
rather than the normal python import system. This allows
for the packaging of PyTorch model code and data so that it can be run
on a server or used in the future for transfer learning.
The importer for packages ensures that code in the module can only be loaded from
within the package, except for modules explicitly listed as external during export.
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
This prevents "implicit" dependencies where the package runs locally because it is importing
a locally-installed package, but then fails when the package is copied to another machine.
"""
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
local to this importer.
"""
modules: Dict[str, types.ModuleType]
def __init__(
self,
file_or_buffer: Union[str, Path, BinaryIO],
file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
module_allowed: Callable[[str], bool] = lambda module_name: True,
):
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
@ -31,69 +65,107 @@ class PackageImporter(DefaultPackageImporter):
Raises:
ImportError: If the package will use a disallowed module.
"""
super(PackageImporter, self).__init__(
file_or_buffer,
module_allowed,
zip_file_reader_type=TorchScriptPackageZipFileReader,
)
def persistent_load(self, typename, data):
assert isinstance(
self.zip_reader,
(TorchScriptDirectoryReader, TorchScriptPackageZipFileReader),
)
def load_tensor(dtype, size, key, location, restore_location):
assert self.loaded_storages is not None
name = f"{key}.storage"
if self.storage_context.has_storage(name):
storage = self.storage_context.get_storage(name, dtype).storage()
self.zip_reader: Any
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
self.filename = "<pytorch_file_reader>"
self.zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (Path, str)):
self.filename = str(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
else:
tensor = self.zip_reader.get_storage_from_record( # type: ignore[attr-defined]
".data/" + name, size, dtype
self.zip_reader = DirectoryReader(self.filename)
else:
self.filename = "<binary>"
self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
self.root = _PackageNode(None)
self.modules = {}
self.extern_modules = self._read_extern()
for extern_module in self.extern_modules:
if not module_allowed(extern_module):
raise ImportError(
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
f"but that module has been disallowed"
)
if not self.zip_reader.is_directory():
self.storage_context.add_storage(name, tensor)
storage = tensor.storage()
self.loaded_storages[key] = restore_location(storage, location)
self._add_extern(extern_module)
if typename == "storage":
storage_type, key, location, size = data
dtype = storage_type.dtype
assert self.loaded_storages is not None
if key not in self.loaded_storages:
load_tensor(
dtype,
size,
key,
_maybe_decode_ascii(location),
self.restore_location,
)
storage = self.loaded_storages[key]
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
)
return None
for fname in self.zip_reader.get_all_records():
self._add_file(fname)
@contextmanager
def set_torch_deserialization_context(self, map_location):
# to let reduce_package access deserializaiton context
self.storage_context = torch._C.DeserializationStorageContext()
self.last_map_location = map_location
self.restore_location = _get_restore_location(map_location)
self.loaded_storages: Union[Dict[int, Any], None] = {}
try:
yield
finally:
self.storage_context = None
self.last_map_location = None
self.restore_location = None
self.loaded_storages = None
self.patched_builtins = builtins.__dict__.copy()
self.patched_builtins["__import__"] = self.__import__
# Allow packaged modules to reference their PackageImporter
self.modules["torch_package_importer"] = self # type: ignore[assignment]
self._mangler = PackageMangler()
# used for reduce deserializaiton
self.storage_context: Any = None
self.last_map_location = None
# used for torch.serialization._load
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
def import_module(self, name: str, package=None):
"""Load a module from the package if it hasn't already been loaded, and then return
the module. Modules are loaded locally
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
Args:
name (str): Fully qualified name of the module to load.
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
Returns:
types.ModuleType: The (possibly already) loaded module.
"""
# We should always be able to support importing modules from this package.
# This is to support something like:
# obj = importer.load_pickle(...)
# importer.import_module(obj.__module__) <- this string will be mangled
#
# Note that _mangler.demangle will not demangle any module names
# produced by a different PackageImporter instance.
name = self._mangler.demangle(name)
return self._gcd_import(name)
def load_binary(self, package: str, resource: str) -> bytes:
"""Load raw bytes.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
Returns:
bytes: The loaded data.
"""
path = self._zipfile_path(package, resource)
return self.zip_reader.get_record(path)
def load_text(
self,
package: str,
resource: str,
encoding: str = "utf-8",
errors: str = "strict",
) -> str:
"""Load a string.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
Returns:
str: The loaded text.
"""
data = self.load_binary(package, resource)
return data.decode(encoding, errors)
# TODO: load_pickle to reduce the repeated code between this and the non-torch version
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
using :meth:`import_module`.
@ -107,16 +179,49 @@ class PackageImporter(DefaultPackageImporter):
Any: The unpickled object.
"""
pickle_file = self._zipfile_path(package, resource)
restore_location = _get_restore_location(map_location)
loaded_storages = {}
loaded_reduces = {}
storage_context = torch._C.DeserializationStorageContext()
def _persistent_load(saved_id):
def load_tensor(dtype, size, key, location, restore_location):
name = f"{key}.storage"
if storage_context.has_storage(name):
storage = storage_context.get_storage(name, dtype).storage()
else:
tensor = self.zip_reader.get_storage_from_record(
".data/" + name, size, dtype
)
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
storage_context.add_storage(name, tensor)
storage = tensor.storage()
loaded_storages[key] = restore_location(storage, location)
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
module = self.persistent_load(typename, data)
if module is not None:
return module
if typename == "reduce_package":
if typename == "storage":
storage_type, key, location, size = data
dtype = storage_type.dtype
if key not in loaded_storages:
load_tensor(
dtype,
size,
key,
_maybe_decode_ascii(location),
restore_location,
)
storage = loaded_storages[key]
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
)
elif typename == "reduce_package":
# to fix BC breaking change, objects on this load path
# will be loaded multiple times erroneously
if len(data) == 2:
@ -132,13 +237,466 @@ class PackageImporter(DefaultPackageImporter):
# Load the data (which may in turn use `persistent_load` to load tensors)
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
unpickler = self.Unpickler(data_file)
unpickler.persistent_load = _persistent_load
with self.set_torch_deserialization_context(map_location):
unpickler.persistent_load = persistent_load
@contextmanager
def set_deserialization_context():
# to let reduce_package access deserializaiton context
self.storage_context = storage_context
self.last_map_location = map_location
try:
yield
finally:
self.storage_context = None
self.last_map_location = None
with set_deserialization_context():
result = unpickler.load()
# TODO from zdevito:
# This stateful weird function will need to be removed in our efforts
# to unify the format. It has a race condition if multiple python
# threads try to read independent files
torch._utils._validate_loaded_sparse_tensors()
# TODO from zdevito:
# This stateful weird function will need to be removed in our efforts
# to unify the format. It has a race condition if multiple python
# threads try to read independent files
torch._utils._validate_loaded_sparse_tensors()
return result
def id(self):
"""
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
Looks like::
<torch_package_0>
"""
return self._mangler.parent_name()
def file_structure(
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
) -> Directory:
"""Returns a file structure representation of package's zipfile.
Args:
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
for the names of the files to be inluded in the zipfile representation. This can also be
a glob-style pattern, as described in :meth:`PackageExporter.mock`
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
Returns:
:class:`Directory`
"""
return _create_directory_from_file_list(
self.filename, self.zip_reader.get_all_records(), include, exclude
)
def python_version(self):
"""Returns the version of python that was used to create this package.
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
file later on.
Returns:
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
"""
python_version_path = ".data/python_version"
return (
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
if self.zip_reader.has_record(python_version_path)
else None
)
def _read_extern(self):
return (
self.zip_reader.get_record(".data/extern_modules")
.decode("utf-8")
.splitlines(keepends=False)
)
def _make_module(
self, name: str, filename: Optional[str], is_package: bool, parent: str
):
mangled_filename = self._mangler.mangle(filename) if filename else None
spec = importlib.machinery.ModuleSpec(
name,
self, # type: ignore[arg-type]
origin="<package_importer>",
is_package=is_package,
)
module = importlib.util.module_from_spec(spec)
self.modules[name] = module
module.__name__ = self._mangler.mangle(name)
ns = module.__dict__
ns["__spec__"] = spec
ns["__loader__"] = self
ns["__file__"] = mangled_filename
ns["__cached__"] = None
ns["__builtins__"] = self.patched_builtins
ns["__torch_package__"] = True
# Add this module to our private global registry. It should be unique due to mangling.
assert module.__name__ not in _package_imported_modules
_package_imported_modules[module.__name__] = module
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
# access sys.modules
self._install_on_parent(parent, name, module)
if filename is not None:
assert mangled_filename is not None
# pre-emptively install the source in `linecache` so that stack traces,
# `inspect`, etc. work.
assert filename not in linecache.cache # type: ignore[attr-defined]
linecache.lazycache(mangled_filename, ns)
code = self._compile_source(filename, mangled_filename)
exec(code, ns)
return module
def _load_module(self, name: str, parent: str):
cur: _PathNode = self.root
for atom in name.split("."):
if not isinstance(cur, _PackageNode) or atom not in cur.children:
raise ModuleNotFoundError(
f'No module named "{name}" in self-contained archive "{self.filename}"'
f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
name=name,
)
cur = cur.children[atom]
if isinstance(cur, _ExternNode):
module = self.modules[name] = importlib.import_module(name)
return module
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
def _compile_source(self, fullpath: str, mangled_filename: str):
source = self.zip_reader.get_record(fullpath)
source = _normalize_line_endings(source)
return compile(source, mangled_filename, "exec", dont_inherit=True)
# note: named `get_source` so that linecache can find the source
# when this is the __loader__ of a module.
def get_source(self, module_name) -> str:
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
module = self.import_module(demangle(module_name))
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
# note: named `get_resource_reader` so that importlib.resources can find it.
# This is otherwise considered an internal method.
def get_resource_reader(self, fullname):
try:
package = self._get_package(fullname)
except ImportError:
return None
if package.__loader__ is not self:
return None
return _PackageResourceReader(self, fullname)
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
if not parent:
return
# Set the module as an attribute on its parent.
parent_module = self.modules[parent]
if parent_module.__loader__ is self:
setattr(parent_module, name.rpartition(".")[2], module)
# note: copied from cpython's import code, with call to create module replaced with _make_module
def _do_find_and_load(self, name):
path = None
parent = name.rpartition(".")[0]
if parent:
if parent not in self.modules:
self._gcd_import(parent)
# Crazy side-effects!
if name in self.modules:
return self.modules[name]
parent_module = self.modules[parent]
try:
path = parent_module.__path__ # type: ignore[attr-defined]
except AttributeError:
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
raise ModuleNotFoundError(msg, name=name) from None
module = self._load_module(name, parent)
self._install_on_parent(parent, name, module)
return module
# note: copied from cpython's import code
def _find_and_load(self, name):
module = self.modules.get(name, _NEEDS_LOADING)
if module is _NEEDS_LOADING:
return self._do_find_and_load(name)
if module is None:
message = "import of {} halted; " "None in sys.modules".format(name)
raise ModuleNotFoundError(message, name=name)
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
# creation of fake submodules via the hacking of sys.modules is not import
# friendly
if name == "os":
self.modules["os.path"] = cast(Any, module).path
elif name == "typing":
self.modules["typing.io"] = cast(Any, module).io
self.modules["typing.re"] = cast(Any, module).re
return module
def _gcd_import(self, name, package=None, level=0):
"""Import and return the module based on its name, the package the call is
being made from, and the level adjustment.
This function represents the greatest common denominator of functionality
between import_module and __import__. This includes setting __package__ if
the loader did not.
"""
_sanity_check(name, package, level)
if level > 0:
name = _resolve_name(name, package, level)
return self._find_and_load(name)
# note: copied from cpython's import code
def _handle_fromlist(self, module, fromlist, *, recursive=False):
"""Figure out what __import__ should return.
The import_ parameter is a callable which takes the name of module to
import. It is required to decouple the function from assuming importlib's
import implementation is desired.
"""
module_name = demangle(module.__name__)
# The hell that is fromlist ...
# If a package was imported, try to import stuff from fromlist.
if hasattr(module, "__path__"):
for x in fromlist:
if not isinstance(x, str):
if recursive:
where = module_name + ".__all__"
else:
where = "``from list''"
raise TypeError(
f"Item in {where} must be str, " f"not {type(x).__name__}"
)
elif x == "*":
if not recursive and hasattr(module, "__all__"):
self._handle_fromlist(module, module.__all__, recursive=True)
elif not hasattr(module, x):
from_name = "{}.{}".format(module_name, x)
try:
self._gcd_import(from_name)
except ModuleNotFoundError as exc:
# Backwards-compatibility dictates we ignore failed
# imports triggered by fromlist for modules that don't
# exist.
if (
exc.name == from_name
and self.modules.get(from_name, _NEEDS_LOADING) is not None
):
continue
raise
return module
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
if level == 0:
module = self._gcd_import(name)
else:
globals_ = globals if globals is not None else {}
package = _calc___package__(globals_)
module = self._gcd_import(name, package, level)
if not fromlist:
# Return up to the first dot in 'name'. This is complicated by the fact
# that 'name' may be relative.
if level == 0:
return self._gcd_import(name.partition(".")[0])
elif not name:
return module
else:
# Figure out where to slice the module's name up to the first dot
# in 'name'.
cut_off = len(name) - len(name.partition(".")[0])
# Slice end needs to be positive to alleviate need to special-case
# when ``'.' not in name``.
module_name = demangle(module.__name__)
return self.modules[module_name[: len(module_name) - cut_off]]
else:
return self._handle_fromlist(module, fromlist)
def _get_package(self, package):
"""Take a package name or module object and return the module.
If a name, the module is imported. If the passed or imported module
object is not a package, raise an exception.
"""
if hasattr(package, "__spec__"):
if package.__spec__.submodule_search_locations is None:
raise TypeError("{!r} is not a package".format(package.__spec__.name))
else:
return package
else:
module = self.import_module(package)
if module.__spec__.submodule_search_locations is None:
raise TypeError("{!r} is not a package".format(package))
else:
return module
def _zipfile_path(self, package, resource=None):
package = self._get_package(package)
assert package.__loader__ is self
name = demangle(package.__name__)
if resource is not None:
resource = _normalize_path(resource)
return f"{name.replace('.', '/')}/{resource}"
else:
return f"{name.replace('.', '/')}"
def _get_or_create_package(
self, atoms: List[str]
) -> "Union[_PackageNode, _ExternNode]":
cur = self.root
for i, atom in enumerate(atoms):
node = cur.children.get(atom, None)
if node is None:
node = cur.children[atom] = _PackageNode(None)
if isinstance(node, _ExternNode):
return node
if isinstance(node, _ModuleNode):
name = ".".join(atoms[:i])
raise ImportError(
f"inconsistent module structure. module {name} is not a package, but has submodules"
)
assert isinstance(node, _PackageNode)
cur = node
return cur
def _add_file(self, filename: str):
"""Assembles a Python module out of the given file. Will ignore files in the .data directory.
Args:
filename (str): the name of the file inside of the package archive to be added
"""
*prefix, last = filename.split("/")
if len(prefix) > 1 and prefix[0] == ".data":
return
package = self._get_or_create_package(prefix)
if isinstance(package, _ExternNode):
raise ImportError(
f"inconsistent module structure. package contains a module file {filename}"
f" that is a subpackage of a module marked external."
)
if last == "__init__.py":
package.source_file = filename
elif last.endswith(".py"):
package_name = last[: -len(".py")]
package.children[package_name] = _ModuleNode(filename)
def _add_extern(self, extern_name: str):
*prefix, last = extern_name.split(".")
package = self._get_or_create_package(prefix)
if isinstance(package, _ExternNode):
return # the shorter extern covers this extern case
package.children[last] = _ExternNode()
_NEEDS_LOADING = object()
_ERR_MSG_PREFIX = "No module named "
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
class _PathNode:
pass
class _PackageNode(_PathNode):
def __init__(self, source_file: Optional[str]):
self.source_file = source_file
self.children: Dict[str, _PathNode] = {}
class _ModuleNode(_PathNode):
__slots__ = ["source_file"]
def __init__(self, source_file: str):
self.source_file = source_file
class _ExternNode(_PathNode):
pass
# A private global registry of all modules that have been package-imported.
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
# `inspect` by default only looks in `sys.modules` to find source files for classes.
# Patch it to check our private registry of package-imported modules as well.
_orig_getfile = inspect.getfile
def patched_getfile(object):
if inspect.isclass(object):
if object.__module__ in _package_imported_modules:
return _package_imported_modules[object.__module__].__file__
return _orig_getfile(object)
inspect.getfile = patched_getfile
class _PackageResourceReader:
"""Private class used to support PackageImporter.get_resource_reader().
Confirms to the importlib.abc.ResourceReader interface. Allowed to access
the innards of PackageImporter.
"""
def __init__(self, importer, fullname):
self.importer = importer
self.fullname = fullname
def open_resource(self, resource):
from io import BytesIO
return BytesIO(self.importer.load_binary(self.fullname, resource))
def resource_path(self, resource):
# The contract for resource_path is that it either returns a concrete
# file system path or raises FileNotFoundError.
if isinstance(
self.importer.zip_reader, DirectoryReader
) and self.importer.zip_reader.has_record(
os.path.join(self.fullname, resource)
):
return os.path.join(
self.importer.zip_reader.directory, self.fullname, resource
)
raise FileNotFoundError
def is_resource(self, name):
path = self.importer._zipfile_path(self.fullname, name)
return self.importer.zip_reader.has_record(path)
def contents(self):
from pathlib import Path
filename = self.fullname.replace(".", "/")
fullname_path = Path(self.importer._zipfile_path(self.fullname))
files = self.importer.zip_reader.get_all_records()
subdirs_seen = set()
for filename in files:
try:
relative = Path(filename).relative_to(fullname_path)
except ValueError:
continue
# If the path of the file (which is relative to the top of the zip
# namespace), relative to the package given when the resource
# reader was created, has a parent, then it's a name in a
# subdirectory and thus we skip it.
parent_name = relative.parent.name
if len(parent_name) == 0:
yield relative.name
elif parent_name not in subdirs_seen:
subdirs_seen.add(parent_name)
yield parent_name

View File

@ -1,672 +0,0 @@
import builtins
import importlib
import inspect
import io
import linecache
import os.path
import types
from pathlib import Path
from typing import cast, Any, BinaryIO, Callable, Dict, List, Optional, Union, Type
from weakref import WeakValueDictionary
from ._directory_reader import DirectoryReader
from ._importlib import (
_calc___package__,
_normalize_line_endings,
_normalize_path,
_resolve_name,
_sanity_check,
)
from ._mangling import PackageMangler, demangle
from ._package_unpickler import PackageUnpickler
from ._zip_file import PackageZipFileReader, DefaultPackageZipFileReader
from .file_structure_representation import Directory, _create_directory_from_file_list
from .glob_group import GlobPattern
from .importer import Importer
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.
#
# NOTE: This should only be used on internal keys (e.g., `typename` and
# `location` in `persistent_load` below!
if isinstance(bytes_str, bytes):
return bytes_str.decode("ascii")
return bytes_str
class PackageImporter(Importer):
"""Importers allow you to load code written to packages by :class:`PackageExporter`.
Code is loaded in a hermetic way, using files from the package
rather than the normal python import system. This allows
for the packaging of PyTorch model code and data so that it can be run
on a server or used in the future for transfer learning.
The importer for packages ensures that code in the module can only be loaded from
within the package, except for modules explicitly listed as external during export.
The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
This prevents "implicit" dependencies where the package runs locally because it is importing
a locally-installed package, but then fails when the package is copied to another machine.
"""
"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
local to this importer.
"""
modules: Dict[str, types.ModuleType]
def get_zip_reader(self, file_or_buffer):
zip_reader: Any
if isinstance(file_or_buffer, DefaultPackageZipFileReader):
filename = "<pytorch_file_reader>"
zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (Path, str)):
filename = str(file_or_buffer)
if not os.path.isdir(filename):
zip_reader = DefaultPackageZipFileReader(filename)
else:
zip_reader = DirectoryReader(filename)
else:
filename = "<binary>"
zip_reader = DefaultPackageZipFileReader(file_or_buffer)
return filename, zip_reader
def __init__(
self,
file_or_buffer: Union[str, Path, BinaryIO],
module_allowed: Callable[[str], bool] = lambda module_name: True,
zip_file_reader_type: Type[PackageZipFileReader] = DefaultPackageZipFileReader,
):
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
allowed by ``module_allowed``
Args:
file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
a string, or an ``os.PathLike`` object containing a filename.
module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
does not support. Defaults to allowing anything.
zip_file_writer_type: A subclass of PackageZipFileReader which would be used to instantiate the zip file reader
Raises:
ImportError: If the package will use a disallowed module.
"""
self.zip_reader = zip_file_reader_type(file_or_buffer)
self.root = _PackageNode(None)
self.modules = {}
self.extern_modules = self._read_extern()
for extern_module in self.extern_modules:
if not module_allowed(extern_module):
raise ImportError(
f"package '{file_or_buffer}' needs the external module '{extern_module}' "
f"but that module has been disallowed"
)
self._add_extern(extern_module)
for fname in self.zip_reader.get_all_records():
self._add_file(fname)
self.patched_builtins = builtins.__dict__.copy()
self.patched_builtins["__import__"] = self.__import__
# Allow packaged modules to reference their PackageImporter
self.modules["torch_package_importer"] = self # type: ignore[assignment]
self._mangler = PackageMangler()
# used for reduce deserializaiton
self.storage_context: Any = None
self.last_map_location = None
# used for torch.serialization._load
self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
def import_module(self, name: str, package=None):
"""Load a module from the package if it hasn't already been loaded, and then return
the module. Modules are loaded locally
to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
Args:
name (str): Fully qualified name of the module to load.
package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
Returns:
types.ModuleType: The (possibly already) loaded module.
"""
# We should always be able to support importing modules from this package.
# This is to support something like:
# obj = importer.load_pickle(...)
# importer.import_module(obj.__module__) <- this string will be mangled
#
# Note that _mangler.demangle will not demangle any module names
# produced by a different PackageImporter instance.
name = self._mangler.demangle(name)
return self._gcd_import(name)
def load_binary(self, package: str, resource: str) -> bytes:
"""Load raw bytes.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
Returns:
bytes: The loaded data.
"""
path = self._zipfile_path(package, resource)
return self.zip_reader.get_record(path)
def load_text(
self,
package: str,
resource: str,
encoding: str = "utf-8",
errors: str = "strict",
) -> str:
"""Load a string.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
Returns:
str: The loaded text.
"""
data = self.load_binary(package, resource)
return data.decode(encoding, errors)
def persistent_load(self, typename, data):
# meant to be overwritten
return None
def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
"""Unpickles the resource from the package, loading any modules that are needed to construct the objects
using :meth:`import_module`.
Args:
package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
resource (str): The unique name for the resource.
map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
Returns:
Any: The unpickled object.
"""
pickle_file = self._zipfile_path(package, resource)
loaded_reduces = {}
def _persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
module = self.persistent_load(typename, data)
if module is not None:
return module
if typename == "reduce_package":
# to fix BC breaking change, objects on this load path
# will be loaded multiple times erroneously
if len(data) == 2:
func, args = data
return func(self, *args)
reduce_id, func, args = data
if reduce_id not in loaded_reduces:
loaded_reduces[reduce_id] = func(self, *args)
return loaded_reduces[reduce_id]
else:
f"Unknown typename for persistent_load, got '{typename}'"
# Load the data (which may in turn use `persistent_load` to load tensors)
data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
unpickler = self.Unpickler(data_file)
unpickler.persistent_load = _persistent_load
result = unpickler.load()
return result
def id(self):
"""
Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
Looks like::
<torch_package_0>
"""
return self._mangler.parent_name()
def file_structure(
self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
) -> Directory:
"""Returns a file structure representation of package's zipfile.
Args:
include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
for the names of the files to be inluded in the zipfile representation. This can also be
a glob-style pattern, as described in :meth:`PackageExporter.mock`
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
Returns:
:class:`Directory`
"""
return _create_directory_from_file_list(
self.zip_reader.get_filename(),
self.zip_reader.get_all_records(),
include,
exclude,
)
def python_version(self):
"""Returns the version of python that was used to create this package.
Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
file later on.
Returns:
:class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
"""
python_version_path = ".data/python_version"
return (
self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
if self.zip_reader.has_record(python_version_path)
else None
)
def _read_extern(self):
return (
self.zip_reader.get_record(".data/extern_modules")
.decode("utf-8")
.splitlines(keepends=False)
)
def _make_module(
self, name: str, filename: Optional[str], is_package: bool, parent: str
):
mangled_filename = self._mangler.mangle(filename) if filename else None
spec = importlib.machinery.ModuleSpec(
name,
self, # type: ignore[arg-type]
origin="<package_importer>",
is_package=is_package,
)
module = importlib.util.module_from_spec(spec)
self.modules[name] = module
module.__name__ = self._mangler.mangle(name)
ns = module.__dict__
ns["__spec__"] = spec
ns["__loader__"] = self
ns["__file__"] = mangled_filename
ns["__cached__"] = None
ns["__builtins__"] = self.patched_builtins
ns["__torch_package__"] = True
# Add this module to our private global registry. It should be unique due to mangling.
assert module.__name__ not in _package_imported_modules
_package_imported_modules[module.__name__] = module
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
# access sys.modules
self._install_on_parent(parent, name, module)
if filename is not None:
assert mangled_filename is not None
# pre-emptively install the source in `linecache` so that stack traces,
# `inspect`, etc. work.
assert filename not in linecache.cache # type: ignore[attr-defined]
linecache.lazycache(mangled_filename, ns)
code = self._compile_source(filename, mangled_filename)
exec(code, ns)
return module
def _load_module(self, name: str, parent: str):
cur: _PathNode = self.root
for atom in name.split("."):
if not isinstance(cur, _PackageNode) or atom not in cur.children:
raise ModuleNotFoundError(
f'No module named "{name}" in self-contained archive "{self.zip_reader.get_filename()}"'
f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
name=name,
)
cur = cur.children[atom]
if isinstance(cur, _ExternNode):
module = self.modules[name] = importlib.import_module(name)
return module
return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
def _compile_source(self, fullpath: str, mangled_filename: str):
source = self.zip_reader.get_record(fullpath)
source = _normalize_line_endings(source)
return compile(source, mangled_filename, "exec", dont_inherit=True)
# note: named `get_source` so that linecache can find the source
# when this is the __loader__ of a module.
def get_source(self, module_name) -> str:
# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
module = self.import_module(demangle(module_name))
return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
# note: named `get_resource_reader` so that importlib.resources can find it.
# This is otherwise considered an internal method.
def get_resource_reader(self, fullname):
try:
package = self._get_package(fullname)
except ImportError:
return None
if package.__loader__ is not self:
return None
return _PackageResourceReader(self, fullname)
def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
if not parent:
return
# Set the module as an attribute on its parent.
parent_module = self.modules[parent]
if parent_module.__loader__ is self:
setattr(parent_module, name.rpartition(".")[2], module)
# note: copied from cpython's import code, with call to create module replaced with _make_module
def _do_find_and_load(self, name):
path = None
parent = name.rpartition(".")[0]
if parent:
if parent not in self.modules:
self._gcd_import(parent)
# Crazy side-effects!
if name in self.modules:
return self.modules[name]
parent_module = self.modules[parent]
try:
path = parent_module.__path__ # type: ignore[attr-defined]
except AttributeError:
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
raise ModuleNotFoundError(msg, name=name) from None
module = self._load_module(name, parent)
self._install_on_parent(parent, name, module)
return module
# note: copied from cpython's import code
def _find_and_load(self, name):
module = self.modules.get(name, _NEEDS_LOADING)
if module is _NEEDS_LOADING:
return self._do_find_and_load(name)
if module is None:
message = "import of {} halted; " "None in sys.modules".format(name)
raise ModuleNotFoundError(message, name=name)
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
# creation of fake submodules via the hacking of sys.modules is not import
# friendly
if name == "os":
self.modules["os.path"] = cast(Any, module).path
elif name == "typing":
self.modules["typing.io"] = cast(Any, module).io
self.modules["typing.re"] = cast(Any, module).re
return module
def _gcd_import(self, name, package=None, level=0):
"""Import and return the module based on its name, the package the call is
being made from, and the level adjustment.
This function represents the greatest common denominator of functionality
between import_module and __import__. This includes setting __package__ if
the loader did not.
"""
_sanity_check(name, package, level)
if level > 0:
name = _resolve_name(name, package, level)
return self._find_and_load(name)
# note: copied from cpython's import code
def _handle_fromlist(self, module, fromlist, *, recursive=False):
"""Figure out what __import__ should return.
The import_ parameter is a callable which takes the name of module to
import. It is required to decouple the function from assuming importlib's
import implementation is desired.
"""
module_name = demangle(module.__name__)
# The hell that is fromlist ...
# If a package was imported, try to import stuff from fromlist.
if hasattr(module, "__path__"):
for x in fromlist:
if not isinstance(x, str):
if recursive:
where = module_name + ".__all__"
else:
where = "``from list''"
raise TypeError(
f"Item in {where} must be str, " f"not {type(x).__name__}"
)
elif x == "*":
if not recursive and hasattr(module, "__all__"):
self._handle_fromlist(module, module.__all__, recursive=True)
elif not hasattr(module, x):
from_name = "{}.{}".format(module_name, x)
try:
self._gcd_import(from_name)
except ModuleNotFoundError as exc:
# Backwards-compatibility dictates we ignore failed
# imports triggered by fromlist for modules that don't
# exist.
if (
exc.name == from_name
and self.modules.get(from_name, _NEEDS_LOADING) is not None
):
continue
raise
return module
def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
if level == 0:
module = self._gcd_import(name)
else:
globals_ = globals if globals is not None else {}
package = _calc___package__(globals_)
module = self._gcd_import(name, package, level)
if not fromlist:
# Return up to the first dot in 'name'. This is complicated by the fact
# that 'name' may be relative.
if level == 0:
return self._gcd_import(name.partition(".")[0])
elif not name:
return module
else:
# Figure out where to slice the module's name up to the first dot
# in 'name'.
cut_off = len(name) - len(name.partition(".")[0])
# Slice end needs to be positive to alleviate need to special-case
# when ``'.' not in name``.
module_name = demangle(module.__name__)
return self.modules[module_name[: len(module_name) - cut_off]]
else:
return self._handle_fromlist(module, fromlist)
def _get_package(self, package):
"""Take a package name or module object and return the module.
If a name, the module is imported. If the passed or imported module
object is not a package, raise an exception.
"""
if hasattr(package, "__spec__"):
if package.__spec__.submodule_search_locations is None:
raise TypeError("{!r} is not a package".format(package.__spec__.name))
else:
return package
else:
module = self.import_module(package)
if module.__spec__.submodule_search_locations is None:
raise TypeError("{!r} is not a package".format(package))
else:
return module
def _zipfile_path(self, package, resource=None):
package = self._get_package(package)
assert package.__loader__ is self
name = demangle(package.__name__)
if resource is not None:
resource = _normalize_path(resource)
return f"{name.replace('.', '/')}/{resource}"
else:
return f"{name.replace('.', '/')}"
def _get_or_create_package(
self, atoms: List[str]
) -> "Union[_PackageNode, _ExternNode]":
cur = self.root
for i, atom in enumerate(atoms):
node = cur.children.get(atom, None)
if node is None:
node = cur.children[atom] = _PackageNode(None)
if isinstance(node, _ExternNode):
return node
if isinstance(node, _ModuleNode):
name = ".".join(atoms[:i])
raise ImportError(
f"inconsistent module structure. module {name} is not a package, but has submodules"
)
assert isinstance(node, _PackageNode)
cur = node
return cur
def _add_file(self, filename: str):
"""Assembles a Python module out of the given file. Will ignore files in the .data directory.
Args:
filename (str): the name of the file inside of the package archive to be added
"""
*prefix, last = filename.split("/")
if len(prefix) > 1 and prefix[0] == ".data":
return
package = self._get_or_create_package(prefix)
if isinstance(package, _ExternNode):
raise ImportError(
f"inconsistent module structure. package contains a module file {filename}"
f" that is a subpackage of a module marked external."
)
if last == "__init__.py":
package.source_file = filename
elif last.endswith(".py"):
package_name = last[: -len(".py")]
package.children[package_name] = _ModuleNode(filename)
def _add_extern(self, extern_name: str):
*prefix, last = extern_name.split(".")
package = self._get_or_create_package(prefix)
if isinstance(package, _ExternNode):
return # the shorter extern covers this extern case
package.children[last] = _ExternNode()
_NEEDS_LOADING = object()
_ERR_MSG_PREFIX = "No module named "
_ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
class _PathNode:
pass
class _PackageNode(_PathNode):
def __init__(self, source_file: Optional[str]):
self.source_file = source_file
self.children: Dict[str, _PathNode] = {}
class _ModuleNode(_PathNode):
__slots__ = ["source_file"]
def __init__(self, source_file: str):
self.source_file = source_file
class _ExternNode(_PathNode):
pass
# A private global registry of all modules that have been package-imported.
_package_imported_modules: WeakValueDictionary = WeakValueDictionary()
# `inspect` by default only looks in `sys.modules` to find source files for classes.
# Patch it to check our private registry of package-imported modules as well.
_orig_getfile = inspect.getfile
def patched_getfile(object):
if inspect.isclass(object):
if object.__module__ in _package_imported_modules:
return _package_imported_modules[object.__module__].__file__
return _orig_getfile(object)
inspect.getfile = patched_getfile
class _PackageResourceReader:
"""Private class used to support PackageImporter.get_resource_reader().
Confirms to the importlib.abc.ResourceReader interface. Allowed to access
the innards of PackageImporter.
"""
def __init__(self, importer, fullname):
self.importer = importer
self.fullname = fullname
def open_resource(self, resource):
from io import BytesIO
return BytesIO(self.importer.load_binary(self.fullname, resource))
def resource_path(self, resource):
# The contract for resource_path is that it either returns a concrete
# file system path or raises FileNotFoundError.
if (
self.importer.zip_reader.is_directory()
and self.importer.zip_reader.has_record(
os.path.join(self.fullname, resource)
)
):
return os.path.join(
self.importer.zip_reader.get_filename(), self.fullname, resource
)
raise FileNotFoundError
def is_resource(self, name):
path = self.importer._zipfile_path(self.fullname, name)
return self.importer.zip_reader.has_record(path)
def contents(self):
from pathlib import Path
filename = self.fullname.replace(".", "/")
fullname_path = Path(self.importer._zipfile_path(self.fullname))
files = self.importer.zip_reader.get_all_records()
subdirs_seen = set()
for filename in files:
try:
relative = Path(filename).relative_to(fullname_path)
except ValueError:
continue
# If the path of the file (which is relative to the top of the zip
# namespace), relative to the package given when the resource
# reader was created, has a parent, then it's a name in a
# subdirectory and thus we skip it.
parent_name = relative.parent.name
if len(parent_name) == 0:
yield relative.name
elif parent_name not in subdirs_seen:
subdirs_seen.add(parent_name)
yield parent_name