mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[package] _custom_import_pickler -> _package_pickler (#53048)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53048 I am planning the custom pickler and unpicklers that we use as semi-public interfaces for `torch.rpc` to consume. Some prefatory movements here. Test Plan: Imported from OSS Reviewed By: Lilyjjo Differential Revision: D26734594 Pulled By: suo fbshipit-source-id: 105ae1161d90f24efc7070a8d80c6ac3d2111bea
This commit is contained in:
committed by
Facebook GitHub Bot
parent
272dfc7bb9
commit
ec128eadea
@ -1,7 +1,7 @@
|
||||
import io
|
||||
import torch
|
||||
import importlib
|
||||
from torch.package._custom_import_pickler import create_custom_import_pickler
|
||||
from torch.package._package_pickler import create_pickler
|
||||
from torch.package.package_importer import _UnpicklerWrapper
|
||||
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
|
||||
from torch.serialization import _maybe_decode_ascii
|
||||
@ -32,7 +32,7 @@ def _save_storages(importer, obj):
|
||||
importers = OrderedImporter(importer, sys_importer)
|
||||
else:
|
||||
importers = sys_importer
|
||||
pickler = create_custom_import_pickler(data_buf, importers)
|
||||
pickler = create_pickler(data_buf, importers)
|
||||
pickler.persistent_id = persistent_id
|
||||
pickler.dump(obj)
|
||||
data_value = data_buf.getvalue()
|
||||
|
||||
@ -6,7 +6,7 @@ from struct import pack
|
||||
from .importer import Importer, sys_importer, ObjMismatchError, ObjNotFoundError
|
||||
|
||||
|
||||
class CustomImportPickler(_Pickler):
|
||||
class PackagePickler(_Pickler):
|
||||
dispatch = _Pickler.dispatch.copy()
|
||||
|
||||
def __init__(self, importer: Importer, *args, **kwargs):
|
||||
@ -73,10 +73,10 @@ class CustomImportPickler(_Pickler):
|
||||
self.memoize(obj)
|
||||
dispatch[FunctionType] = save_global
|
||||
|
||||
def create_custom_import_pickler(data_buf, importer):
|
||||
def create_pickler(data_buf, importer):
|
||||
if importer is sys_importer:
|
||||
# if we are using the normal import library system, then
|
||||
# we can use the C implementation of pickle which is faster
|
||||
return Pickler(data_buf, protocol=3)
|
||||
else:
|
||||
return CustomImportPickler(importer, data_buf, protocol=3)
|
||||
return PackagePickler(importer, data_buf, protocol=3)
|
||||
@ -4,7 +4,7 @@ import collections
|
||||
import io
|
||||
import pickletools
|
||||
from .find_file_dependencies import find_files_source_depends_on
|
||||
from ._custom_import_pickler import create_custom_import_pickler
|
||||
from ._package_pickler import create_pickler
|
||||
from ._file_structure_representation import _create_folder_from_file_list, Folder
|
||||
from ._glob_group import GlobPattern, _GlobGroup
|
||||
from ._importlib import _normalize_path
|
||||
@ -302,7 +302,7 @@ node [shape=box];
|
||||
filename = self._filename(package, resource)
|
||||
# Write the pickle data for `obj`
|
||||
data_buf = io.BytesIO()
|
||||
pickler = create_custom_import_pickler(data_buf, self.importer)
|
||||
pickler = create_pickler(data_buf, self.importer)
|
||||
pickler.persistent_id = self._persistent_id
|
||||
pickler.dump(obj)
|
||||
data_value = data_buf.getvalue()
|
||||
|
||||
Reference in New Issue
Block a user