From 8ce81bcee1da294a34af0a90dc16483055e8c5a4 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Wed, 6 Aug 2025 02:26:07 +0000 Subject: [PATCH] [Torch Package] Make get names of OrderedImporters support fallback to importers (#155743) Summary: OrderedImporters is supposed to be an importer which tries out every single importer in self._importers. However the get_name API does not follow this behavior and only uses the get_name from the basic Importer class. This change is to update the OrderedImporters get_name API so that it tries the get_name API of every single importers. Differential Revision: D76463252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155743 Approved by: https://github.com/jcwchen, https://github.com/jingsh --- test/package/test_save_load.py | 7 +++---- torch/package/importer.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index a0cc967787e6..edbba9f6f8ee 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -208,11 +208,10 @@ class TestSaveLoad(PackageTestCase): # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe - # This should fail. The 'PackageAObject' type defined from 'importer1' - # is not necessarily the same 'obj2's version of 'PackageAObject'. + # This succeeds because OrderedImporter.get_name() properly + # falls back to sys_importer which can find the original PackageAObject pe = make_exporter() - with self.assertRaises(pickle.PicklingError): - pe.save_pickle("obj", "obj.pkl", obj2) + pe.save_pickle("obj", "obj.pkl", obj2) # This should also fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same as the one defined from 'importer2' diff --git a/torch/package/importer.py b/torch/package/importer.py index 49b4512f79a6..8cfc1e336a45 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import importlib +import logging from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] _getattribute, @@ -13,6 +14,7 @@ from ._mangling import demangle, get_mangle_prefix, is_mangled __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] +log = logging.getLogger(__name__) class ObjNotFoundError(Exception): @@ -204,6 +206,20 @@ class OrderedImporter(Importer): return True return module.__file__ is None + def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: + for importer in self._importers: + try: + return importer.get_name(obj, name) + except (ObjNotFoundError, ObjMismatchError) as e: + warning_message = ( + f"Tried to call get_name with obj {obj}, " + f"and name {name} on {importer} and got {e}" + ) + log.warning(warning_message) + raise ObjNotFoundError( + f"Could not find obj {obj} and name {name} in any of the importers {self._importers}" + ) + def import_module(self, module_name: str) -> ModuleType: last_err = None for importer in self._importers: