mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.Package zipfile debugging printer (#52176)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52176 Added tooling to print out zipfile structure for PackageExporter and PackageImporter. API looks like: ``` exporter.print_file_structure("sss" /*only include files with this in the path*/) importer3.print_file_structure(False /*don't print storage*/, "sss" /*only include files with this in the path*/) ``` The output looks like this with the storage hidden by default: ``` ─── resnet.zip ├── .data │ ├── extern_modules │ └── version ├── models │ └── models1.pkl └── torchvision └── models ├── resnet.py └── utils.py ``` The output looks like this with the storage being printed out: ``` ─── resnet_added_attr_test.zip ├── .data │ ├── 94574437434544.storage │ ├── 94574468343696.storage │ ├── 94574470147744.storage │ ├── 94574470198784.storage │ ├── 94574470267968.storage │ ├── 94574474917984.storage │ ├── extern_modules │ └── version ├── models │ └── models1.pkl └── torchvision └── models ├── resnet.py └── utils.py ``` If the output is filtered with the string 'utils' it'd looks like this: ``` ─── resnet_added_attr_test.zip └── torchvision └── models └── utils.py ``` Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D26429795 Pulled By: Lilyjjo fbshipit-source-id: 4fa25b0426912f939c7b52cedd6e217672891f21
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b72a72a477
commit
0bc57f47f0
@ -5,16 +5,17 @@ import io
|
||||
import pickletools
|
||||
from .find_file_dependencies import find_files_source_depends_on
|
||||
from ._custom_import_pickler import create_custom_import_pickler
|
||||
from ._file_structure_representation import _create_folder_from_file_list, Folder
|
||||
from ._glob_group import GlobPattern, _GlobGroup
|
||||
from ._importlib import _normalize_path
|
||||
from ._mangling import is_mangled
|
||||
from ._stdlib import is_stdlib_module
|
||||
from .importer import Importer, OrderedImporter, sys_importer
|
||||
import types
|
||||
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, Iterable, BinaryIO, Optional
|
||||
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, BinaryIO, Optional
|
||||
from pathlib import Path
|
||||
import linecache
|
||||
from urllib.parse import quote
|
||||
import re
|
||||
|
||||
|
||||
class PackageExporter:
|
||||
@ -133,6 +134,19 @@ class PackageExporter:
|
||||
is_package = path.name == '__init__.py'
|
||||
self.save_source_string(module_name, _read_file(file_or_directory), is_package, dependencies, file_or_directory)
|
||||
|
||||
def file_structure(self, *, include: 'GlobPattern' = "**", exclude: 'GlobPattern' = ()) -> Folder:
|
||||
"""Returns a file structure representation of package's zipfile.
|
||||
|
||||
Args:
|
||||
include (Union[List[str], str]): An optional string e.g. "my_package.my_subpackage", or optional list of strings
|
||||
for the names of the files to be inluded in the zipfile representation. This can also be
|
||||
a glob-style pattern, as described in :meth:`mock`
|
||||
|
||||
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
|
||||
"""
|
||||
return _create_folder_from_file_list(self.zip_file.archive_name(), self.zip_file.get_all_written_records(),
|
||||
include, exclude)
|
||||
|
||||
def save_source_string(self, module_name: str, src: str, is_package: bool = False,
|
||||
dependencies: bool = True, orig_file_name: str = None):
|
||||
"""Adds `src` as the source code for `module_name` in the exported package.
|
||||
@ -473,42 +487,3 @@ def _read_file(filename: str) -> str:
|
||||
with open(filename, 'rb') as f:
|
||||
b = f.read()
|
||||
return b.decode('utf-8')
|
||||
|
||||
GlobPattern = Union[str, Iterable[str]]
|
||||
|
||||
|
||||
class _GlobGroup:
|
||||
def __init__(self, include: 'GlobPattern', exclude: 'GlobPattern'):
|
||||
self._dbg = f'_GlobGroup(include={include}, exclude={exclude})'
|
||||
self.include = _GlobGroup._glob_list(include)
|
||||
self.exclude = _GlobGroup._glob_list(exclude)
|
||||
|
||||
def __str__(self):
|
||||
return self._dbg
|
||||
|
||||
def matches(self, candidate: str) -> bool:
|
||||
candidate = '.' + candidate
|
||||
return any(p.fullmatch(candidate) for p in self.include) and all(not p.fullmatch(candidate) for p in self.exclude)
|
||||
|
||||
@staticmethod
|
||||
def _glob_list(elems: 'GlobPattern'):
|
||||
if isinstance(elems, str):
|
||||
return [_GlobGroup._glob_to_re(elems)]
|
||||
else:
|
||||
return [_GlobGroup._glob_to_re(e) for e in elems]
|
||||
|
||||
@staticmethod
|
||||
def _glob_to_re(pattern: str):
|
||||
# to avoid corner cases for the first component, we prefix the candidate string
|
||||
# with '.' so `import torch` will regex against `.torch`
|
||||
def component_to_re(component):
|
||||
if '**' in component:
|
||||
if component == '**':
|
||||
return '(\\.[^.]+)*'
|
||||
else:
|
||||
raise ValueError('** can only appear as an entire path segment')
|
||||
else:
|
||||
return '\\.' + '[^.]*'.join(re.escape(x) for x in component.split('*'))
|
||||
|
||||
result = ''.join(component_to_re(c) for c in pattern.split('.'))
|
||||
return re.compile(result)
|
||||
|
Reference in New Issue
Block a user