mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[rfc][pkg] check spec for module source before falling back to file in package exporter (#90258)
Summary: To get source for a particular module, the "correct" thing to do is to check the module's spec and use `get_source` if it's a SourceFileLoader, since subclasses may look elsewhere than the `__file__`, and the spec will give the source of truth. For torch packager, however, we prefer to use linecache, but the loader could still change the file, so we figure out the file for the module using the spec's loader rather than using `module.__file__`, if possible. Test Plan: This code path will get exercised by CI. Also added a test for remapped files. Differential Revision: D41412983 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90258 Approved by: https://github.com/PaliC
This commit is contained in:
committed by
PyTorch MergeBot
parent
e1674d7dc0
commit
0c972fb5c7
1
test/package/module_a_remapped_path.py
Normal file
1
test/package/module_a_remapped_path.py
Normal file
@ -0,0 +1 @@
|
||||
result = "module_a_remapped_path"
|
@ -2,7 +2,9 @@
|
||||
# Owner(s): ["oncall: package/deploy"]
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@ -104,6 +106,60 @@ class TestMisc(PackageTestCase):
|
||||
import_exclude,
|
||||
)
|
||||
|
||||
def test_loaders_that_remap_files_work_ok(self):
|
||||
from importlib.abc import MetaPathFinder
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from importlib.util import spec_from_loader
|
||||
|
||||
class LoaderThatRemapsModuleA(SourceFileLoader):
|
||||
def get_filename(self, name):
|
||||
result = super().get_filename(name)
|
||||
if name == "module_a":
|
||||
return os.path.join(os.path.dirname(result), "module_a_remapped_path.py")
|
||||
else:
|
||||
return result
|
||||
|
||||
class FinderThatRemapsModuleA(MetaPathFinder):
|
||||
def find_spec(self, fullname, path, target):
|
||||
"""Try to find the original spec for module_a using all the
|
||||
remaining meta_path finders."""
|
||||
if fullname != "module_a":
|
||||
return None
|
||||
spec = None
|
||||
for finder in sys.meta_path:
|
||||
if finder is self:
|
||||
continue
|
||||
if hasattr(finder, "find_spec"):
|
||||
spec = finder.find_spec(fullname, path, target=target)
|
||||
elif hasattr(finder, "load_module"):
|
||||
spec = spec_from_loader(fullname, finder)
|
||||
if spec is not None:
|
||||
break
|
||||
assert spec is not None and isinstance(spec.loader, SourceFileLoader)
|
||||
spec.loader = LoaderThatRemapsModuleA(spec.loader.name, spec.loader.path)
|
||||
return spec
|
||||
|
||||
sys.meta_path.insert(0, FinderThatRemapsModuleA())
|
||||
# clear it from sys.modules so that we use the custom finder next time
|
||||
# it gets imported
|
||||
sys.modules.pop("module_a", None)
|
||||
try:
|
||||
buffer = BytesIO()
|
||||
with PackageExporter(buffer) as he:
|
||||
import module_a
|
||||
|
||||
he.intern("**")
|
||||
he.save_module(module_a.__name__)
|
||||
|
||||
|
||||
buffer.seek(0)
|
||||
hi = PackageImporter(buffer)
|
||||
self.assertTrue("remapped_path" in hi.get_source("module_a"))
|
||||
finally:
|
||||
# pop it again to ensure it does not mess up other tests
|
||||
sys.modules.pop("module_a", None)
|
||||
sys.meta_path.pop(0)
|
||||
|
||||
def test_python_version(self):
|
||||
"""
|
||||
Tests that the current python version is stored in the package and is available
|
||||
|
@ -8,6 +8,7 @@ import types
|
||||
from collections import defaultdict, OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@ -422,17 +423,20 @@ class PackageExporter:
|
||||
return False
|
||||
|
||||
def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
|
||||
filename = getattr(module, "__file__", None)
|
||||
result = (
|
||||
None
|
||||
if filename is None or not filename.endswith(".py")
|
||||
else linecache.getlines(filename, module.__dict__)
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return "".join(result)
|
||||
filename = None
|
||||
spec = getattr(module, "__spec__", None)
|
||||
if spec is not None:
|
||||
loader = getattr(spec, "loader", None)
|
||||
if loader is not None and isinstance(loader, SourceFileLoader):
|
||||
try:
|
||||
filename = loader.get_filename(module.__name__)
|
||||
except ImportError:
|
||||
pass
|
||||
if filename is None:
|
||||
filename = getattr(module, "__file__", None)
|
||||
if isinstance(filename, str) and filename.endswith(".py"):
|
||||
return "".join(linecache.getlines(filename, module.__dict__))
|
||||
return None
|
||||
|
||||
def add_dependency(self, module_name: str, dependencies=True):
|
||||
"""Given a module, add it to the dependency graph according to patterns
|
||||
|
Reference in New Issue
Block a user