add support for pickle v4 (#70642)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70642

Review history on https://github.com/pytorch/pytorch/pull/70014

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D33414364

Pulled By: PaliC

fbshipit-source-id: 7e7ed491c6f16d4fac3a03f7e403935823c03aa6
This commit is contained in:
Sahan Paliskara
2022-01-10 11:10:06 -08:00
committed by Facebook GitHub Bot
parent 118bd82dde
commit 3ef10da97d
2 changed files with 54 additions and 19 deletions

View File

@ -98,10 +98,10 @@ class PackagePickler(_Pickler):
self.memoize(obj)
def create_pickler(data_buf, importer):
def create_pickler(data_buf, importer, protocol=4):
if importer is sys_importer:
# if we are using the normal import library system, then
# we can use the C implementation of pickle which is faster
return Pickler(data_buf, protocol=3)
return Pickler(data_buf, protocol=protocol)
else:
return PackagePickler(importer, data_buf, protocol=3)
return PackagePickler(importer, data_buf, protocol=protocol)

View File

@ -19,6 +19,7 @@ from typing import (
Set,
Union,
cast,
DefaultDict,
)
import torch
@ -39,7 +40,6 @@ _gate_torchscript_serialization = True
ActionHook = Callable[["PackageExporter", str], None]
class _ModuleProviderAction(Enum):
"""Represents one of the actions that :class:`PackageExporter` can take on a module.
@ -550,7 +550,10 @@ class PackageExporter:
self.add_dependency(dep)
def save_pickle(
self, package: str, resource: str, obj: Any, dependencies: bool = True
self, package: str,
resource: str, obj: Any,
dependencies: bool = True,
pickle_protocol: int = 3,
):
"""Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects.
@ -568,10 +571,13 @@ class PackageExporter:
obj (Any): The object to save, must be picklable.
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
"""
assert ((pickle_protocol == 4) or (pickle_protocol == 3)), "torch.package only supports pickle protocols 3 and 4"
filename = self._filename(package, resource)
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = create_pickler(data_buf, self.importer)
pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
pickler.persistent_id = self._persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
@ -584,32 +590,61 @@ class PackageExporter:
is_pickle=True,
)
def _check_mocked_error(module: Optional[str], field: Optional[str]):
assert isinstance(module, str)
assert isinstance(field, str)
if self._can_implicitly_extern(module):
return
for pattern, pattern_info in self.patterns.items():
if pattern.matches(module):
if pattern_info.action == _ModuleProviderAction.MOCK:
raise NotImplementedError(
f"Object '{field}' from module {module} was mocked out during packaging "
f"but is being used in resource - {resource} in package {package}. "
"If this error is happening during 'save_pickle', please ensure that your "
"pickled object doesn't contain any mocked objects."
)
else:
return
if dependencies:
all_dependencies = []
module = None
field = None
memo: DefaultDict[int, str] = defaultdict(None)
memo_count = 0
# pickletools.dis(data_value)
for opcode, arg, pos in pickletools.genops(data_value):
if opcode.name == "GLOBAL": # a global reference
if pickle_protocol == 4:
if opcode.name == "SHORT_BINUNICODE" or opcode.name == "BINUNICODE8":
assert isinstance(arg, str)
module = field
field = arg
memo[memo_count] = arg
elif opcode.name == "BINGET_LONG" or opcode.name == "BINGET" or opcode.name == "GET":
assert isinstance(arg, int)
module = field
field = memo.get(arg, None)
elif opcode.name == "MEMOIZE":
memo_count += 1
elif opcode.name == "STACK_GLOBAL":
assert isinstance(module, str)
if module not in all_dependencies:
all_dependencies.append(module)
_check_mocked_error(module, field)
elif pickle_protocol == 3 and opcode.name == "GLOBAL": # a global reference
assert isinstance(arg, str)
module, field = arg.split(" ")
if module not in all_dependencies:
all_dependencies.append(module)
for pattern, pattern_info in self.patterns.items():
if pattern.matches(module):
if pattern_info.action == _ModuleProviderAction.MOCK:
raise NotImplementedError(
f"Object '{field}' from module {module} was mocked out during packaging "
f"but is being used in resource - {resource} in package {package}. "
"If this error is happening during 'save_pickle', please ensure that your "
"pickled object doesn't contain any mocked objects."
)
else:
break
_check_mocked_error(module, field)
for module_name in all_dependencies:
self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
self.add_dependency(module_name)
self._write(filename, data_value)
def save_text(self, package: str, resource: str, text: str):
"""Save text data to the package.