mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	deploy: add dummy metadata for builtin packages (#76211)
Summary: This adds dummy metadata for frozen builtin packages when using `torch::deploy`. This is a bit hacky but unblocks allows Huggingface transformers library to be used within `torch::deploy` which depends on `importlib.metadata.version` to detect whether torch is installed or not. https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py#L49 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76211 Test Plan: Added `importlib.metadata.version("torch")` unit test Reviewed By: kiukchung, PaliC Differential Revision: D35834831 Pulled By: d4l3k fbshipit-source-id: e58365e1ada69299adea96f0ca1fe211e092dd97 (cherry picked from commit c4b4152a24dcdf359503db2112a10a88633e67b6)
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							a9deda5469
						
					
				
				
					commit
					bbc6fcd730
				
			| @ -66,6 +66,7 @@ void BuiltinRegistry::runPreInitialization() { | ||||
|  | ||||
| const char* metaPathSetupTemplate = R"PYTHON( | ||||
| import sys | ||||
| from importlib.metadata import DistributionFinder, Distribution | ||||
| # We need to register a custom meta path finder because we are registering | ||||
| # `torch._C` as a builtin module. | ||||
| # | ||||
| @ -74,12 +75,36 @@ import sys | ||||
| # are top-level imports.  Since `torch._C` is a submodule of `torch`, the | ||||
| # BuiltinImporter skips it. | ||||
| class F: | ||||
|     MODULES = {<<<DEPLOY_BUILTIN_MODULES_CSV>>>} | ||||
|  | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         if fullname in [<<<DEPLOY_BUILTIN_MODULES_CSV>>>]: | ||||
|         if fullname in self.MODULES: | ||||
|             # Load this module using `BuiltinImporter`, but set `path` to None | ||||
|             # in order to trick it into loading our module. | ||||
|             return sys.meta_path[1].find_spec(fullname, path=None, target=None) | ||||
|         return None | ||||
|  | ||||
|     def find_distributions(self, context=DistributionFinder.Context()): | ||||
|         modules = {"torch"} | self.MODULES | ||||
|         # Insert dummy distribution records for each builtin module so | ||||
|         # importlib.metadata.version(...) works. | ||||
|         if context.name is None: | ||||
|             for name in modules: | ||||
|                 yield DummyDistribution(name) | ||||
|         if context.name in modules: | ||||
|             yield DummyDistribution(context.name) | ||||
|  | ||||
| class DummyDistribution(Distribution): | ||||
|     def __init__(self, name): | ||||
|         self._metadata = { | ||||
|             "Name": name, | ||||
|             "Version": "0.0.1+fake_multipy", | ||||
|         } | ||||
|  | ||||
|     @property | ||||
|     def metadata(self): | ||||
|         return self._metadata | ||||
|  | ||||
| sys.meta_path.insert(0, F()) | ||||
| )PYTHON"; | ||||
|  | ||||
| @ -87,9 +112,9 @@ void BuiltinRegistry::runPostInitialization() { | ||||
|   TORCH_INTERNAL_ASSERT(Py_IsInitialized()); | ||||
|   std::string metaPathSetupScript(metaPathSetupTemplate); | ||||
|   std::string replaceKey = "<<<DEPLOY_BUILTIN_MODULES_CSV>>>"; | ||||
|   auto itr = metaPathSetupScript.find(replaceKey); | ||||
|   if (itr != std::string::npos) { | ||||
|     metaPathSetupScript.replace(itr, replaceKey.size(), getBuiltinModulesCSV()); | ||||
|   size_t pos = metaPathSetupScript.find(replaceKey); | ||||
|   if (pos != std::string::npos) { | ||||
|     metaPathSetupScript.replace(pos, replaceKey.size(), getBuiltinModulesCSV()); | ||||
|   } | ||||
|   int r = PyRun_SimpleString(metaPathSetupScript.c_str()); | ||||
|   TORCH_INTERNAL_ASSERT(r == 0); | ||||
|  | ||||
| @ -22,7 +22,7 @@ | ||||
|  * BuiltinRegisterer object. The constructor of BuiltinRegisterer does the real | ||||
|  * registration work. | ||||
|  */ | ||||
| #include <gtest/gtest.h> | ||||
| #include <gtest/gtest_prod.h> | ||||
| #include <cstdarg> | ||||
| #include <memory> | ||||
| #include <unordered_map> | ||||
|  | ||||
| @ -451,6 +451,18 @@ result = torch.Tensor([1,2,3]) | ||||
|   EXPECT_TRUE(w_grad0.equal(w_grad1)); | ||||
| } | ||||
|  | ||||
| TEST(TorchpyTest, ImportlibMetadata) { | ||||
|   torch::deploy::InterpreterManager m(1); | ||||
|   m.registerModuleSource("importlib_test", R"PYTHON( | ||||
| from importlib.metadata import version | ||||
|  | ||||
| result = version("torch") | ||||
| )PYTHON"); | ||||
|   auto I = m.allInstances()[0].acquireSession(); | ||||
|   auto ver = I.global("importlib_test", "result").toIValue().toString(); | ||||
|   ASSERT_EQ(ver->string(), "0.0.1+fake_multipy"); | ||||
| } | ||||
|  | ||||
| // OSS build does not have bultin numpy support yet. Use this flag to guard the | ||||
| // test case. | ||||
| #if HAS_NUMPY | ||||
|  | ||||
		Reference in New Issue
	
	Block a user