[Torch Package] Make get names of OrderedImporters support fallback to importers (#155743)

Summary:
OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class.
This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers.

Differential Revision: D76463252

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743
Approved by: https://github.com/jcwchen, https://github.com/jingsh
This commit is contained in:
Dave Lei
2025-08-06 02:26:07 +00:00
committed by PyTorch MergeBot
parent 4604f0482c
commit 8ce81bcee1
2 changed files with 19 additions and 4 deletions

View File

@ -208,11 +208,10 @@ class TestSaveLoad(PackageTestCase):
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
return pe
# This should fail. The 'PackageAObject' type defined from 'importer1'
# is not necessarily the same 'obj2's version of 'PackageAObject'.
# This succeeds because OrderedImporter.get_name() properly
# falls back to sys_importer which can find the original PackageAObject
pe = make_exporter()
with self.assertRaises(pickle.PicklingError):
pe.save_pickle("obj", "obj.pkl", obj2)
pe.save_pickle("obj", "obj.pkl", obj2)
# This should also fail. The 'PackageAObject' type defined from 'importer1'
# is not necessarily the same as the one defined from 'importer2'

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import importlib
import logging
from abc import ABC, abstractmethod
from pickle import ( # type: ignore[attr-defined]
_getattribute,
@ -13,6 +14,7 @@ from ._mangling import demangle, get_mangle_prefix, is_mangled
__all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
log = logging.getLogger(__name__)
class ObjNotFoundError(Exception):
@ -204,6 +206,20 @@ class OrderedImporter(Importer):
return True
return module.__file__ is None
def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]:
for importer in self._importers:
try:
return importer.get_name(obj, name)
except (ObjNotFoundError, ObjMismatchError) as e:
warning_message = (
f"Tried to call get_name with obj {obj}, "
f"and name {name} on {importer} and got {e}"
)
log.warning(warning_message)
raise ObjNotFoundError(
f"Could not find obj {obj} and name {name} in any of the importers {self._importers}"
)
def import_module(self, module_name: str) -> ModuleType:
last_err = None
for importer in self._importers: