mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add support for pickle v4 (#70642)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70642 Review history on https://github.com/pytorch/pytorch/pull/70014 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D33414364 Pulled By: PaliC fbshipit-source-id: 7e7ed491c6f16d4fac3a03f7e403935823c03aa6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
118bd82dde
commit
3ef10da97d
@ -98,10 +98,10 @@ class PackagePickler(_Pickler):
|
||||
self.memoize(obj)
|
||||
|
||||
|
||||
def create_pickler(data_buf, importer):
|
||||
def create_pickler(data_buf, importer, protocol=4):
|
||||
if importer is sys_importer:
|
||||
# if we are using the normal import library system, then
|
||||
# we can use the C implementation of pickle which is faster
|
||||
return Pickler(data_buf, protocol=3)
|
||||
return Pickler(data_buf, protocol=protocol)
|
||||
else:
|
||||
return PackagePickler(importer, data_buf, protocol=3)
|
||||
return PackagePickler(importer, data_buf, protocol=protocol)
|
||||
|
@ -19,6 +19,7 @@ from typing import (
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
DefaultDict,
|
||||
)
|
||||
|
||||
import torch
|
||||
@ -39,7 +40,6 @@ _gate_torchscript_serialization = True
|
||||
|
||||
ActionHook = Callable[["PackageExporter", str], None]
|
||||
|
||||
|
||||
class _ModuleProviderAction(Enum):
|
||||
"""Represents one of the actions that :class:`PackageExporter` can take on a module.
|
||||
|
||||
@ -550,7 +550,10 @@ class PackageExporter:
|
||||
self.add_dependency(dep)
|
||||
|
||||
def save_pickle(
|
||||
self, package: str, resource: str, obj: Any, dependencies: bool = True
|
||||
self, package: str,
|
||||
resource: str, obj: Any,
|
||||
dependencies: bool = True,
|
||||
pickle_protocol: int = 3,
|
||||
):
|
||||
"""Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
|
||||
the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects.
|
||||
@ -568,10 +571,13 @@ class PackageExporter:
|
||||
obj (Any): The object to save, must be picklable.
|
||||
dependencies (bool, optional): If ``True``, we scan the source for dependencies.
|
||||
"""
|
||||
|
||||
assert ((pickle_protocol == 4) or (pickle_protocol == 3)), "torch.package only supports pickle protocols 3 and 4"
|
||||
|
||||
filename = self._filename(package, resource)
|
||||
# Write the pickle data for `obj`
|
||||
data_buf = io.BytesIO()
|
||||
pickler = create_pickler(data_buf, self.importer)
|
||||
pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
|
||||
pickler.persistent_id = self._persistent_id
|
||||
pickler.dump(obj)
|
||||
data_value = data_buf.getvalue()
|
||||
@ -584,32 +590,61 @@ class PackageExporter:
|
||||
is_pickle=True,
|
||||
)
|
||||
|
||||
def _check_mocked_error(module: Optional[str], field: Optional[str]):
|
||||
assert isinstance(module, str)
|
||||
assert isinstance(field, str)
|
||||
if self._can_implicitly_extern(module):
|
||||
return
|
||||
for pattern, pattern_info in self.patterns.items():
|
||||
if pattern.matches(module):
|
||||
if pattern_info.action == _ModuleProviderAction.MOCK:
|
||||
raise NotImplementedError(
|
||||
f"Object '{field}' from module {module} was mocked out during packaging "
|
||||
f"but is being used in resource - {resource} in package {package}. "
|
||||
"If this error is happening during 'save_pickle', please ensure that your "
|
||||
"pickled object doesn't contain any mocked objects."
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
if dependencies:
|
||||
all_dependencies = []
|
||||
module = None
|
||||
field = None
|
||||
memo: DefaultDict[int, str] = defaultdict(None)
|
||||
memo_count = 0
|
||||
# pickletools.dis(data_value)
|
||||
for opcode, arg, pos in pickletools.genops(data_value):
|
||||
if opcode.name == "GLOBAL": # a global reference
|
||||
if pickle_protocol == 4:
|
||||
if opcode.name == "SHORT_BINUNICODE" or opcode.name == "BINUNICODE8":
|
||||
assert isinstance(arg, str)
|
||||
module = field
|
||||
field = arg
|
||||
memo[memo_count] = arg
|
||||
elif opcode.name == "BINGET_LONG" or opcode.name == "BINGET" or opcode.name == "GET":
|
||||
assert isinstance(arg, int)
|
||||
module = field
|
||||
field = memo.get(arg, None)
|
||||
elif opcode.name == "MEMOIZE":
|
||||
memo_count += 1
|
||||
elif opcode.name == "STACK_GLOBAL":
|
||||
assert isinstance(module, str)
|
||||
if module not in all_dependencies:
|
||||
all_dependencies.append(module)
|
||||
_check_mocked_error(module, field)
|
||||
elif pickle_protocol == 3 and opcode.name == "GLOBAL": # a global reference
|
||||
assert isinstance(arg, str)
|
||||
module, field = arg.split(" ")
|
||||
if module not in all_dependencies:
|
||||
all_dependencies.append(module)
|
||||
for pattern, pattern_info in self.patterns.items():
|
||||
if pattern.matches(module):
|
||||
if pattern_info.action == _ModuleProviderAction.MOCK:
|
||||
raise NotImplementedError(
|
||||
f"Object '{field}' from module {module} was mocked out during packaging "
|
||||
f"but is being used in resource - {resource} in package {package}. "
|
||||
"If this error is happening during 'save_pickle', please ensure that your "
|
||||
"pickled object doesn't contain any mocked objects."
|
||||
)
|
||||
else:
|
||||
break
|
||||
|
||||
_check_mocked_error(module, field)
|
||||
for module_name in all_dependencies:
|
||||
self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
|
||||
self.add_dependency(module_name)
|
||||
|
||||
self._write(filename, data_value)
|
||||
|
||||
|
||||
def save_text(self, package: str, resource: str, text: str):
|
||||
"""Save text data to the package.
|
||||
|
||||
|
Reference in New Issue
Block a user