mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] add mock/extern hooks (#58000)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58000 Directly overriding save_extern and save_mock may mess with our invariants in weird ways. This is less pronounced now, but once we switch to graph-based dependency management things will get broken subtly if people fail to call `super()`. Better to add hook support to reflect that really you can only do a side effect. Also has the bonus that people are likely familiar with it from `nn.Module` hooks. Differential Revision: D28339191 Test Plan: Imported from OSS Reviewed By: SplitInfinity Pulled By: suo fbshipit-source-id: 63ffd39d2dcb1a7524f3c2c6a23bd399e754cc44
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d9ea93181b
commit
29cfcf70be
@ -3,6 +3,7 @@ import io
|
||||
import linecache
|
||||
import pickletools
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@ -20,6 +21,7 @@ from urllib.parse import quote
|
||||
|
||||
import torch
|
||||
from torch.serialization import location_tag, normalize_storage_type
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from ._digraph import DiGraph
|
||||
from ._importlib import _normalize_path
|
||||
@ -111,6 +113,12 @@ class PackageExporter:
|
||||
self.provided: Dict[str, bool] = {}
|
||||
self.verbose = verbose
|
||||
|
||||
# These are OrderedDicts for compatibility with RemovableHandle.
|
||||
# Generic OrderedDict type annotations are not present until 3.7.
|
||||
# The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
|
||||
self._extern_hooks: OrderedDict = OrderedDict()
|
||||
self._mock_hooks: OrderedDict = OrderedDict()
|
||||
|
||||
if isinstance(importer, Importer):
|
||||
self.importer = importer
|
||||
else:
|
||||
@ -275,7 +283,7 @@ node [shape=box];
|
||||
f"implicitly adding {root_name} to external modules "
|
||||
f"since it is part of the standard library and is a dependency."
|
||||
)
|
||||
self.save_extern_module(root_name)
|
||||
self._save_extern_module(root_name)
|
||||
return
|
||||
|
||||
for i, (pattern, action, _) in enumerate(self.patterns):
|
||||
@ -379,6 +387,48 @@ node [shape=box];
|
||||
filename = self._filename(package, resource)
|
||||
self._write(filename, binary)
|
||||
|
||||
def register_extern_hook(
|
||||
self, hook: Callable[["PackageExporter", str], None]
|
||||
) -> RemovableHandle:
|
||||
"""Registers an extern hook on the exporter.
|
||||
|
||||
The hook will be called each time a module matches against an :meth:`extern` pattern.
|
||||
It should have the following signature::
|
||||
|
||||
hook(exporter: PackageExporter, module_name: str) -> None
|
||||
|
||||
Hooks will be called in order of registration.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.hooks.RemovableHandle`:
|
||||
a handle that can be used to remove the added hook by calling
|
||||
``handle.remove()``
|
||||
"""
|
||||
handle = RemovableHandle(self._extern_hooks)
|
||||
self._extern_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def register_mock_hook(
|
||||
self, hook: Callable[["PackageExporter", str], None]
|
||||
) -> RemovableHandle:
|
||||
"""Registers a mock hook on the exporter.
|
||||
|
||||
The hook will be called each time a module matches against a :meth:`mock` pattern.
|
||||
It should have the following signature::
|
||||
|
||||
hook(exporter: PackageExporter, module_name: str) -> None
|
||||
|
||||
Hooks will be called in order of registration.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.hooks.RemovableHandle`:
|
||||
a handle that can be used to remove the added hook by calling
|
||||
``handle.remove()``
|
||||
"""
|
||||
handle = RemovableHandle(self._mock_hooks)
|
||||
self._mock_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
def mock(
|
||||
self,
|
||||
include: "GlobPattern",
|
||||
@ -412,7 +462,7 @@ node [shape=box];
|
||||
|
||||
"""
|
||||
self.patterns.append(
|
||||
(GlobGroup(include, exclude=exclude), self.save_mock_module, allow_empty)
|
||||
(GlobGroup(include, exclude=exclude), self._save_mock_module, allow_empty)
|
||||
)
|
||||
|
||||
def extern(
|
||||
@ -440,7 +490,7 @@ node [shape=box];
|
||||
|
||||
"""
|
||||
self.patterns.append(
|
||||
(GlobGroup(include, exclude=exclude), self.save_extern_module, allow_empty)
|
||||
(GlobGroup(include, exclude=exclude), self._save_extern_module, allow_empty)
|
||||
)
|
||||
|
||||
def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
|
||||
@ -457,20 +507,26 @@ node [shape=box];
|
||||
(GlobGroup(include, exclude=exclude), self._reject_denied_module, True)
|
||||
)
|
||||
|
||||
def save_extern_module(self, module_name: str):
|
||||
def _save_extern_module(self, module_name: str):
|
||||
"""Add `module_name` to the list of external modules, regardless of whether it is
|
||||
required by other modules.
|
||||
|
||||
Prefer using :meth:`extern` to only mark modules extern if they are actually required by the packaged code.
|
||||
"""
|
||||
for hook in self._extern_hooks.values():
|
||||
hook(self, module_name)
|
||||
|
||||
self.extern_modules[module_name] = True
|
||||
|
||||
def save_mock_module(self, module_name: str):
|
||||
def _save_mock_module(self, module_name: str):
|
||||
"""Add `module_name` to the package, implemented it with a mocked out version that
|
||||
can be imported but does not include any implementations.
|
||||
|
||||
Prefer using `mock` to only include this module if it is required by other modules.
|
||||
"""
|
||||
for hook in self._mock_hooks.values():
|
||||
hook(self, module_name)
|
||||
|
||||
if "_mock" not in self.provided:
|
||||
self.save_source_string(
|
||||
"_mock",
|
||||
|
Reference in New Issue
Block a user