Files
pytorch/torch/package/package_exporter.py
Sahan Paliskara d4a709be3d [pkg] add generic ZipFile Reader/Writer (#72237)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72237

add a generic zip file reader/writer to torch.package in order to get rid of dependency on torch for non torchscript / tensor related usages of package. This also enables users to create a derived class from the zip file reader/writer classes to have their own serialization/deserialization if it's desired for performance needs.

https://www.internalfb.com/intern/diff/D35423079/ was reverted due to this refactor changing the name of where most of the implementation components of PackageExporter/PackageImporter come from like ModuleActionType_ etc.
This diff also changes the import paths where these components come from to point to the correct file compared to D35423079

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D35423079

Pulled By: PaliC

fbshipit-source-id: 31abc4364d5fd007911cfb67cf36ebfac5d786f4
(cherry picked from commit 023b0d1445e0b1e1bb7a03c660cd62eb9d26d2a6)
2022-04-06 16:11:13 -07:00

68 lines
2.5 KiB
Python

from pathlib import Path
from typing import (
cast,
BinaryIO,
Sequence,
Union,
)
import torch
from torch.serialization import location_tag, normalize_storage_type
from torch.types import Storage
from ._zip_file_torchscript import TorchScriptPackageZipFileWriter
from .importer import sys_importer, Importer
from .package_exporter_no_torch import PackageExporter as DefaultPackageExporter
class PackageExporter(DefaultPackageExporter):
"""
A package exporter for specialized functionality for torch. Specifically it uses the optimizations
of torch's storage in order to not duplicate tensors.
"""
def __init__(
self,
f: Union[str, Path, BinaryIO],
importer: Union[Importer, Sequence[Importer]] = sys_importer,
):
super(PackageExporter, self).__init__(
f, importer, zip_file_writer_type=TorchScriptPackageZipFileWriter
)
def persistent_id(self, obj):
assert isinstance(self.zip_file, TorchScriptPackageZipFileWriter)
# needed for 'storage' typename which is a way in which torch models are saved
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
else:
storage = obj
storage_type = normalize_storage_type(type(storage))
dtype = torch.uint8
storage_numel = storage.nbytes()
storage = cast(Storage, storage)
location = location_tag(storage)
# serialize storage if not already written
storage_present = self.zip_file.storage_context.has_storage(storage)
storage_id = self.zip_file.storage_context.get_or_add_storage(storage)
if not storage_present:
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zip_file.write_record(
f".data/{storage_id}.storage", storage.data_ptr(), num_bytes
)
return ("storage", storage_type, storage_id, location, storage_numel)
return None