mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53049 This makes our API symmetric--now we have an `Importer` aware Pickler and Unpickler implementation that have similar interfaces. Test Plan: Imported from OSS Reviewed By: Lilyjjo Differential Revision: D26734593 Pulled By: suo fbshipit-source-id: 3479437cf6b98e0d6a8aa4907c75f0c61d5495d4
88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
from pickle import Pickler, _Pickler, _getattribute, _extension_registry, _compat_pickle # type: ignore
|
|
from pickle import GLOBAL, STACK_GLOBAL, EXT1, EXT2, EXT4, PicklingError
|
|
from types import FunctionType
|
|
from struct import pack
|
|
|
|
from .importer import Importer, sys_importer, ObjMismatchError, ObjNotFoundError
|
|
|
|
|
|
class PackagePickler(_Pickler):
|
|
"""Package-aware pickler.
|
|
|
|
This behaves the same as a normal pickler, except it uses an `Importer`
|
|
to find objects and modules to save.
|
|
"""
|
|
dispatch = _Pickler.dispatch.copy()
|
|
|
|
def __init__(self, importer: Importer, *args, **kwargs):
|
|
self.importer = importer
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def save_global(self, obj, name=None):
|
|
# unfortunately the pickler code is factored in a way that
|
|
# forces us to copy/paste this function. The only change is marked
|
|
# CHANGED below.
|
|
write = self.write
|
|
memo = self.memo
|
|
|
|
# CHANGED: import module from module environment instead of __import__
|
|
try:
|
|
module_name, name = self.importer.get_name(obj, name)
|
|
except (ObjNotFoundError, ObjMismatchError) as err:
|
|
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
|
|
|
|
module = self.importer.import_module(module_name)
|
|
_, parent = _getattribute(module, name)
|
|
# END CHANGED
|
|
|
|
if self.proto >= 2:
|
|
code = _extension_registry.get((module_name, name))
|
|
if code:
|
|
assert code > 0
|
|
if code <= 0xff:
|
|
write(EXT1 + pack("<B", code))
|
|
elif code <= 0xffff:
|
|
write(EXT2 + pack("<H", code))
|
|
else:
|
|
write(EXT4 + pack("<i", code))
|
|
return
|
|
lastname = name.rpartition('.')[2]
|
|
if parent is module:
|
|
name = lastname
|
|
# Non-ASCII identifiers are supported only with protocols >= 3.
|
|
if self.proto >= 4:
|
|
self.save(module_name)
|
|
self.save(name)
|
|
write(STACK_GLOBAL)
|
|
elif parent is not module:
|
|
self.save_reduce(getattr, (parent, lastname))
|
|
elif self.proto >= 3:
|
|
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
|
|
bytes(name, "utf-8") + b'\n')
|
|
else:
|
|
if self.fix_imports:
|
|
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
|
|
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
|
|
if (module_name, name) in r_name_mapping:
|
|
module_name, name = r_name_mapping[(module_name, name)]
|
|
elif module_name in r_import_mapping:
|
|
module_name = r_import_mapping[module_name]
|
|
try:
|
|
write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
|
|
bytes(name, "ascii") + b'\n')
|
|
except UnicodeEncodeError:
|
|
raise PicklingError(
|
|
"can't pickle global identifier '%s.%s' using "
|
|
"pickle protocol %i" % (module, name, self.proto)) from None
|
|
|
|
self.memoize(obj)
|
|
dispatch[FunctionType] = save_global
|
|
|
|
def create_pickler(data_buf, importer):
|
|
if importer is sys_importer:
|
|
# if we are using the normal import library system, then
|
|
# we can use the C implementation of pickle which is faster
|
|
return Pickler(data_buf, protocol=3)
|
|
else:
|
|
return PackagePickler(importer, data_buf, protocol=3)
|