mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58573 Users can create invalid imports, like: ``` HG: in a top-level package if False: from .. import foo ``` Since this code is never executed, it will not cause the module to fail to load. But our dependency analysis walks every `import` statement in the AST, and will attempt to resolve the (incorrectly formed) import, throwing an exception. For posterity, the code that triggered this: https://git.io/JsCgM Differential Revision: D28543980 Test Plan: Added a unit test Reviewed By: Chillee Pulled By: suo fbshipit-source-id: 03b7e274633945b186500fab6f974973ef8c7c7d Co-authored-by: Michael Suo <suo@fb.com>
		
			
				
	
	
		
			313 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			313 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import importlib
 | |
| from io import BytesIO
 | |
| from sys import version_info
 | |
| from textwrap import dedent
 | |
| from unittest import skipIf
 | |
| 
 | |
| from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
 | |
| from torch.package.package_exporter import PackagingError
 | |
| from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
 | |
| 
 | |
| try:
 | |
|     from .common import PackageTestCase
 | |
| except ImportError:
 | |
|     # Support the case where we run this file directly.
 | |
|     from common import PackageTestCase
 | |
| 
 | |
| 
 | |
| class TestDependencyAPI(PackageTestCase):
 | |
|     """Dependency management API tests.
 | |
|     - mock()
 | |
|     - extern()
 | |
|     - deny()
 | |
|     """
 | |
| 
 | |
|     def test_extern(self):
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.extern(["package_a.subpackage", "module_a"])
 | |
|             he.save_source_string("foo", "import package_a.subpackage; import module_a")
 | |
|         buffer.seek(0)
 | |
|         hi = PackageImporter(buffer)
 | |
|         import module_a
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         module_a_im = hi.import_module("module_a")
 | |
|         hi.import_module("package_a.subpackage")
 | |
|         package_a_im = hi.import_module("package_a")
 | |
| 
 | |
|         self.assertIs(module_a, module_a_im)
 | |
|         self.assertIsNot(package_a, package_a_im)
 | |
|         self.assertIs(package_a.subpackage, package_a_im.subpackage)
 | |
| 
 | |
|     def test_extern_glob(self):
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.extern(["package_a.*", "module_*"])
 | |
|             he.save_module("package_a")
 | |
|             he.save_source_string(
 | |
|                 "test_module",
 | |
|                 dedent(
 | |
|                     """\
 | |
|                     import package_a.subpackage
 | |
|                     import module_a
 | |
|                     """
 | |
|                 ),
 | |
|             )
 | |
|         buffer.seek(0)
 | |
|         hi = PackageImporter(buffer)
 | |
|         import module_a
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         module_a_im = hi.import_module("module_a")
 | |
|         hi.import_module("package_a.subpackage")
 | |
|         package_a_im = hi.import_module("package_a")
 | |
| 
 | |
|         self.assertIs(module_a, module_a_im)
 | |
|         self.assertIsNot(package_a, package_a_im)
 | |
|         self.assertIs(package_a.subpackage, package_a_im.subpackage)
 | |
| 
 | |
|     def test_extern_glob_allow_empty(self):
 | |
|         """
 | |
|         Test that an error is thrown when a extern glob is specified with allow_empty=True
 | |
|         and no matching module is required during packaging.
 | |
|         """
 | |
|         import package_a.subpackage  # noqa: F401
 | |
| 
 | |
|         buffer = BytesIO()
 | |
|         with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
 | |
|             with PackageExporter(buffer, verbose=False) as exporter:
 | |
|                 exporter.extern(include=["package_b.*"], allow_empty=False)
 | |
|                 exporter.save_module("package_a.subpackage")
 | |
| 
 | |
|     def test_deny(self):
 | |
|         """
 | |
|         Test marking packages as "deny" during export.
 | |
|         """
 | |
|         buffer = BytesIO()
 | |
| 
 | |
|         with self.assertRaisesRegex(PackagingError, "denied"):
 | |
|             with PackageExporter(buffer, verbose=False) as exporter:
 | |
|                 exporter.deny(["package_a.subpackage", "module_a"])
 | |
