mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72237 add a generic zip file reader/writer to torch.package in order to get rid of dependency on torch for non torchscript / tensor related usages of package. This also enables users to create a derived class from the zip file reader/writer classes to have their own serialization/deserialization if it's desired for performance needs. https://www.internalfb.com/intern/diff/D35423079/ was reverted due to this refactor changing the name of where most of the implementation components of PackageExporter/PackageImporter come from like ModuleActionType_ etc. This diff also changes the import paths where these components come from to point to the correct file compared to D35423079 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D35423079 Pulled By: PaliC fbshipit-source-id: 31abc4364d5fd007911cfb67cf36ebfac5d786f4 (cherry picked from commit 023b0d1445e0b1e1bb7a03c660cd62eb9d26d2a6)
68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
from pathlib import Path
|
|
from typing import (
|
|
cast,
|
|
BinaryIO,
|
|
Sequence,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
from torch.serialization import location_tag, normalize_storage_type
|
|
from torch.types import Storage
|
|
|
|
from ._zip_file_torchscript import TorchScriptPackageZipFileWriter
|
|
from .importer import sys_importer, Importer
|
|
from .package_exporter_no_torch import PackageExporter as DefaultPackageExporter
|
|
|
|
|
|
class PackageExporter(DefaultPackageExporter):
|
|
"""
|
|
A package exporter for specialized functionality for torch. Specifically it uses the optimizations
|
|
of torch's storage in order to not duplicate tensors.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
f: Union[str, Path, BinaryIO],
|
|
importer: Union[Importer, Sequence[Importer]] = sys_importer,
|
|
):
|
|
|
|
super(PackageExporter, self).__init__(
|
|
f, importer, zip_file_writer_type=TorchScriptPackageZipFileWriter
|
|
)
|
|
|
|
def persistent_id(self, obj):
|
|
assert isinstance(self.zip_file, TorchScriptPackageZipFileWriter)
|
|
# needed for 'storage' typename which is a way in which torch models are saved
|
|
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
|
|
if isinstance(obj, torch.storage._TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# remove this case
|
|
storage = obj._storage
|
|
storage_type_str = obj.pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
dtype = obj.dtype
|
|
storage_numel = obj.size()
|
|
|
|
else:
|
|
storage = obj
|
|
storage_type = normalize_storage_type(type(storage))
|
|
dtype = torch.uint8
|
|
storage_numel = storage.nbytes()
|
|
|
|
storage = cast(Storage, storage)
|
|
location = location_tag(storage)
|
|
|
|
# serialize storage if not already written
|
|
storage_present = self.zip_file.storage_context.has_storage(storage)
|
|
storage_id = self.zip_file.storage_context.get_or_add_storage(storage)
|
|
if not storage_present:
|
|
if storage.device.type != "cpu":
|
|
storage = storage.cpu()
|
|
num_bytes = storage.nbytes()
|
|
self.zip_file.write_record(
|
|
f".data/{storage_id}.storage", storage.data_ptr(), num_bytes
|
|
)
|
|
return ("storage", storage_type, storage_id, location, storage_numel)
|
|
return None
|