mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] patch inspect.getfile to work with PackageImporter (#51568)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51568 The default behavior of inspect.getfile doesn't work on classes imported from PackageImporter, because it returns the following. sys.modules[kls.__module__].__file__ Looking in `sys.modules` is hard-coded behavior. So, patch it to first check a similar registry of PackageImported modules we maintain. Test Plan: Imported from OSS Reviewed By: yf225 Differential Revision: D26201236 Pulled By: suo fbshipit-source-id: aaf5d7ee8ca0155619c8185e64f70a30152ac567
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b6c6fb7252
commit
55a4aa79aa
@ -1,4 +1,5 @@
|
||||
from unittest import skipIf
|
||||
import inspect
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
from tempfile import NamedTemporaryFile
|
||||
from torch.package import PackageExporter, PackageImporter
|
||||
@ -507,6 +508,24 @@ def load():
|
||||
with self.assertRaises(NotImplementedError):
|
||||
hi.load_pickle('obj', 'obj.pkl')
|
||||
|
||||
def test_inspect_class(self):
|
||||
"""Should be able to retrieve source for a packaged class."""
|
||||
import package_a.subpackage
|
||||
buffer = BytesIO()
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
|
||||
with PackageExporter(buffer, verbose=False) as pe:
|
||||
pe.save_pickle('obj', 'obj.pkl', obj)
|
||||
|
||||
buffer.seek(0)
|
||||
pi = PackageImporter(buffer)
|
||||
packaged_class = pi.import_module('package_a.subpackage').PackageASubpackageObject
|
||||
regular_class = package_a.subpackage.PackageASubpackageObject
|
||||
|
||||
packaged_src = inspect.getsourcelines(packaged_class)
|
||||
regular_src = inspect.getsourcelines(regular_class)
|
||||
self.assertEqual(packaged_src, regular_src)
|
||||
|
||||
|
||||
class ManglingTest(TestCase):
|
||||
def test_unique_manglers(self):
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import List, Callable, Dict, Optional, Any, Union, BinaryIO
|
||||
from types import ModuleType
|
||||
import builtins
|
||||
import importlib
|
||||
import inspect
|
||||
import linecache
|
||||
from torch.serialization import _load
|
||||
import pickle
|
||||
@ -168,6 +170,10 @@ class PackageImporter:
|
||||
ns['__cached__'] = None
|
||||
ns['__builtins__'] = self.patched_builtins
|
||||
|
||||
# Add this module to our private global registry. It should be unique due to mangling.
|
||||
assert module.__name__ not in _package_imported_modules
|
||||
_package_imported_modules[module.__name__] = module
|
||||
|
||||
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
|
||||
# access sys.modules
|
||||
self._install_on_parent(parent, name, module)
|
||||
@ -429,3 +435,16 @@ class _ModuleNode(_PathNode):
|
||||
|
||||
class _ExternNode(_PathNode):
|
||||
pass
|
||||
|
||||
# A private global registry of all modules that have been package-imported.
|
||||
_package_imported_modules: Dict[str, ModuleType] = {}
|
||||
|
||||
# `inspect` by default only looks in `sys.modules` to find source files for classes.
|
||||
# Patch it to check our private registry of package-imported modules as well.
|
||||
_orig_getfile = inspect.getfile
|
||||
def patched_getfile(object):
|
||||
if inspect.isclass(object):
|
||||
if object.__module__ in _package_imported_modules:
|
||||
return _package_imported_modules[object.__module__].__file__
|
||||
return _orig_getfile(object)
|
||||
inspect.getfile = patched_getfile
|
||||
|
Reference in New Issue
Block a user