mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] populate a special attribute on imported modules (#55255)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55255 This allows packaged code to detect whether or not they are used in a packaged context, and do different things depending on that. An example where this might be useful is to control dynamic dependency loading depending on whether or not something is packaged. Test Plan: Imported from OSS Reviewed By: Lilyjjo Differential Revision: D27544245 Pulled By: suo fbshipit-source-id: 55d44ef57281524b8d9ab890bd387de97f20bd9f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
041b4431b2
commit
a84c92b78b
10
test/package/package_a/use_dunder_package.py
Normal file
10
test/package/package_a/use_dunder_package.py
Normal file
@ -0,0 +1,10 @@
|
||||
if "__torch_package__" in dir():
|
||||
|
||||
def is_from_package():
|
||||
return True
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def is_from_package():
|
||||
return False
|
@ -134,6 +134,46 @@ class TestMisc(PackageTestCase):
|
||||
regular_src = inspect.getsourcelines(regular_class)
|
||||
self.assertEqual(packaged_src, regular_src)
|
||||
|
||||
def test_dunder_package_present(self):
|
||||
"""
|
||||
The attribute '__torch_package__' should be populated on imported modules.
|
||||
"""
|
||||
import package_a.subpackage
|
||||
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_pickle("obj", "obj.pkl", obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = PackageImporter(buffer)
|
||||
mod = pi.import_module(
|
||||
"package_a.subpackage"
|
||||
)
|
||||
self.assertTrue(hasattr(mod, "__torch_package__"))
|
||||
|
||||
def test_dunder_package_works_from_package(self):
|
||||
"""
|
||||
The attribute '__torch_package__' should be accessible from within
|
||||
the module itself, so that packaged code can detect whether it's
|
||||
being used in a packaged context or not.
|
||||
"""
|
||||
import package_a.use_dunder_package as mod
|
||||
|
||||
buffer = BytesIO()
|
||||
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_module(mod.__name__)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = PackageImporter(buffer)
|
||||
imported_mod = pi.import_module(
|
||||
mod.__name__
|
||||
)
|
||||
self.assertTrue(imported_mod.is_from_package())
|
||||
self.assertFalse(mod.is_from_package())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -259,6 +259,7 @@ class PackageImporter(Importer):
|
||||
ns["__file__"] = mangled_filename
|
||||
ns["__cached__"] = None
|
||||
ns["__builtins__"] = self.patched_builtins
|
||||
ns["__torch_package__"] = True
|
||||
|
||||
# Add this module to our private global registry. It should be unique due to mangling.
|
||||
assert module.__name__ not in _package_imported_modules
|
||||
|
Reference in New Issue
Block a user