[package] _custom_import_pickler -> _package_pickler (#53048)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53048

I am planning the custom pickler and unpicklers that we use as
semi-public interfaces for `torch.rpc` to consume. Some prefatory
movements here.

Test Plan: Imported from OSS

Reviewed By: Lilyjjo

Differential Revision: D26734594

Pulled By: suo

fbshipit-source-id: 105ae1161d90f24efc7070a8d80c6ac3d2111bea
This commit is contained in:
Michael Suo
2021-03-01 18:34:24 -08:00
committed by Facebook GitHub Bot
parent 272dfc7bb9
commit ec128eadea
3 changed files with 7 additions and 7 deletions

View File

@ -1,7 +1,7 @@
import io
import torch
import importlib
from torch.package._custom_import_pickler import create_custom_import_pickler
from torch.package._package_pickler import create_pickler
from torch.package.package_importer import _UnpicklerWrapper
from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
from torch.serialization import _maybe_decode_ascii
@ -32,7 +32,7 @@ def _save_storages(importer, obj):
importers = OrderedImporter(importer, sys_importer)
else:
importers = sys_importer
pickler = create_custom_import_pickler(data_buf, importers)
pickler = create_pickler(data_buf, importers)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()

View File

@ -6,7 +6,7 @@ from struct import pack
from .importer import Importer, sys_importer, ObjMismatchError, ObjNotFoundError
class CustomImportPickler(_Pickler):
class PackagePickler(_Pickler):
dispatch = _Pickler.dispatch.copy()
def __init__(self, importer: Importer, *args, **kwargs):
@ -73,10 +73,10 @@ class CustomImportPickler(_Pickler):
self.memoize(obj)
dispatch[FunctionType] = save_global
def create_custom_import_pickler(data_buf, importer):
def create_pickler(data_buf, importer):
if importer is sys_importer:
# if we are using the normal import library system, then
# we can use the C implementation of pickle which is faster
return Pickler(data_buf, protocol=3)
else:
return CustomImportPickler(importer, data_buf, protocol=3)
return PackagePickler(importer, data_buf, protocol=3)

View File

@ -4,7 +4,7 @@ import collections
import io
import pickletools
from .find_file_dependencies import find_files_source_depends_on
from ._custom_import_pickler import create_custom_import_pickler
from ._package_pickler import create_pickler
from ._file_structure_representation import _create_folder_from_file_list, Folder
from ._glob_group import GlobPattern, _GlobGroup
from ._importlib import _normalize_path
@ -302,7 +302,7 @@ node [shape=box];
filename = self._filename(package, resource)
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = create_custom_import_pickler(data_buf, self.importer)
pickler = create_pickler(data_buf, self.importer)
pickler.persistent_id = self._persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()