mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
		
			
				
	
	
		
			779 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			779 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["oncall: package/deploy"]
 | 
						|
 | 
						|
from io import BytesIO
 | 
						|
from textwrap import dedent
 | 
						|
from unittest import skipIf
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.package import PackageExporter, PackageImporter
 | 
						|
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
 | 
						|
 | 
						|
 | 
						|
try:
 | 
						|
    from .common import PackageTestCase
 | 
						|
except ImportError:
 | 
						|
    # Support the case where we run this file directly.
 | 
						|
    from common import PackageTestCase
 | 
						|
 | 
						|
try:
 | 
						|
    from torchvision.models import resnet18
 | 
						|
 | 
						|
    HAS_TORCHVISION = True
 | 
						|
except ImportError:
 | 
						|
    HAS_TORCHVISION = False
 | 
						|
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
 | 
						|
 | 
						|
 | 
						|
class TestPackageScript(PackageTestCase):
 | 
						|
    """Tests for compatibility with TorchScript."""
 | 
						|
 | 
						|
    def test_package_interface(self):
 | 
						|
        """Packaging an interface class should work correctly."""
 | 
						|
 | 
						|
        import package_a.fake_interface as fake
 | 
						|
 | 
						|
        uses_interface = fake.UsesInterface()
 | 
						|
        scripted = torch.jit.script(uses_interface)
 | 
						|
        scripted.proxy_mod = torch.jit.script(fake.NewModule())
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as pe:
 | 
						|
            pe.intern("**")
 | 
						|
            pe.save_pickle("model", "model.pkl", uses_interface)
 | 
						|
        buffer.seek(0)
 | 
						|
 | 
						|
        package_importer = PackageImporter(buffer)
 | 
						|
        loaded = package_importer.load_pickle("model", "model.pkl")
 | 
						|
 | 
						|
        scripted_loaded = torch.jit.script(loaded)
 | 
						|
        scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
 | 
						|
 | 
						|
        input = torch.tensor(1)
 | 
						|
 | 
						|
        self.assertEqual(scripted(input), scripted_loaded(input))
 | 
						|
 | 
						|
    def test_different_package_interface(self):
 | 
						|
        """Test a case where the interface defined in the package is
 | 
						|
        different than the one defined in the loading environment, to make
 | 
						|
        sure TorchScript can distinguish between the two.
 | 
						|
        """
 | 
						|
        # Import one version of the interface
 | 
						|
        import package_a.fake_interface as fake
 | 
						|
 | 
						|
        # Simulate a package that contains a different version of the
 | 
						|
        # interface, with the exact same name.
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as pe:
 | 
						|
            pe.save_source_string(
 | 
						|
                fake.__name__,
 | 
						|
                dedent(
 | 
						|
                    """\
 | 
						|
                    import torch
 | 
						|
                    from torch import Tensor
 | 
						|
 | 
						|
                    @torch.jit.interface
 | 
						|
                    class ModuleInterface(torch.nn.Module):
 | 
						|
                        def one(self, inp1: Tensor) -> Tensor:
 | 
						|
                            pass
 | 
						|
 | 
						|
                    class ImplementsInterface(torch.nn.Module):
 | 
						|
                        def one(self, inp1: Tensor) -> Tensor:
 | 
						|
                            return inp1 + 1
 | 
						|
 | 
						|
                    class UsesInterface(torch.nn.Module):
 | 
						|
                        proxy_mod: ModuleInterface
 | 
						|
 | 
						|
                        def __init__(self) -> None:
 | 
						|
                            super().__init__()
 | 
						|
                            self.proxy_mod = ImplementsInterface()
 | 
						|
 | 
						|
                        def forward(self, input: Tensor) -> Tensor:
 | 
						|
                            return self.proxy_mod.one(input)
 | 
						|
                    """
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        buffer.seek(0)
 | 
						|
 | 
						|
        package_importer = PackageImporter(buffer)
 | 
						|
        diff_fake = package_importer.import_module(fake.__name__)
 | 
						|
        # We should be able to script successfully.
 | 
						|
        torch.jit.script(diff_fake.UsesInterface())
 | 
						|
 | 
						|
    def test_package_script_class(self):
 | 
						|
        import package_a.fake_script_class as fake
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as pe:
 | 
						|
            pe.save_module(fake.__name__)
 | 
						|
        buffer.seek(0)
 | 
						|
 | 
						|
        package_importer = PackageImporter(buffer)
 | 
						|
        loaded = package_importer.import_module(fake.__name__)
 | 
						|
 | 
						|
        input = torch.tensor(1)
 | 
						|
        self.assertTrue(
 | 
						|
            torch.allclose(
 | 
						|
                fake.uses_script_class(input), loaded.uses_script_class(input)
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def test_package_script_class_referencing_self(self):
 | 
						|
        import package_a.fake_script_class as fake
 | 
						|
 | 
						|
        obj = fake.UsesIdListFeature()
 | 
						|
        # intentionally script here to fill the compilation cache, to make sure
 | 
						|
        # there is no false sharing between scripted types coming from the
 | 
						|
        # package vs. outside environment.
 | 
						|
        torch.jit.script(obj)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as exporter:
 | 
						|
            exporter.intern("**")
 | 
						|
            exporter.save_pickle("obj", "obj.pkl", obj)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        obj_loaded = importer.load_pickle("obj", "obj.pkl")
 | 
						|
        scripted_obj_loaded = torch.jit.script(obj_loaded)
 | 
						|
 | 
						|
        # Make sure the scripted object can be serialized without error.
 | 
						|
        buffer2 = scripted_obj_loaded.save_to_buffer()
 | 
						|
        torch.jit.load(BytesIO(buffer2))
 | 
						|
 | 
						|
    def test_different_package_script_class(self):
 | 
						|
        """Test a case where the script class defined in the package is
 | 
						|
        different than the one defined in the loading environment, to make
 | 
						|
        sure TorchScript can distinguish between the two.
 | 
						|
        """
 | 
						|
        import package_a.fake_script_class as fake
 | 
						|
 | 
						|
        # Simulate a package that contains a different version of the
 | 
						|
        # script class ,with the attribute `bar` instead of `foo`
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as pe2:
 | 
						|
            pe2.save_source_string(
 | 
						|
                fake.__name__,
 | 
						|
                dedent(
 | 
						|
                    """\
 | 
						|
                    import torch
 | 
						|
 | 
						|
                    @torch.jit.script
 | 
						|
                    class MyScriptClass:
 | 
						|
                        def __init__(self, x):
 | 
						|
                            self.bar = x
 | 
						|
                    """
 | 
						|
                ),
 | 
						|
            )
 | 
						|
        buffer.seek(0)
 | 
						|
 | 
						|
        package_importer = PackageImporter(buffer)
 | 
						|
        diff_fake = package_importer.import_module(fake.__name__)
 | 
						|
        input = torch.rand(2, 3)
 | 
						|
        loaded_script_class = diff_fake.MyScriptClass(input)
 | 
						|
        orig_script_class = fake.MyScriptClass(input)
 | 
						|
        self.assertEqual(loaded_script_class.bar, orig_script_class.foo)
 | 
						|
 | 
						|
    def test_save_scriptmodule(self):
 | 
						|
        """
 | 
						|
        Test basic saving of ScriptModule.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "mod.pkl", scripted_mod)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod(input), scripted_mod(input))
 | 
						|
 | 
						|
    @skipIf(
 | 
						|
        IS_FBCODE or IS_SANDCASTLE,
 | 
						|
        "Tests that use temporary files are disabled in fbcode",
 | 
						|
    )
 | 
						|
    def test_save_scriptmodule_file(self):
 | 
						|
        """
 | 
						|
        Test basic saving of ScriptModule in file.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
 | 
						|
        filename = self.temp()
 | 
						|
        with PackageExporter(filename) as e:
 | 
						|
            e.save_pickle("res", "mod.pkl", scripted_mod)
 | 
						|
 | 
						|
        importer = PackageImporter(filename)
 | 
						|
        loaded_mod = importer.load_pickle("res", "mod.pkl")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod(input), scripted_mod(input))
 | 
						|
 | 
						|
    def test_save_scriptmodule_with_submods(self):
 | 
						|
        """
 | 
						|
        Test basic saving of ScriptModule with submodule.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithSubmod, ModWithTensor
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(
 | 
						|
            ModWithSubmod(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
        )
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "mod.pkl", scripted_mod)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod = importer.load_pickle("res", "mod.pkl", map_location="cpu")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod(input), scripted_mod(input))
 | 
						|
 | 
						|
    def test_save_scriptmodules_submod_redefinition(self):
 | 
						|
        """
 | 
						|
        Test to verify saving multiple ScriptModules with same top module
 | 
						|
        but different submodules works. Submodule is redefined to between
 | 
						|
        the defintion of the top module to check that the different concrete
 | 
						|
        types of the modules are thoroughly recognized by serializaiton code.
 | 
						|
        """
 | 
						|
 | 
						|
        class Submod(torch.nn.Module):
 | 
						|
            def forward(self, input: str):
 | 
						|
                input = input + "_submod"
 | 
						|
                return input
 | 
						|
 | 
						|
        class TopMod(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.modB = Submod()
 | 
						|
 | 
						|
            def forward(self, input: str):
 | 
						|
                return self.modB(input)
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(TopMod())
 | 
						|
 | 
						|
        # redefinition is intentional, change single inner string
 | 
						|
        # string attribute, should trigger new module type
 | 
						|
        class Submod(torch.nn.Module):  # noqa: F811
 | 
						|
            def forward(self, input: str):
 | 
						|
                input = input + "_submod(changed)"
 | 
						|
                return input
 | 
						|
 | 
						|
        scripted_mod_1 = torch.jit.script(TopMod())
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
 | 
						|
            e.save_pickle("res", "mod2.pkl", scripted_mod_1)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
 | 
						|
        self.assertEqual(loaded_mod_0("input"), scripted_mod_0("input"))
 | 
						|
        self.assertEqual(loaded_mod_1("input"), scripted_mod_1("input"))
 | 
						|
        self.assertNotEqual(loaded_mod_0("input"), loaded_mod_1("input"))
 | 
						|
 | 
						|
    def test_save_independent_scriptmodules(self):
 | 
						|
        """
 | 
						|
        Test to verify saving multiple ScriptModules with completely
 | 
						|
        separate code works.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor, SimpleTest
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(SimpleTest())
 | 
						|
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
 | 
						|
            e.save_pickle("res", "mod2.pkl", scripted_mod_1)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_0 = importer.load_pickle("res", "mod1.pkl")
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod2.pkl")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
 | 
						|
        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
 | 
						|
 | 
						|
    def test_save_repeat_scriptmodules(self):
 | 
						|
        """
 | 
						|
        Test to verify saving multiple different modules and
 | 
						|
        repeats of same scriptmodule in package works. Also tests that
 | 
						|
        PyTorchStreamReader isn't having code hidden from
 | 
						|
        PyTorchStreamWriter writing ScriptModule code files multiple times.
 | 
						|
        """
 | 
						|
        from package_a.test_module import (
 | 
						|
            ModWithSubmodAndTensor,
 | 
						|
            ModWithTensor,
 | 
						|
            SimpleTest,
 | 
						|
        )
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(SimpleTest())
 | 
						|
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
        scripted_mod_2 = torch.jit.script(
 | 
						|
            ModWithSubmodAndTensor(
 | 
						|
                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "mod0.pkl", scripted_mod_0)
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
 | 
						|
            e.save_pickle("res", "mod2.pkl", scripted_mod_0)
 | 
						|
            e.save_pickle("res", "mod3.pkl", scripted_mod_1)
 | 
						|
            e.save_pickle("res", "mod4.pkl", scripted_mod_2)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_0 = importer.load_pickle("res", "mod0.pkl")
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod3.pkl")
 | 
						|
        loaded_mod_2 = importer.load_pickle("res", "mod4.pkl")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod_0(input), scripted_mod_0(input))
 | 
						|
        self.assertEqual(loaded_mod_1(input), scripted_mod_1(input))
 | 
						|
        self.assertEqual(loaded_mod_2(input), scripted_mod_2(input))
 | 
						|
 | 
						|
    def test_scriptmodules_repeat_save(self):
 | 
						|
        """
 | 
						|
        Test to verify saving and loading same ScriptModule object works
 | 
						|
        across multiple packages.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
        scripted_mod_1 = torch.jit.script(
 | 
						|
            ModWithSubmodAndTensor(
 | 
						|
                torch.rand(1, 2, 3), ModWithTensor(torch.rand(1, 2, 3))
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        buffer_0 = BytesIO()
 | 
						|
        with PackageExporter(buffer_0) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
 | 
						|
 | 
						|
        buffer_0.seek(0)
 | 
						|
        importer_0 = PackageImporter(buffer_0)
 | 
						|
        loaded_module_0 = importer_0.load_pickle("res", "mod1.pkl")
 | 
						|
 | 
						|
        buffer_1 = BytesIO()
 | 
						|
        with PackageExporter(buffer_1) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
 | 
						|
            e.save_pickle("res", "mod2.pkl", loaded_module_0)
 | 
						|
 | 
						|
        buffer_1.seek(0)
 | 
						|
        importer_1 = PackageImporter(buffer_1)
 | 
						|
        loaded_module_1 = importer_1.load_pickle("res", "mod1.pkl")
 | 
						|
        reloaded_module_0 = importer_1.load_pickle("res", "mod2.pkl")
 | 
						|
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_module_0(input), scripted_mod_0(input))
 | 
						|
        self.assertEqual(loaded_module_0(input), reloaded_module_0(input))
 | 
						|
        self.assertEqual(loaded_module_1(input), scripted_mod_1(input))
 | 
						|
 | 
						|
    @skipIfNoTorchVision
 | 
						|
    def test_save_scriptmodule_only_necessary_code(self):
 | 
						|
        """
 | 
						|
        Test to verify when saving multiple packages with same CU
 | 
						|
        that packages don't include unnecessary torchscript code files.
 | 
						|
        The TorchVision code should only be saved in the package that
 | 
						|
        relies on it.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor
 | 
						|
 | 
						|
        class ModWithTorchVision(torch.nn.Module):
 | 
						|
            def __init__(self, name: str):
 | 
						|
                super().__init__()
 | 
						|
                self.tvmod = resnet18()
 | 
						|
 | 
						|
            def forward(self, input):
 | 
						|
                return input * 4
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(ModWithTorchVision("foo"))
 | 
						|
        scripted_mod_1 = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
 | 
						|
        buffer_0 = BytesIO()
 | 
						|
        with PackageExporter(buffer_0) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_0)
 | 
						|
 | 
						|
        buffer_0.seek(0)
 | 
						|
        importer_0 = importer = PackageImporter(buffer_0)
 | 
						|
 | 
						|
        buffer_1 = BytesIO()
 | 
						|
        with PackageExporter(buffer_1) as e:
 | 
						|
            e.save_pickle("res", "mod1.pkl", scripted_mod_1)
 | 
						|
 | 
						|
        buffer_1.seek(0)
 | 
						|
        importer_1 = PackageImporter(buffer_1)
 | 
						|
 | 
						|
        self.assertTrue("torchvision" in str(importer_0.file_structure()))
 | 
						|
        self.assertFalse("torchvision" in str(importer_1.file_structure()))
 | 
						|
 | 
						|
    def test_save_scriptmodules_in_container(self):
 | 
						|
        """
 | 
						|
        Test saving of ScriptModules inside of container. Checks that relations
 | 
						|
        between shared modules are upheld.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
 | 
						|
 | 
						|
        scripted_mod_a = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
 | 
						|
        scripted_mod_b = torch.jit.script(
 | 
						|
            ModWithSubmodAndTensor(torch.rand(1, 2, 3), scripted_mod_a)
 | 
						|
        )
 | 
						|
        script_mods_list = [scripted_mod_a, scripted_mod_b]
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("res", "list.pkl", script_mods_list)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_list = importer.load_pickle("res", "list.pkl")
 | 
						|
        input = torch.rand(1, 2, 3)
 | 
						|
        self.assertEqual(loaded_mod_list[0](input), scripted_mod_a(input))
 | 
						|
        self.assertEqual(loaded_mod_list[1](input), scripted_mod_b(input))
 | 
						|
 | 
						|
    def test_save_eager_mods_sharing_scriptmodule(self):
 | 
						|
        """
 | 
						|
        Test saving of single ScriptModule shared by multiple
 | 
						|
        eager modules (ScriptModule should be saved just once
 | 
						|
        even though is contained in multiple pickles).
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithSubmod, SimpleTest
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(SimpleTest())
 | 
						|
 | 
						|
        mod1 = ModWithSubmod(scripted_mod)
 | 
						|
        mod2 = ModWithSubmod(scripted_mod)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "mod1.pkl", mod1)
 | 
						|
            e.save_pickle("res", "mod2.pkl", mod2)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        file_structure = importer.file_structure()
 | 
						|
        self.assertTrue(file_structure.has_file(".data/ts_code/0"))
 | 
						|
        self.assertFalse(file_structure.has_file(".data/ts_code/1"))
 | 
						|
 | 
						|
    def test_load_shared_scriptmodules(self):
 | 
						|
        """
 | 
						|
        Test loading of single ScriptModule shared by multiple eager
 | 
						|
        modules in single pickle (ScriptModule objects should be the same).
 | 
						|
        """
 | 
						|
        from package_a.test_module import (
 | 
						|
            ModWithMultipleSubmods,
 | 
						|
            ModWithSubmod,
 | 
						|
            SimpleTest,
 | 
						|
        )
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(SimpleTest())
 | 
						|
 | 
						|
        mod1 = ModWithSubmod(scripted_mod)
 | 
						|
        mod2 = ModWithSubmod(scripted_mod)
 | 
						|
 | 
						|
        mod_parent = ModWithMultipleSubmods(mod1, mod2)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "mod.pkl", mod_parent)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
 | 
						|
        loaded_mod = importer.load_pickle("res", "mod.pkl")
 | 
						|
        self.assertTrue(
 | 
						|
            id(loaded_mod.mod1.script_mod) == id(loaded_mod.mod2.script_mod)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_save_shared_tensors(self):
 | 
						|
        """
 | 
						|
        Test tensors shared across eager and ScriptModules are serialized once.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithSubmodAndTensor, ModWithTensor
 | 
						|
 | 
						|
        shared_tensor = torch.rand(2, 3, 4)
 | 
						|
        scripted_mod = torch.jit.script(ModWithTensor(shared_tensor))
 | 
						|
 | 
						|
        mod1 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
 | 
						|
        mod2 = ModWithSubmodAndTensor(shared_tensor, scripted_mod)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "tensor", shared_tensor)
 | 
						|
            e.save_pickle("res", "mod1.pkl", mod1)
 | 
						|
            e.save_pickle("res", "mod2.pkl", mod2)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
 | 
						|
 | 
						|
        # assert that there is only one storage stored in package
 | 
						|
        file_structure = importer.file_structure(include=".data/*.storage")
 | 
						|
        self.assertTrue(len(file_structure.children[".data"].children) == 1)
 | 
						|
 | 
						|
        input = torch.rand(2, 3, 4)
 | 
						|
        self.assertEqual(loaded_mod_1(input), mod1(input))
 | 
						|
 | 
						|
    def test_load_shared_tensors(self):
 | 
						|
        """
 | 
						|
        Test tensors shared across eager and ScriptModules on load
 | 
						|
        are the same.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor
 | 
						|
 | 
						|
        shared_tensor = torch.ones(3, 3)
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
 | 
						|
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
 | 
						|
 | 
						|
        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
 | 
						|
 | 
						|
        self.assertEqual(
 | 
						|
            shared_tensor.storage()._cdata,
 | 
						|
            scripted_mod_0.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
        self.assertEqual(
 | 
						|
            shared_tensor.storage()._cdata,
 | 
						|
            scripted_mod_1.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "mod1.pkl", mod1)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
 | 
						|
 | 
						|
        self.assertEqual(
 | 
						|
            loaded_mod_1.tensor.storage()._cdata,
 | 
						|
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
        self.assertEqual(
 | 
						|
            loaded_mod_1.tensor.storage()._cdata,
 | 
						|
            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
 | 
						|
        loaded_mod_1.tensor.add_(torch.ones(3, 3))
 | 
						|
 | 
						|
        self.assertTrue(
 | 
						|
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_load_shared_tensors_repackaged(self):
 | 
						|
        """
 | 
						|
        Test tensors shared across eager and ScriptModules on load
 | 
						|
        are the same across multiple package saves and loads. This is
 | 
						|
        an important test because not all of the tensor information is restored
 | 
						|
        in python between packages. The python identity is not maintained, but
 | 
						|
        the backing cpp TensorImpl is. We load/save storages based off of this
 | 
						|
        cpp TensorImpl and not the python identity.
 | 
						|
        """
 | 
						|
        from package_a.test_module import ModWithTensor, ModWithTwoSubmodsAndTensor
 | 
						|
 | 
						|
        shared_tensor = torch.ones(3, 3)
 | 
						|
 | 
						|
        scripted_mod_0 = torch.jit.script(ModWithTensor(shared_tensor))
 | 
						|
        scripted_mod_1 = torch.jit.script(ModWithTensor(shared_tensor))
 | 
						|
 | 
						|
        mod1 = ModWithTwoSubmodsAndTensor(shared_tensor, scripted_mod_0, scripted_mod_1)
 | 
						|
 | 
						|
        buffer_0 = BytesIO()
 | 
						|
        with PackageExporter(buffer_0) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "mod1.pkl", mod1)
 | 
						|
 | 
						|
        buffer_0.seek(0)
 | 
						|
        importer_0 = PackageImporter(buffer_0)
 | 
						|
        loaded_mod_0 = importer_0.load_pickle("res", "mod1.pkl")
 | 
						|
 | 
						|
        buffer_1 = BytesIO()
 | 
						|
        with PackageExporter(buffer_1, importer=importer_0) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "mod1.pkl", loaded_mod_0)
 | 
						|
 | 
						|
        buffer_1.seek(0)
 | 
						|
        importer = PackageImporter(buffer_1)
 | 
						|
        loaded_mod_1 = importer.load_pickle("res", "mod1.pkl")
 | 
						|
 | 
						|
        self.assertEqual(
 | 
						|
            loaded_mod_1.tensor.storage()._cdata,
 | 
						|
            loaded_mod_1.sub_mod_0.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
        self.assertEqual(
 | 
						|
            loaded_mod_1.tensor.storage()._cdata,
 | 
						|
            loaded_mod_1.sub_mod_1.tensor.storage()._cdata,
 | 
						|
        )
 | 
						|
 | 
						|
        loaded_mod_1.tensor.add_(
 | 
						|
            torch.ones(3, 3)
 | 
						|
        )  # all tensors should reflect this change
 | 
						|
 | 
						|
        self.assertTrue(
 | 
						|
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_0.tensor)
 | 
						|
        )
 | 
						|
        self.assertTrue(
 | 
						|
            torch.allclose(loaded_mod_1.tensor, loaded_mod_1.sub_mod_1.tensor)
 | 
						|
        )
 | 
						|
 | 
						|
    def test_saving_and_scripting_packaged_mod(self):
 | 
						|
        """
 | 
						|
        Test scripting a module loaded from a package
 | 
						|
        and saving it in a new package as a script object.
 | 
						|
        """
 | 
						|
        from package_a.test_module import SimpleTest
 | 
						|
 | 
						|
        orig_mod = SimpleTest()
 | 
						|
 | 
						|
        buffer_0 = BytesIO()
 | 
						|
        with PackageExporter(buffer_0) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("model", "model.pkl", orig_mod)
 | 
						|
 | 
						|
        buffer_0.seek(0)
 | 
						|
        importer_0 = PackageImporter(buffer_0)
 | 
						|
        loaded_mod = importer_0.load_pickle("model", "model.pkl")
 | 
						|
 | 
						|
        input = torch.rand(2, 3)
 | 
						|
        self.assertEqual(loaded_mod(input), orig_mod(input))
 | 
						|
 | 
						|
        scripted_mod = torch.jit.script(loaded_mod)
 | 
						|
 | 
						|
        buffer_1 = BytesIO()
 | 
						|
        with PackageExporter(buffer_1, importer=importer_0) as e:
 | 
						|
            e.intern("**")
 | 
						|
            e.save_pickle("res", "scripted_mod.pkl", scripted_mod)
 | 
						|
 | 
						|
        buffer_1.seek(0)
 | 
						|
        importer_1 = PackageImporter(buffer_1)
 | 
						|
        loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl")
 | 
						|
 | 
						|
        self.assertEqual(loaded_mod_scripted(input), orig_mod(input))
 | 
						|
 | 
						|
    def test_mixing_packaged_and_inline_modules(self):
 | 
						|
        """
 | 
						|
        Test saving inline and imported modules in same package with
 | 
						|
        independent code.
 | 
						|
        """
 | 
						|
 | 
						|
        class InlineMod(torch.nn.Module):
 | 
						|
            def __init__(self, name: str):
 | 
						|
                super().__init__()
 | 
						|
                self.name = name
 | 
						|
                self.tensor = torch.rand(1, 2, 3)
 | 
						|
 | 
						|
            def forward(self, input: str):
 | 
						|
                input = input + "_modInline:" + self.name
 | 
						|
                return input, (self.tensor * 4)
 | 
						|
 | 
						|
        inline_mod = InlineMod("inline")
 | 
						|
        scripted_inline = torch.jit.script(inline_mod)
 | 
						|
 | 
						|
        from package_a.test_module import SimpleTest
 | 
						|
 | 
						|
        imported_mod = SimpleTest()
 | 
						|
        scripted_imported = torch.jit.script(imported_mod)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("model", "inline.pkl", scripted_inline)
 | 
						|
            e.save_pickle("model", "imported.pkl", scripted_imported)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_inline = importer.load_pickle("model", "inline.pkl")
 | 
						|
        loaded_imported = importer.load_pickle("model", "imported.pkl")
 | 
						|
 | 
						|
        input = torch.rand(2, 3)
 | 
						|
        self.assertEqual(loaded_imported(input), imported_mod(input))
 | 
						|
        self.assertEqual(loaded_inline("input"), inline_mod("input"))
 | 
						|
 | 
						|
    @skipIfNoTorchVision
 | 
						|
    def test_mixing_packaged_and_inline_modules_shared_code(self):
 | 
						|
        """
 | 
						|
        Test saving inline and imported modules in same package that
 | 
						|
        share code.
 | 
						|
        """
 | 
						|
 | 
						|
        class TorchVisionTestInline(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.tvmod = resnet18()
 | 
						|
 | 
						|
            def forward(self, x):
 | 
						|
                x = a_non_torch_leaf(x, x)
 | 
						|
                return torch.relu(x + 3.0)
 | 
						|
 | 
						|
        def a_non_torch_leaf(a, b):
 | 
						|
            return a + b
 | 
						|
 | 
						|
        inline_mod = TorchVisionTestInline()
 | 
						|
        scripted_inline = torch.jit.script(inline_mod)
 | 
						|
 | 
						|
        from package_c.test_module import TorchVisionTest
 | 
						|
 | 
						|
        imported_mod = TorchVisionTest()
 | 
						|
        scripted_imported = torch.jit.script(imported_mod)
 | 
						|
 | 
						|
        buffer = BytesIO()
 | 
						|
        with PackageExporter(buffer) as e:
 | 
						|
            e.save_pickle("model", "inline.pkl", scripted_inline)
 | 
						|
            e.save_pickle("model", "imported.pkl", scripted_imported)
 | 
						|
 | 
						|
        buffer.seek(0)
 | 
						|
        importer = PackageImporter(buffer)
 | 
						|
        loaded_inline = importer.load_pickle("model", "inline.pkl")
 | 
						|
        loaded_imported = importer.load_pickle("model", "imported.pkl")
 | 
						|
 | 
						|
        input = torch.rand(2, 3)
 | 
						|
        self.assertEqual(loaded_imported(input), imported_mod(input))
 | 
						|
        self.assertEqual(loaded_inline(input), inline_mod(input))
 | 
						|
 | 
						|
    def test_tensor_sharing_pickle(self):
 | 
						|
        """Test that saving a ScriptModule and a separately saving a tensor
 | 
						|
        object causes no issues.
 | 
						|
        """
 | 
						|
 | 
						|
        class M(torch.nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.foo = torch.ones(2, 3)
 | 
						|
 | 
						|
            def forward(self):
 | 
						|
                return self.foo
 | 
						|
 | 
						|
        scripted_m = torch.jit.script(M())
 | 
						|
        original_tensor = torch.ones(0)
 | 
						|
 | 
						|
        f = BytesIO()
 | 
						|
        with torch.package.PackageExporter(f) as exporter:
 | 
						|
            exporter.save_pickle("model", "model.pkl", scripted_m)
 | 
						|
            exporter.save_pickle("model", "input.pkl", original_tensor)
 | 
						|
 | 
						|
        f.seek(0)
 | 
						|
        # Should be able to load correctly
 | 
						|
        importer = PackageImporter(f)
 | 
						|
        loaded_m = importer.load_pickle("model", "model.pkl")
 | 
						|
        loaded_tensor = importer.load_pickle("model", "input.pkl")
 | 
						|
 | 
						|
        self.assertEqual(scripted_m.foo, loaded_m.foo)
 | 
						|
        self.assertEqual(original_tensor, loaded_tensor)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |