Files
pytorch/test/package/test_package_script.py
Sam Estep 75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00

146 lines
4.9 KiB
Python

from io import BytesIO
from textwrap import dedent
import torch
from torch.package import (
PackageExporter,
PackageImporter,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestPackageScript(PackageTestCase):
"""Tests for compatibility with TorchScript."""
def test_package_interface(self):
"""Packaging an interface class should work correctly."""
import package_a.fake_interface as fake
uses_interface = fake.UsesInterface()
scripted = torch.jit.script(uses_interface)
scripted.proxy_mod = torch.jit.script(fake.NewModule())
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_pickle("model", "model.pkl", uses_interface)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.load_pickle("model", "model.pkl")
scripted_loaded = torch.jit.script(loaded)
scripted_loaded.proxy_mod = torch.jit.script(fake.NewModule())
input = torch.tensor(1)
self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input)))
def test_different_package_interface(self):
"""Test a case where the interface defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
# Import one version of the interface
import package_a.fake_interface as fake
# Simulate a package that contains a different version of the
# interface, with the exact same name.
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_source_string(
fake.__name__,
dedent(
"""\
import torch
from torch import Tensor
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
pass
class ImplementsInterface(torch.nn.Module):
def one(self, inp1: Tensor) -> Tensor:
return inp1 + 1
class UsesInterface(torch.nn.Module):
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
self.proxy_mod = ImplementsInterface()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.one(input)
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
# We should be able to script successfully.
torch.jit.script(diff_fake.UsesInterface())
def test_package_script_class(self):
import package_a.fake_script_class as fake
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_module(fake.__name__)
buffer.seek(0)
package_importer = PackageImporter(buffer)
loaded = package_importer.import_module(fake.__name__)
input = torch.tensor(1)
self.assertTrue(
torch.allclose(
fake.uses_script_class(input), loaded.uses_script_class(input)
)
)
def test_different_package_script_class(self):
"""Test a case where the script class defined in the package is
different than the one defined in the loading environment, to make
sure TorchScript can distinguish between the two.
"""
import package_a.fake_script_class as fake
# Simulate a package that contains a different version of the
# script class ,with the attribute `bar` instead of `foo`
buffer = BytesIO()
with PackageExporter(buffer, verbose=False) as pe2:
pe2.save_source_string(
fake.__name__,
dedent(
"""\
import torch
@torch.jit.script
class MyScriptClass:
def __init__(self, x):
self.bar = x
"""
),
)
buffer.seek(0)
package_importer = PackageImporter(buffer)
diff_fake = package_importer.import_module(fake.__name__)
input = torch.rand(2, 3)
loaded_script_class = diff_fake.MyScriptClass(input)
orig_script_class = fake.MyScriptClass(input)
self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo))
if __name__ == "__main__":
run_tests()