[rfc][pkg] check spec for module source before falling back to file in package exporter (#90258)

Summary: To get source for a particular module, the "correct" thing to do is to check the module's spec and use `get_source` if it's a SourceFileLoader, since subclasses may look elsewhere than the `__file__`, and the spec will give the source of truth. For torch packager, however, we prefer to use linecache, but the loader could still change the file, so we figure out the file for the module using the spec's loader rather than using `module.__file__`, if possible.

Test Plan: This code path will get exercised by CI. Also added a test for remapped files.

Differential Revision: D41412983

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90258
Approved by: https://github.com/PaliC
This commit is contained in:
Stephen Macke
2022-12-08 20:24:45 +00:00
committed by PyTorch MergeBot
parent e1674d7dc0
commit 0c972fb5c7
3 changed files with 72 additions and 11 deletions

View File

@ -0,0 +1 @@
result = "module_a_remapped_path"

View File

@ -2,7 +2,9 @@
# Owner(s): ["oncall: package/deploy"]
import inspect
import os
import platform
import sys
from io import BytesIO
from pathlib import Path
from textwrap import dedent
@ -104,6 +106,60 @@ class TestMisc(PackageTestCase):
import_exclude,
)
def test_loaders_that_remap_files_work_ok(self):
from importlib.abc import MetaPathFinder
from importlib.machinery import SourceFileLoader
from importlib.util import spec_from_loader
class LoaderThatRemapsModuleA(SourceFileLoader):
def get_filename(self, name):
result = super().get_filename(name)
if name == "module_a":
return os.path.join(os.path.dirname(result), "module_a_remapped_path.py")
else:
return result
class FinderThatRemapsModuleA(MetaPathFinder):
def find_spec(self, fullname, path, target):
"""Try to find the original spec for module_a using all the
remaining meta_path finders."""
if fullname != "module_a":
return None
spec = None
for finder in sys.meta_path:
if finder is self:
continue
if hasattr(finder, "find_spec"):
spec = finder.find_spec(fullname, path, target=target)
elif hasattr(finder, "load_module"):
spec = spec_from_loader(fullname, finder)
if spec is not None:
break
assert spec is not None and isinstance(spec.loader, SourceFileLoader)
spec.loader = LoaderThatRemapsModuleA(spec.loader.name, spec.loader.path)
return spec
sys.meta_path.insert(0, FinderThatRemapsModuleA())
# clear it from sys.modules so that we use the custom finder next time
# it gets imported
sys.modules.pop("module_a", None)
try:
buffer = BytesIO()
with PackageExporter(buffer) as he:
import module_a
he.intern("**")
he.save_module(module_a.__name__)
buffer.seek(0)
hi = PackageImporter(buffer)
self.assertTrue("remapped_path" in hi.get_source("module_a"))
finally:
# pop it again to ensure it does not mess up other tests
sys.modules.pop("module_a", None)
sys.meta_path.pop(0)
def test_python_version(self):
"""
Tests that the current python version is stored in the package and is available

View File

@ -8,6 +8,7 @@ import types
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from enum import Enum
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import (
Any,
@ -422,17 +423,20 @@ class PackageExporter:
return False
def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
filename = getattr(module, "__file__", None)
result = (
None
if filename is None or not filename.endswith(".py")
else linecache.getlines(filename, module.__dict__)
)
if result is None:
return None
return "".join(result)
filename = None
spec = getattr(module, "__spec__", None)
if spec is not None:
loader = getattr(spec, "loader", None)
if loader is not None and isinstance(loader, SourceFileLoader):
try:
filename = loader.get_filename(module.__name__)
except ImportError:
pass
if filename is None:
filename = getattr(module, "__file__", None)
if isinstance(filename, str) and filename.endswith(".py"):
return "".join(linecache.getlines(filename, module.__dict__))
return None
def add_dependency(self, module_name: str, dependencies=True):
"""Given a module, add it to the dependency graph according to patterns