mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Torch Package] Make get names of OrderedImporters support fallback to importers (#155743)
Summary: OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class. This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers. Differential Revision: D76463252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743 Approved by: https://github.com/jcwchen, https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
4604f0482c
commit
8ce81bcee1
@ -208,11 +208,10 @@ class TestSaveLoad(PackageTestCase):
|
||||
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
|
||||
return pe
|
||||
|
||||
# This should fail. The 'PackageAObject' type defined from 'importer1'
|
||||
# is not necessarily the same 'obj2's version of 'PackageAObject'.
|
||||
# This succeeds because OrderedImporter.get_name() properly
|
||||
# falls back to sys_importer which can find the original PackageAObject
|
||||
pe = make_exporter()
|
||||
with self.assertRaises(pickle.PicklingError):
|
||||
pe.save_pickle("obj", "obj.pkl", obj2)
|
||||
pe.save_pickle("obj", "obj.pkl", obj2)
|
||||
|
||||
# This should also fail. The 'PackageAObject' type defined from 'importer1'
|
||||
# is not necessarily the same as the one defined from 'importer2'
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import importlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_getattribute,
|
||||
@ -13,6 +14,7 @@ from ._mangling import demangle, get_mangle_prefix, is_mangled
|
||||
|
||||
|
||||
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjNotFoundError(Exception):
|
||||
@ -204,6 +206,20 @@ class OrderedImporter(Importer):
|
||||
return True
|
||||
return module.__file__ is None
|
||||
|
||||
def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]:
|
||||
for importer in self._importers:
|
||||
try:
|
||||
return importer.get_name(obj, name)
|
||||
except (ObjNotFoundError, ObjMismatchError) as e:
|
||||
warning_message = (
|
||||
f"Tried to call get_name with obj {obj}, "
|
||||
f"and name {name} on {importer} and got {e}"
|
||||
)
|
||||
log.warning(warning_message)
|
||||
raise ObjNotFoundError(
|
||||
f"Could not find obj {obj} and name {name} in any of the importers {self._importers}"
|
||||
)
|
||||
|
||||
def import_module(self, module_name: str) -> ModuleType:
|
||||
last_err = None
|
||||
for importer in self._importers:
|
||||
|
Reference in New Issue
Block a user