|                 exporter.save_source_string("foo", "import package_a.subpackage")
 | |
| 
 | |
|     def test_deny_glob(self):
 | |
|         """
 | |
|         Test marking packages as "deny" using globs instead of package names.
 | |
|         """
 | |
|         buffer = BytesIO()
 | |
|         with self.assertRaises(PackagingError):
 | |
|             with PackageExporter(buffer, verbose=False) as exporter:
 | |
|                 exporter.deny(["package_a.*", "module_*"])
 | |
|                 exporter.save_source_string(
 | |
|                     "test_module",
 | |
|                     dedent(
 | |
|                         """\
 | |
|                         import package_a.subpackage
 | |
|                         import module_a
 | |
|                         """
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|     @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
 | |
|     def test_mock(self):
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.mock(["package_a.subpackage", "module_a"])
 | |
|             # Import something that dependso n package_a.subpackage
 | |
|             he.save_source_string("foo", "import package_a.subpackage")
 | |
|         buffer.seek(0)
 | |
|         hi = PackageImporter(buffer)
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         _ = package_a.subpackage
 | |
|         import module_a
 | |
| 
 | |
|         _ = module_a
 | |
| 
 | |
|         m = hi.import_module("package_a.subpackage")
 | |
|         r = m.result
 | |
|         with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
 | |
|             r()
 | |
| 
 | |
|     @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
 | |
|     def test_mock_glob(self):
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.mock(["package_a.*", "module*"])
 | |
|             he.save_module("package_a")
 | |
|             he.save_source_string(
 | |
|                 "test_module",
 | |
|                 dedent(
 | |
|                     """\
 | |
|                     import package_a.subpackage
 | |
|                     import module_a
 | |
|                     """
 | |
|                 ),
 | |
|             )
 | |
|         buffer.seek(0)
 | |
|         hi = PackageImporter(buffer)
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         _ = package_a.subpackage
 | |
|         import module_a
 | |
| 
 | |
|         _ = module_a
 | |
| 
 | |
|         m = hi.import_module("package_a.subpackage")
 | |
|         r = m.result
 | |
|         with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
 | |
|             r()
 | |
| 
 | |
|     def test_mock_glob_allow_empty(self):
 | |
|         """
 | |
|         Test that an error is thrown when a mock glob is specified with allow_empty=True
 | |
|         and no matching module is required during packaging.
 | |
|         """
 | |
|         import package_a.subpackage  # noqa: F401
 | |
| 
 | |
|         buffer = BytesIO()
 | |
|         with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
 | |
|             with PackageExporter(buffer, verbose=False) as exporter:
 | |
|                 exporter.mock(include=["package_b.*"], allow_empty=False)
 | |
|                 exporter.save_module("package_a.subpackage")
 | |
| 
 | |
|     @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
 | |
|     def test_pickle_mocked(self):
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         obj = package_a.subpackage.PackageASubpackageObject()
 | |
|         obj2 = package_a.PackageAObject(obj)
 | |
| 
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.mock(include="package_a.subpackage")
 | |
|             he.intern("**")
 | |
|             he.save_pickle("obj", "obj.pkl", obj2)
 | |
| 
 | |
|         buffer.seek(0)
 | |
| 
 | |
|         hi = PackageImporter(buffer)
 | |
|         with self.assertRaises(NotImplementedError):
 | |
|             hi.load_pickle("obj", "obj.pkl")
 | |
| 
 | |
|     def test_allow_empty_with_error(self):
 | |
|         """If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
 | |
|         buffer = BytesIO()
 | |
|         with self.assertRaises(ModuleNotFoundError):
 | |
|             with PackageExporter(buffer, verbose=False) as pe:
 | |
|                 # Even though we did not extern a module that matches this
 | |
|                 # pattern, we want to show the save_module error, not the allow_empty error.
 | |
| 
 | |
|                 pe.extern("foo", allow_empty=False)
 | |
|                 pe.save_module("aodoifjodisfj")  # will error
 | |
| 
 | |
|                 # we never get here, so technically the allow_empty check
 | |
|                 # should raise an error. However, the error above is more
 | |
|                 # informative to what's actually going wrong with packaging.
 | |
|                 pe.save_source_string("bar", "import foo\n")
 | |
| 
 | |
|     def test_implicit_intern(self):
 | |
|         """The save_module APIs should implicitly intern the module being saved."""
 | |
|         import package_a  # noqa: F401
 | |
| 
 | |
|         buffer = BytesIO()
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.save_module("package_a")
 | |
| 
 | |
|     def test_intern_error(self):
 | |
|         """Failure to handle all dependencies should lead to an error."""
 | |
|         import package_a.subpackage
 | |
| 
 | |
|         obj = package_a.subpackage.PackageASubpackageObject()
 | |
|         obj2 = package_a.PackageAObject(obj)
 | |
| 
 | |
|         buffer = BytesIO()
 | |
| 
 | |
|         with self.assertRaises(PackagingError) as e:
 | |
|             with PackageExporter(buffer, verbose=False) as he:
 | |
|                 he.save_pickle("obj", "obj.pkl", obj2)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             str(e.exception),
 | |
|             dedent(
 | |
|                 """
 | |
|                 * Module did not match against any action pattern. Extern, mock, or intern it.
 | |
|                     package_a
 | |
|                     package_a.subpackage
 | |
|                 """
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|         # Interning all dependencies should work
 | |
|         with PackageExporter(buffer, verbose=False) as he:
 | |
|             he.intern(["package_a", "package_a.subpackage"])
 | |
|             he.save_pickle("obj", "obj.pkl", obj2)
 | |
| 
 | |
|     @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows")
 | |
|     def test_broken_dependency(self):
 | |
|         """A unpackageable dependency should raise a PackagingError."""
 | |
| 
 | |
|         def create_module(name):
 | |
|             spec = importlib.machinery.ModuleSpec(name, self, is_package=False)  # type: ignore[arg-type]
 | |
|             module = importlib.util.module_from_spec(spec)
 | |
|             ns = module.__dict__
 | |
|             ns["__spec__"] = spec
 | |
|             ns["__loader__"] = self
 | |
|             ns["__file__"] = f"{name}.so"
 | |
|             ns["__cached__"] = None
 | |
|             return module
 | |
| 
 | |
|         class BrokenImporter(Importer):
 | |
|             def __init__(self):
 | |
|                 self.modules = {
 | |
|                     "foo": create_module("foo"),
 | |
|                     "bar": create_module("bar"),
 | |
|                 }
 | |
| 
 | |
|             def import_module(self, module_name):
 | |
|                 return self.modules[module_name]
 | |
| 
 | |
|         buffer = BytesIO()
 | |
| 
 | |
|         with self.assertRaises(PackagingError) as e:
 | |
|             with PackageExporter(
 | |
|                 buffer, verbose=False, importer=BrokenImporter()
 | |
|             ) as exporter:
 | |
|                 exporter.intern(["foo", "bar"])
 | |
|                 exporter.save_source_string("my_module", "import foo; import bar")
 | |
| 
 | |
|         self.assertEqual(
 | |
|             str(e.exception),
 | |
|             dedent(
 | |
|                 """
 | |
|                 * Module is a C extension module. torch.package supports Python modules only.
 | |
|                     foo
 | |
|                     bar
 | |
|                 """
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     def test_invalid_import(self):
 | |
|         """An incorrectly-formed import should raise a PackagingError."""
 | |
|         buffer = BytesIO()
 | |
|         with self.assertRaises(PackagingError) as e:
 | |
|             with PackageExporter(buffer, verbose=False) as exporter:
 | |
|                 # This import will fail to load.
 | |
|                 exporter.save_source_string("foo", "from ........ import lol")
 | |
| 
 | |
|         self.assertEqual(
 | |
|             str(e.exception),
 | |
|             dedent(
 | |
|                 """
 | |
|                 * Dependency resolution failed.
 | |
|                     foo
 | |
|                       Context: attempted relative import beyond top-level package
 | |
|                 """
 | |
|             ),
 | |
|         )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |