mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
deploy: add dummy metadata for builtin packages (#76211)
Summary: This adds dummy metadata for frozen builtin packages when using `torch::deploy`. This is a bit hacky but unblocks allows Huggingface transformers library to be used within `torch::deploy` which depends on `importlib.metadata.version` to detect whether torch is installed or not. https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py#L49 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76211 Test Plan: Added `importlib.metadata.version("torch")` unit test Reviewed By: kiukchung, PaliC Differential Revision: D35834831 Pulled By: d4l3k fbshipit-source-id: e58365e1ada69299adea96f0ca1fe211e092dd97 (cherry picked from commit c4b4152a24dcdf359503db2112a10a88633e67b6)
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9deda5469
commit
bbc6fcd730
@ -66,6 +66,7 @@ void BuiltinRegistry::runPreInitialization() {
|
||||
|
||||
const char* metaPathSetupTemplate = R"PYTHON(
|
||||
import sys
|
||||
from importlib.metadata import DistributionFinder, Distribution
|
||||
# We need to register a custom meta path finder because we are registering
|
||||
# `torch._C` as a builtin module.
|
||||
#
|
||||
@ -74,12 +75,36 @@ import sys
|
||||
# are top-level imports. Since `torch._C` is a submodule of `torch`, the
|
||||
# BuiltinImporter skips it.
|
||||
class F:
|
||||
MODULES = {<<<DEPLOY_BUILTIN_MODULES_CSV>>>}
|
||||
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
if fullname in [<<<DEPLOY_BUILTIN_MODULES_CSV>>>]:
|
||||
if fullname in self.MODULES:
|
||||
# Load this module using `BuiltinImporter`, but set `path` to None
|
||||
# in order to trick it into loading our module.
|
||||
return sys.meta_path[1].find_spec(fullname, path=None, target=None)
|
||||
return None
|
||||
|
||||
def find_distributions(self, context=DistributionFinder.Context()):
|
||||
modules = {"torch"} | self.MODULES
|
||||
# Insert dummy distribution records for each builtin module so
|
||||
# importlib.metadata.version(...) works.
|
||||
if context.name is None:
|
||||
for name in modules:
|
||||
yield DummyDistribution(name)
|
||||
if context.name in modules:
|
||||
yield DummyDistribution(context.name)
|
||||
|
||||
class DummyDistribution(Distribution):
|
||||
def __init__(self, name):
|
||||
self._metadata = {
|
||||
"Name": name,
|
||||
"Version": "0.0.1+fake_multipy",
|
||||
}
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
sys.meta_path.insert(0, F())
|
||||
)PYTHON";
|
||||
|
||||
@ -87,9 +112,9 @@ void BuiltinRegistry::runPostInitialization() {
|
||||
TORCH_INTERNAL_ASSERT(Py_IsInitialized());
|
||||
std::string metaPathSetupScript(metaPathSetupTemplate);
|
||||
std::string replaceKey = "<<<DEPLOY_BUILTIN_MODULES_CSV>>>";
|
||||
auto itr = metaPathSetupScript.find(replaceKey);
|
||||
if (itr != std::string::npos) {
|
||||
metaPathSetupScript.replace(itr, replaceKey.size(), getBuiltinModulesCSV());
|
||||
size_t pos = metaPathSetupScript.find(replaceKey);
|
||||
if (pos != std::string::npos) {
|
||||
metaPathSetupScript.replace(pos, replaceKey.size(), getBuiltinModulesCSV());
|
||||
}
|
||||
int r = PyRun_SimpleString(metaPathSetupScript.c_str());
|
||||
TORCH_INTERNAL_ASSERT(r == 0);
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
* BuiltinRegisterer object. The constructor of BuiltinRegisterer does the real
|
||||
* registration work.
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <gtest/gtest_prod.h>
|
||||
#include <cstdarg>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
@ -451,6 +451,18 @@ result = torch.Tensor([1,2,3])
|
||||
EXPECT_TRUE(w_grad0.equal(w_grad1));
|
||||
}
|
||||
|
||||
TEST(TorchpyTest, ImportlibMetadata) {
|
||||
torch::deploy::InterpreterManager m(1);
|
||||
m.registerModuleSource("importlib_test", R"PYTHON(
|
||||
from importlib.metadata import version
|
||||
|
||||
result = version("torch")
|
||||
)PYTHON");
|
||||
auto I = m.allInstances()[0].acquireSession();
|
||||
auto ver = I.global("importlib_test", "result").toIValue().toString();
|
||||
ASSERT_EQ(ver->string(), "0.0.1+fake_multipy");
|
||||
}
|
||||
|
||||
// OSS build does not have bultin numpy support yet. Use this flag to guard the
|
||||
// test case.
|
||||
#if HAS_NUMPY
|
||||
|
||||
Reference in New Issue
Block a user