mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.Package zipfile debugging printer (#52176)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52176 Added tooling to print out zipfile structure for PackageExporter and PackageImporter. API looks like: ``` exporter.print_file_structure("sss" /*only include files with this in the path*/) importer3.print_file_structure(False /*don't print storage*/, "sss" /*only include files with this in the path*/) ``` The output looks like this with the storage hidden by default: ``` ─── resnet.zip ├── .data │ ├── extern_modules │ └── version ├── models │ └── models1.pkl └── torchvision └── models ├── resnet.py └── utils.py ``` The output looks like this with the storage being printed out: ``` ─── resnet_added_attr_test.zip ├── .data │ ├── 94574437434544.storage │ ├── 94574468343696.storage │ ├── 94574470147744.storage │ ├── 94574470198784.storage │ ├── 94574470267968.storage │ ├── 94574474917984.storage │ ├── extern_modules │ └── version ├── models │ └── models1.pkl └── torchvision └── models ├── resnet.py └── utils.py ``` If the output is filtered with the string 'utils' it'd looks like this: ``` ─── resnet_added_attr_test.zip └── torchvision └── models └── utils.py ``` Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D26429795 Pulled By: Lilyjjo fbshipit-source-id: 4fa25b0426912f939c7b52cedd6e217672891f21
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b72a72a477
commit
0bc57f47f0
@ -82,6 +82,59 @@ the_math = math
|
||||
self.assertEqual(package_a_i.result, 'package_a')
|
||||
self.assertIsNot(package_a_i, package_a)
|
||||
|
||||
def test_file_structure(self):
|
||||
filename = self.temp()
|
||||
|
||||
export_plain = """\
|
||||
├── main
|
||||
│ └── main
|
||||
├── obj
|
||||
│ └── obj.pkl
|
||||
├── package_a
|
||||
│ ├── __init__.py
|
||||
│ └── subpackage.py
|
||||
└── module_a.py
|
||||
"""
|
||||
export_include = """\
|
||||
├── obj
|
||||
│ └── obj.pkl
|
||||
└── package_a
|
||||
└── subpackage.py
|
||||
"""
|
||||
import_exclude = """\
|
||||
├── .data
|
||||
│ ├── extern_modules
|
||||
│ └── version
|
||||
├── main
|
||||
│ └── main
|
||||
├── obj
|
||||
│ └── obj.pkl
|
||||
├── package_a
|
||||
│ ├── __init__.py
|
||||
│ └── subpackage.py
|
||||
└── module_a.py
|
||||
"""
|
||||
|
||||
with PackageExporter(filename, verbose=False) as he:
|
||||
import module_a
|
||||
import package_a
|
||||
import package_a.subpackage
|
||||
obj = package_a.subpackage.PackageASubpackageObject()
|
||||
he.save_module(module_a.__name__)
|
||||
he.save_module(package_a.__name__)
|
||||
he.save_pickle('obj', 'obj.pkl', obj)
|
||||
he.save_text('main', 'main', "my string")
|
||||
|
||||
export_file_structure = he.file_structure()
|
||||
# remove first line from testing because WINDOW/iOS/Unix treat the filename differently
|
||||
self.assertEqual('\n'.join(str(export_file_structure).split('\n')[1:]), export_plain)
|
||||
export_file_structure = he.file_structure(include=["**/subpackage.py", "**/*.pkl"])
|
||||
self.assertEqual('\n'.join(str(export_file_structure).split('\n')[1:]), export_include)
|
||||
|
||||
hi = PackageImporter(filename)
|
||||
import_file_structure = hi.file_structure(exclude="**/*.storage")
|
||||
self.assertEqual('\n'.join(str(import_file_structure).split('\n')[1:]), import_exclude)
|
||||
|
||||
def test_save_module_binary(self):
|
||||
f = BytesIO()
|
||||
with PackageExporter(f, verbose=False) as he:
|
||||
|
||||
78
torch/package/_file_structure_representation.py
Normal file
78
torch/package/_file_structure_representation.py
Normal file
@ -0,0 +1,78 @@
|
||||
from ._glob_group import GlobPattern, _GlobGroup
|
||||
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class Folder:
|
||||
def __init__(self, name: str, is_dir: bool):
|
||||
self.name = name
|
||||
self.is_dir = is_dir
|
||||
self.children: Dict[str, Folder] = {}
|
||||
|
||||
def get_folder(self, folders: List[str]):
|
||||
# builds path of folders if not yet built, returns last folder
|
||||
if len(folders) == 0:
|
||||
return self
|
||||
folder_name = folders[0]
|
||||
if folder_name not in self.children:
|
||||
self.children[folder_name] = Folder(folder_name, True)
|
||||
return self.children[folder_name].get_folder(folders[1:])
|
||||
|
||||
def add_file(self, file_path):
|
||||
*folders, file = file_path.split("/")
|
||||
folder = self.get_folder(folders)
|
||||
folder.children[file] = Folder(file, False)
|
||||
|
||||
def __str__(self):
|
||||
str_list: List[str] = []
|
||||
self.stringify_tree(str_list)
|
||||
return "".join(str_list)
|
||||
|
||||
def stringify_tree(
|
||||
self, str_list: List[str], preamble: str = "", folder_ptr: str = "─── "
|
||||
):
|
||||
space = " "
|
||||
branch = "│ "
|
||||
tee = "├── "
|
||||
last = "└── "
|
||||
|
||||
# add this folder's representation
|
||||
str_list.append(f"{preamble}{folder_ptr}{self.name}\n")
|
||||
|
||||
# add folder's children representations
|
||||
if folder_ptr == tee:
|
||||
preamble = preamble + branch
|
||||
else:
|
||||
preamble = preamble + space
|
||||
|
||||
file_keys: List[str] = []
|
||||
dir_keys: List[str] = []
|
||||
for key, val in self.children.items():
|
||||
if val.is_dir:
|
||||
dir_keys.append(key)
|
||||
else:
|
||||
file_keys.append(key)
|
||||
|
||||
for index, key in enumerate(sorted(dir_keys)):
|
||||
if (index == len(dir_keys) - 1) and len(file_keys) == 0:
|
||||
self.children[key].stringify_tree(str_list, preamble, last)
|
||||
else:
|
||||
self.children[key].stringify_tree(str_list, preamble, tee)
|
||||
for index, file in enumerate(sorted(file_keys)):
|
||||
pointer = last if (index == len(file_keys) - 1) else tee
|
||||
str_list.append(f"{preamble}{pointer}{file}\n")
|
||||
|
||||
|
||||
def _create_folder_from_file_list(
|
||||
filename: str,
|
||||
file_list: List[str],
|
||||
include: "GlobPattern" = "**",
|
||||
exclude: "GlobPattern" = (),
|
||||
) -> Folder:
|
||||
glob_pattern = _GlobGroup(include, exclude, "/")
|
||||
|
||||
top_folder = Folder(filename, True)
|
||||
for file in file_list:
|
||||
if glob_pattern.matches(file):
|
||||
top_folder.add_file(file)
|
||||
return top_folder
|
||||
48
torch/package/_glob_group.py
Normal file
48
torch/package/_glob_group.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Union, Iterable
|
||||
import re
|
||||
|
||||
GlobPattern = Union[str, Iterable[str]]
|
||||
|
||||
|
||||
class _GlobGroup:
|
||||
def __init__(
|
||||
self, include: "GlobPattern", exclude: "GlobPattern", separator: str = "."
|
||||
):
|
||||
self._dbg = f"_GlobGroup(include={include}, exclude={exclude})"
|
||||
self.include = _GlobGroup._glob_list(include, separator)
|
||||
self.exclude = _GlobGroup._glob_list(exclude, separator)
|
||||
self.separator = separator
|
||||
|
||||
def __str__(self):
|
||||
return self._dbg
|
||||
|
||||
def matches(self, candidate: str) -> bool:
|
||||
candidate = self.separator + candidate
|
||||
return any(p.fullmatch(candidate) for p in self.include) and all(
|
||||
not p.fullmatch(candidate) for p in self.exclude
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _glob_list(elems: "GlobPattern", separator: str = "."):
|
||||
if isinstance(elems, str):
|
||||
return [_GlobGroup._glob_to_re(elems, separator)]
|
||||
else:
|
||||
return [_GlobGroup._glob_to_re(e, separator) for e in elems]
|
||||
|
||||
@staticmethod
|
||||
def _glob_to_re(pattern: str, separator: str = "."):
|
||||
# to avoid corner cases for the first component, we prefix the candidate string
|
||||
# with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator
|
||||
def component_to_re(component):
|
||||
if "**" in component:
|
||||
if component == "**":
|
||||
return "(" + re.escape(separator) + "[^" + separator + "]+)*"
|
||||
else:
|
||||
raise ValueError("** can only appear as an entire path segment")
|
||||
else:
|
||||
return re.escape(separator) + ("[^" + separator + "]*").join(
|
||||
re.escape(x) for x in component.split("*")
|
||||
)
|
||||
|
||||
result = "".join(component_to_re(c) for c in pattern.split(separator))
|
||||
return re.compile(result)
|
||||
@ -5,16 +5,17 @@ import io
|
||||
import pickletools
|
||||
from .find_file_dependencies import find_files_source_depends_on
|
||||
from ._custom_import_pickler import create_custom_import_pickler
|
||||
from ._file_structure_representation import _create_folder_from_file_list, Folder
|
||||
from ._glob_group import GlobPattern, _GlobGroup
|
||||
from ._importlib import _normalize_path
|
||||
from ._mangling import is_mangled
|
||||
from ._stdlib import is_stdlib_module
|
||||
from .importer import Importer, OrderedImporter, sys_importer
|
||||
import types
|
||||
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, Iterable, BinaryIO, Optional
|
||||
from typing import List, Any, Callable, Dict, Sequence, Tuple, Union, BinaryIO, Optional
|
||||
from pathlib import Path
|
||||
import linecache
|
||||
from urllib.parse import quote
|
||||
import re
|
||||
|
||||
|
||||
class PackageExporter:
|
||||
@ -133,6 +134,19 @@ class PackageExporter:
|
||||
is_package = path.name == '__init__.py'
|
||||
self.save_source_string(module_name, _read_file(file_or_directory), is_package, dependencies, file_or_directory)
|
||||
|
||||
def file_structure(self, *, include: 'GlobPattern' = "**", exclude: 'GlobPattern' = ()) -> Folder:
|
||||
"""Returns a file structure representation of package's zipfile.
|
||||
|
||||
Args:
|
||||
include (Union[List[str], str]): An optional string e.g. "my_package.my_subpackage", or optional list of strings
|
||||
for the names of the files to be inluded in the zipfile representation. This can also be
|
||||
a glob-style pattern, as described in :meth:`mock`
|
||||
|
||||
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
|
||||
"""
|
||||
return _create_folder_from_file_list(self.zip_file.archive_name(), self.zip_file.get_all_written_records(),
|
||||
include, exclude)
|
||||
|
||||
def save_source_string(self, module_name: str, src: str, is_package: bool = False,
|
||||
dependencies: bool = True, orig_file_name: str = None):
|
||||
"""Adds `src` as the source code for `module_name` in the exported package.
|
||||
@ -473,42 +487,3 @@ def _read_file(filename: str) -> str:
|
||||
with open(filename, 'rb') as f:
|
||||
b = f.read()
|
||||
return b.decode('utf-8')
|
||||
|
||||
GlobPattern = Union[str, Iterable[str]]
|
||||
|
||||
|
||||
class _GlobGroup:
|
||||
def __init__(self, include: 'GlobPattern', exclude: 'GlobPattern'):
|
||||
self._dbg = f'_GlobGroup(include={include}, exclude={exclude})'
|
||||
self.include = _GlobGroup._glob_list(include)
|
||||
self.exclude = _GlobGroup._glob_list(exclude)
|
||||
|
||||
def __str__(self):
|
||||
return self._dbg
|
||||
|
||||
def matches(self, candidate: str) -> bool:
|
||||
candidate = '.' + candidate
|
||||
return any(p.fullmatch(candidate) for p in self.include) and all(not p.fullmatch(candidate) for p in self.exclude)
|
||||
|
||||
@staticmethod
|
||||
def _glob_list(elems: 'GlobPattern'):
|
||||
if isinstance(elems, str):
|
||||
return [_GlobGroup._glob_to_re(elems)]
|
||||
else:
|
||||
return [_GlobGroup._glob_to_re(e) for e in elems]
|
||||
|
||||
@staticmethod
|
||||
def _glob_to_re(pattern: str):
|
||||
# to avoid corner cases for the first component, we prefix the candidate string
|
||||
# with '.' so `import torch` will regex against `.torch`
|
||||
def component_to_re(component):
|
||||
if '**' in component:
|
||||
if component == '**':
|
||||
return '(\\.[^.]+)*'
|
||||
else:
|
||||
raise ValueError('** can only appear as an entire path segment')
|
||||
else:
|
||||
return '\\.' + '[^.]*'.join(re.escape(x) for x in component.split('*'))
|
||||
|
||||
result = ''.join(component_to_re(c) for c in pattern.split('.'))
|
||||
return re.compile(result)
|
||||
|
||||
@ -15,6 +15,8 @@ from pathlib import Path
|
||||
|
||||
from ._importlib import _normalize_line_endings, _resolve_name, _sanity_check, _calc___package__, \
|
||||
_normalize_path
|
||||
from ._file_structure_representation import _create_folder_from_file_list, Folder
|
||||
from ._glob_group import GlobPattern
|
||||
from ._mock_zipreader import MockZipReader
|
||||
from ._mangling import PackageMangler, demangle
|
||||
from .importer import Importer
|
||||
@ -192,6 +194,18 @@ class PackageImporter(Importer):
|
||||
"""
|
||||
return self._mangler.parent_name()
|
||||
|
||||
def file_structure(self, *, include: 'GlobPattern' = "**", exclude: 'GlobPattern' = ()) -> Folder:
|
||||
"""Returns a file structure representation of package's zipfile.
|
||||
|
||||
Args:
|
||||
include (Union[List[str], str]): An optional string e.g. "my_package.my_subpackage", or optional list of strings
|
||||
for the names of the files to be inluded in the zipfile representation. This can also be
|
||||
a glob-style pattern, as described in exporter's :meth:`mock`
|
||||
|
||||
exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
|
||||
"""
|
||||
return _create_folder_from_file_list(self.filename, self.zip_reader.get_all_records(), include, exclude)
|
||||
|
||||
def _read_extern(self):
|
||||
return self.zip_reader.get_record('.data/extern_modules').decode('utf-8').splitlines(keepends=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user