mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] populate a special attribute on imported modules (#55255)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55255 This allows packaged code to detect whether or not they are used in a packaged context, and do different things depending on that. An example where this might be useful is to control dynamic dependency loading depending on whether or not something is packaged. Test Plan: Imported from OSS Reviewed By: Lilyjjo Differential Revision: D27544245 Pulled By: suo fbshipit-source-id: 55d44ef57281524b8d9ab890bd387de97f20bd9f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
041b4431b2
commit
a84c92b78b
@ -134,6 +134,46 @@ class TestMisc(PackageTestCase):
|
||||
regular_src = inspect.getsourcelines(regular_class)
|
||||
self.assertEqual(packaged_src, regular_src)
|
||||
|
||||
def test_dunder_package_present(self):
|
||||
"""
|
||||
The attribute '__torch_package__' should be populated on imported modules.
|
||||
"""
|
||||
import package_a.subpackage
|
||||
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_pickle("obj", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = PackageImporter(buffer)
|
||||
mod = pi.import_module(
|
||||
"package_a.subpackage"
|
||||
)
|
||||
self.assertTrue(hasattr(mod, "__torch_package__"))
|
||||
|
||||
def test_dunder_package_works_from_package(self):
|
||||
"""
|
||||
The attribute '__torch_package__' should be accessible from within
|
||||
the module itself, so that packaged code can detect whether it's
|
||||
being used in a packaged context or not.
|
||||
"""
|
||||
import package_a.use_dunder_package as mod
|
||||
|
||||
buffer = BytesIO()
|
||||
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_module(mod.__name__)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = PackageImporter(buffer)
|
||||
imported_mod = pi.import_module(
|
||||
mod.__name__
|
||||
)
|
||||
self.assertTrue(imported_mod.is_from_package())
|
||||
self.assertFalse(mod.is_from_package())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user