reimport pr137735 due to merging check issues (#138959)

This is  a cherry-pick from #137735 by @mikaylagawarecki , that cannot be merged due to a (wrongly) failing check for codev

@diff-train-skip-merge

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138959
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Wouter Devriendt
2024-10-27 16:31:34 +00:00
committed by PyTorch MergeBot
parent 144d75d934
commit bae3426af7
10 changed files with 119 additions and 36 deletions

View File

@ -36,7 +36,7 @@ def define_targets(rules):
"caffe2/serialize/istream_adapter.cc",
"caffe2/serialize/read_adapter_interface.cc",
],
copts = ["-fexceptions"],
copts = ["-fexceptions", "-DFBCODE_CAFFE2"],
tags = [
"-fbcode",
"supermodule:android/default/pytorch",

View File

@ -621,15 +621,17 @@ size_t ostream_write_func(
return ret;
}
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
: archive_name_(basename(file_name)) {
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
: archive_name_(basename(file_name)),
compute_crc32_(compute_crc32) {
setup(file_name);
}
PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func)
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
: archive_name_("archive"),
writer_func_(writer_func) {
writer_func_(writer_func),
compute_crc32_(compute_crc32) {
setup(archive_name_);
}
@ -695,6 +697,11 @@ void PyTorchStreamWriter::writeRecord(
size_t padding_size =
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
if (!compute_crc32_) {
#if (!defined(FBCODE_CAFFE2))
flags |= MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32;
#endif
}
mz_zip_writer_add_mem_ex_v2(
/*pZip=*/ar_.get(),
/*pArchive_name=*/full_name.c_str(),

View File

@ -205,9 +205,9 @@ class TORCH_API PyTorchStreamReader final {
class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(const std::string& archive_name);
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func);
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
void setMinVersion(const uint64_t version);
@ -248,6 +248,7 @@ class TORCH_API PyTorchStreamWriter final {
std::function<size_t(const void*, size_t)> writer_func_;
uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_;
bool compute_crc32_;
// This number will be updated when the model has operators
// that have valid upgraders.

View File

@ -390,6 +390,8 @@ The following utility functions are related to serialization:
.. currentmodule:: torch.serialization
.. autofunction:: register_package
.. autofunction:: get_crc32_options
.. autofunction:: set_crc32_options
.. autofunction:: get_default_load_endianness
.. autofunction:: set_default_load_endianness
.. autofunction:: get_default_mmap_options

View File

@ -4334,6 +4334,35 @@ class TestSerialization(TestCase, SerializationMixin):
else:
os.environ[env_var] = old_value
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")
@parametrize("compute_crc32", (True, False))
@parametrize("filename", (True, False))
def test_crc32_options(self, compute_crc32, filename):
# test both path and buffer case
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
sd = torch.nn.Linear(3, 5).state_dict()
with file_creation_func() as f:
try:
torch.serialization.set_crc32_options(compute_crc32)
torch.save(sd, f)
if not filename:
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd)
finally:
torch.serialization.set_crc32_options(True)
args = () if compute_crc32 else (zipfile.BadZipFile, "Bad CRC-32 for file")
ctx = contextlib.nullcontext if compute_crc32 else self.assertRaisesRegex
if not filename:
f.seek(0)
# zip_file.extractall() will raise BadZipFile if CRC32 is not populated
# we use the context manager to check whether CRC32 was populated
with ctx(*args), tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(f) as zip_file:
zip_file.extractall(path=temp_dir)
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)

View File

@ -6251,6 +6251,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
mz_uint16 bit_flags = 0;
mz_bool write_metadata_only = buf_size && !pBuf;
mz_bool skip_crc32 = write_metadata_only || (level_and_flags & MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32);
if ((int)level_and_flags < 0)
level_and_flags = MZ_DEFAULT_LEVEL;
@ -6309,7 +6310,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
{
if (!write_metadata_only) {
if (!skip_crc32) {
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
}
uncomp_size = buf_size;

View File

@ -1001,7 +1001,8 @@ typedef enum {
MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY = 0x2000, /* validate the local headers, but don't decompress the entire file and check the crc32 */
MZ_ZIP_FLAG_WRITE_ZIP64 = 0x4000, /* always use the zip64 file format, instead of the original zip file format with automatic switch to zip64. Use as flags parameter with mz_zip_writer_init*_v2 */
MZ_ZIP_FLAG_WRITE_ALLOW_READING = 0x8000,
MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000
MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000,
MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32 = 0x20000, /* don't compute the crc32 of file data that's being added. */
} mz_zip_flags;
typedef enum {

View File

@ -1447,9 +1447,9 @@ class PyTorchFileReader:
class PyTorchFileWriter:
@overload
def __init__(self, name: str) -> None: ...
def __init__(self, name: str, compute_crc32 = True) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ...
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ...

View File

@ -1389,8 +1389,12 @@ void initJITBindings(PyObject* module) {
"fallback", [](GraphExecutorState& s) { return s.fallback; });
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>())
.def(py::init([](const py::object& buffer) {
.def(
py::init<std::string, bool>(),
py::arg("file_name"),
py::arg("compute_crc32") = true)
.def(
py::init([](const py::object& buffer, bool compute_crc32 = true) {
auto writer_func = [=](const void* data, size_t size) {
// Writing an empty file is a noop
if (size == 0) {
@ -1408,9 +1412,15 @@ void initJITBindings(PyObject* module) {
}
return size;
};
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
}))
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
return std::make_unique<PyTorchStreamWriter>(
std::move(writer_func), compute_crc32);
}),
py::arg("buffer"),
py::arg("compute_crc32") = true)
.def(
py::init<const std::function<size_t(const void*, size_t)>&, bool>(),
py::arg("writer_func"),
py::arg("compute_crc32") = true)
// [Note: write_record_metadata]
// The write_record_metadata function is intended to write metadata (i.e.
// the zipfile header and end of central directory record) for a file

View File

@ -53,6 +53,8 @@ __all__ = [
"load",
"StorageType",
"LoadEndianness",
"get_crc32_options",
"set_crc32_options",
"get_default_load_endianness",
"set_default_load_endianness",
"get_default_mmap_options",
@ -167,6 +169,34 @@ def set_default_load_endianness(endianness):
_default_load_endian = endianness
_compute_crc32: bool = True
def get_crc32_options() -> bool:
"""
Get whether :func:`torch.save` computes and writes crc32 for each record.
Defaults to ``True``.
"""
return _compute_crc32
def set_crc32_options(compute_crc32: bool):
"""
Set whether :func:`torch.save` computes and writes crc32 for each record.
.. note::
Setting this to ``False`` may make unzipping of the ``torch.save`` output
fail or warn due to corrupted CRC32. However ``torch.load`` will be
able to load the file.
Args:
compute_crc32 (bool): set crc32 compuation flag
"""
global _compute_crc32
_compute_crc32 = compute_crc32
_default_mmap_options: int = MAP_PRIVATE
@ -682,9 +712,11 @@ class _open_zipfile_writer_file(_opener):
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
super().__init__(
torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)
)
else:
super().__init__(torch._C.PyTorchFileWriter(self.name))
super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -700,7 +732,7 @@ class _open_zipfile_writer_buffer(_opener):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer))
super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()