[package] add mock/extern hooks (#58000)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58000

Directly overriding save_extern and save_mock may mess with our
invariants in weird ways. This is less pronounced now, but once we
switch to graph-based dependency management things will get broken
subtly if people fail to call `super()`.

Better to add hook support to reflect that really you can only do a side
effect. Also has the bonus that people are likely familiar with it from
`nn.Module` hooks.

Differential Revision: D28339191

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Pulled By: suo

fbshipit-source-id: 63ffd39d2dcb1a7524f3c2c6a23bd399e754cc44
This commit is contained in:
Michael Suo
2021-05-11 16:45:32 -07:00
committed by Facebook GitHub Bot
parent d9ea93181b
commit 29cfcf70be
3 changed files with 183 additions and 32 deletions

View File

@ -3,6 +3,7 @@ import io
import linecache
import pickletools
import types
from collections import OrderedDict
from pathlib import Path
from typing import (
Any,
@ -20,6 +21,7 @@ from urllib.parse import quote
import torch
from torch.serialization import location_tag, normalize_storage_type
from torch.utils.hooks import RemovableHandle
from ._digraph import DiGraph
from ._importlib import _normalize_path
@ -111,6 +113,12 @@ class PackageExporter:
self.provided: Dict[str, bool] = {}
self.verbose = verbose
# These are OrderedDicts for compatibility with RemovableHandle.
# Generic OrderedDict type annotations are not present until 3.7.
# The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
self._extern_hooks: OrderedDict = OrderedDict()
self._mock_hooks: OrderedDict = OrderedDict()
if isinstance(importer, Importer):
self.importer = importer
else:
@ -275,7 +283,7 @@ node [shape=box];
f"implicitly adding {root_name} to external modules "
f"since it is part of the standard library and is a dependency."
)
self.save_extern_module(root_name)
self._save_extern_module(root_name)
return
for i, (pattern, action, _) in enumerate(self.patterns):
@ -379,6 +387,48 @@ node [shape=box];
filename = self._filename(package, resource)
self._write(filename, binary)
def register_extern_hook(
self, hook: Callable[["PackageExporter", str], None]
) -> RemovableHandle:
"""Registers an extern hook on the exporter.
The hook will be called each time a module matches against an :meth:`extern` pattern.
It should have the following signature::
hook(exporter: PackageExporter, module_name: str) -> None
Hooks will be called in order of registration.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = RemovableHandle(self._extern_hooks)
self._extern_hooks[handle.id] = hook
return handle
def register_mock_hook(
self, hook: Callable[["PackageExporter", str], None]
) -> RemovableHandle:
"""Registers a mock hook on the exporter.
The hook will be called each time a module matches against a :meth:`mock` pattern.
It should have the following signature::
hook(exporter: PackageExporter, module_name: str) -> None
Hooks will be called in order of registration.
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = RemovableHandle(self._mock_hooks)
self._mock_hooks[handle.id] = hook
return handle
def mock(
self,
include: "GlobPattern",
@ -412,7 +462,7 @@ node [shape=box];
"""
self.patterns.append(
(GlobGroup(include, exclude=exclude), self.save_mock_module, allow_empty)
(GlobGroup(include, exclude=exclude), self._save_mock_module, allow_empty)
)
def extern(
@ -440,7 +490,7 @@ node [shape=box];
"""
self.patterns.append(
(GlobGroup(include, exclude=exclude), self.save_extern_module, allow_empty)
(GlobGroup(include, exclude=exclude), self._save_extern_module, allow_empty)
)
def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
@ -457,20 +507,26 @@ node [shape=box];
(GlobGroup(include, exclude=exclude), self._reject_denied_module, True)
)
def save_extern_module(self, module_name: str):
def _save_extern_module(self, module_name: str):
"""Add `module_name` to the list of external modules, regardless of whether it is
required by other modules.
Prefer using :meth:`extern` to only mark modules extern if they are actually required by the packaged code.
"""
for hook in self._extern_hooks.values():
hook(self, module_name)
self.extern_modules[module_name] = True
def save_mock_module(self, module_name: str):
def _save_mock_module(self, module_name: str):
"""Add `module_name` to the package, implemented it with a mocked out version that
can be imported but does not include any implementations.
Prefer using `mock` to only include this module if it is required by other modules.
"""
for hook in self._mock_hooks.values():
hook(self, module_name)
if "_mock" not in self.provided:
self.save_source_string(
"_mock",