torch.Package zipfile debugging printer (#52176)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52176

Added tooling to print out zipfile structure for PackageExporter and PackageImporter.

API looks like:
```
exporter.print_file_structure("sss" /*only include files with this in the path*/)
importer3.print_file_structure(False /*don't print storage*/, "sss" /*only include files with this in the path*/)
```

The output looks like this with the storage hidden by default:
```
─── resnet.zip
    ├── .data
    │   ├── extern_modules
    │   └── version
    ├── models
    │   └── models1.pkl
    └── torchvision
        └── models
            ├── resnet.py
            └── utils.py
```
The output looks like this with the storage being printed out:
```
─── resnet_added_attr_test.zip
    ├── .data
    │   ├── 94574437434544.storage
    │   ├── 94574468343696.storage
    │   ├── 94574470147744.storage
    │   ├── 94574470198784.storage
    │   ├── 94574470267968.storage
    │   ├── 94574474917984.storage
    │   ├── extern_modules
    │   └── version
    ├── models
    │   └── models1.pkl
    └── torchvision
        └── models
            ├── resnet.py
            └── utils.py
```

If the output is filtered with the string 'utils' it'd looks like this:
```
─── resnet_added_attr_test.zip
    └── torchvision
        └── models
            └── utils.py
```

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D26429795

Pulled By: Lilyjjo

fbshipit-source-id: 4fa25b0426912f939c7b52cedd6e217672891f21
This commit is contained in:
Lillian Johnson
2021-02-22 15:00:34 -08:00
committed by Facebook GitHub Bot
parent b72a72a477
commit 0bc57f47f0
5 changed files with 209 additions and 41 deletions

View File

@ -5,16 +5,17 @@ import io
import pickletools
from .find_file_dependencies import find_files_source_depends_on
from ._custom_import_pickler import create_custom_import_pickler
from ._file_structure_representation import _create_folder_from_file_list, Folder
from ._glob_group import GlobPattern, _GlobGroup
from ._importlib import _normalize_path
from ._mangling import is_mangled
from ._stdlib import is_stdlib_module
from .importer import Importer, OrderedImporter, sys_importer
import types
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, Iterable, BinaryIO, Optional
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, BinaryIO, Optional
from pathlib import Path
import linecache
from urllib.parse import quote
import re
class PackageExporter:
@ -133,6 +134,19 @@ class PackageExporter:
is_package = path.name == '__init__.py'
self.save_source_string(module_name, _read_file(file_or_directory), is_package, dependencies, file_or_directory)
def file_structure(self, *, include: 'GlobPattern' = "**", exclude: 'GlobPattern' = ()) -> Folder:
"""Returns a file structure representation of package's zipfile.
Args:
include (Union[List[str], str]): An optional string e.g. "my_package.my_subpackage", or optional list of strings
for the names of the files to be inluded in the zipfile representation. This can also be
a glob-style pattern, as described in :meth:`mock`
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
"""
return _create_folder_from_file_list(self.zip_file.archive_name(), self.zip_file.get_all_written_records(),
include, exclude)
def save_source_string(self, module_name: str, src: str, is_package: bool = False,
dependencies: bool = True, orig_file_name: str = None):
"""Adds `src` as the source code for `module_name` in the exported package.
@ -473,42 +487,3 @@ def _read_file(filename: str) -> str:
with open(filename, 'rb') as f:
b = f.read()
return b.decode('utf-8')
GlobPattern = Union[str, Iterable[str]]
class _GlobGroup:
def __init__(self, include: 'GlobPattern', exclude: 'GlobPattern'):
self._dbg = f'_GlobGroup(include={include}, exclude={exclude})'
self.include = _GlobGroup._glob_list(include)
self.exclude = _GlobGroup._glob_list(exclude)
def __str__(self):
return self._dbg
def matches(self, candidate: str) -> bool:
candidate = '.' + candidate
return any(p.fullmatch(candidate) for p in self.include) and all(not p.fullmatch(candidate) for p in self.exclude)
@staticmethod
def _glob_list(elems: 'GlobPattern'):
if isinstance(elems, str):
return [_GlobGroup._glob_to_re(elems)]
else:
return [_GlobGroup._glob_to_re(e) for e in elems]
@staticmethod
def _glob_to_re(pattern: str):
# to avoid corner cases for the first component, we prefix the candidate string
# with '.' so `import torch` will regex against `.torch`
def component_to_re(component):
if '**' in component:
if component == '**':
return '(\\.[^.]+)*'
else:
raise ValueError('** can only appear as an entire path segment')
else:
return '\\.' + '[^.]*'.join(re.escape(x) for x in component.split('*'))
result = ''.join(component_to_re(c) for c in pattern.split('.'))
return re.compile(result)