[package] patch inspect.getfile to work with PackageImporter (#51568)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51568

The default behavior of inspect.getfile doesn't work on classes imported
from PackageImporter, because it returns the following.

    sys.modules[kls.__module__].__file__

Looking in `sys.modules` is hard-coded behavior. So, patch it to first
check a similar registry of PackageImported modules we maintain.

Test Plan: Imported from OSS

Reviewed By: yf225

Differential Revision: D26201236

Pulled By: suo

fbshipit-source-id: aaf5d7ee8ca0155619c8185e64f70a30152ac567
This commit is contained in:
Michael Suo
2021-02-02 11:27:04 -08:00
committed by Facebook GitHub Bot
parent b6c6fb7252
commit 55a4aa79aa
2 changed files with 38 additions and 0 deletions

View File

@ -1,4 +1,5 @@
from unittest import skipIf
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from tempfile import NamedTemporaryFile
from torch.package import PackageExporter, PackageImporter
@ -507,6 +508,24 @@ def load():
with self.assertRaises(NotImplementedError):
hi.load_pickle('obj', 'obj.pkl')
def test_inspect_class(self):
"""Should be able to retrieve source for a packaged class."""
import package_a.subpackage
buffer = BytesIO()
obj = package_a.subpackage.PackageASubpackageObject()
with PackageExporter(buffer, verbose=False) as pe:
pe.save_pickle('obj', 'obj.pkl', obj)
buffer.seek(0)
pi = PackageImporter(buffer)
packaged_class = pi.import_module('package_a.subpackage').PackageASubpackageObject
regular_class = package_a.subpackage.PackageASubpackageObject
packaged_src = inspect.getsourcelines(packaged_class)
regular_src = inspect.getsourcelines(regular_class)
self.assertEqual(packaged_src, regular_src)
class ManglingTest(TestCase):
def test_unique_manglers(self):

View File

@ -1,6 +1,8 @@
from typing import List, Callable, Dict, Optional, Any, Union, BinaryIO
from types import ModuleType
import builtins
import importlib
import inspect
import linecache
from torch.serialization import _load
import pickle
@ -168,6 +170,10 @@ class PackageImporter:
ns['__cached__'] = None
ns['__builtins__'] = self.patched_builtins
# Add this module to our private global registry. It should be unique due to mangling.
assert module.__name__ not in _package_imported_modules
_package_imported_modules[module.__name__] = module
# pre-emptively install on the parent to prevent IMPORT_FROM from trying to
# access sys.modules
self._install_on_parent(parent, name, module)
@ -429,3 +435,16 @@ class _ModuleNode(_PathNode):
class _ExternNode(_PathNode):
pass
# A private global registry of all modules that have been package-imported.
_package_imported_modules: Dict[str, ModuleType] = {}
# `inspect` by default only looks in `sys.modules` to find source files for classes.
# Patch it to check our private registry of package-imported modules as well.
_orig_getfile = inspect.getfile
def patched_getfile(object):
if inspect.isclass(object):
if object.__module__ in _package_imported_modules:
return _package_imported_modules[object.__module__].__file__
return _orig_getfile(object)
inspect.getfile = patched_getfile