mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add write_record_metadata
to PyTorchFileWriter (#125184)"
This reverts commit dd92637f445d2787f83829079276f71b1ad1fc7c. Reverted https://github.com/pytorch/pytorch/pull/125184 on behalf of https://github.com/izaitsevfb due to breaks internal builds, see D56962076 ([comment](https://github.com/pytorch/pytorch/pull/125184#issuecomment-2094976897))
This commit is contained in:
@ -620,35 +620,15 @@ size_t ostream_write_func(
|
||||
return ret;
|
||||
}
|
||||
|
||||
// This func will not update combined_uncomp_crc32_ with the uncomp_crc32
|
||||
// since there is no way to get the uncomp_crc32 when no buffer is provided.
|
||||
size_t ostream_seek_func(
|
||||
void* pOpaque,
|
||||
mz_uint64 file_ofs,
|
||||
size_t n) {
|
||||
auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
|
||||
if (self->current_pos_ != file_ofs) {
|
||||
CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
|
||||
}
|
||||
size_t ret = self->seek_func_(n);
|
||||
if (self->current_pos_ + n != ret) {
|
||||
self->err_seen_ = true;
|
||||
}
|
||||
self->current_pos_ += n;
|
||||
return n;
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
|
||||
: archive_name_(basename(file_name)) {
|
||||
setup(file_name);
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
const std::function<size_t(size_t)> seek_func)
|
||||
const std::function<size_t(const void*, size_t)> writer_func)
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func),
|
||||
seek_func_(seek_func) {
|
||||
writer_func_(writer_func) {
|
||||
setup(archive_name_);
|
||||
}
|
||||
|
||||
@ -677,15 +657,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_stream_.write(static_cast<const char*>(buf), nbytes);
|
||||
return !file_stream_ ? 0 : nbytes;
|
||||
};
|
||||
seek_func_ = [this](size_t nbytes) -> size_t {
|
||||
file_stream_.seekp(nbytes, std::ios_base::cur);
|
||||
return file_stream_.tellp();
|
||||
};
|
||||
}
|
||||
|
||||
ar_->m_pIO_opaque = this;
|
||||
ar_->m_pWrite = ostream_write_func;
|
||||
ar_->m_pSeek = ostream_seek_func;
|
||||
|
||||
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
|
||||
valid("initializing archive ", file_name.c_str());
|
||||
@ -715,20 +690,20 @@ void PyTorchStreamWriter::writeRecord(
|
||||
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
mz_zip_writer_add_mem_ex_v2(
|
||||
/*pZip=*/ar_.get(),
|
||||
/*pArchive_name=*/full_name.c_str(),
|
||||
/*pBuf=*/data,
|
||||
/*buf_size=*/size,
|
||||
/*pComment=*/nullptr,
|
||||
/*comment_size=*/0,
|
||||
/*level_and_flags=*/flags,
|
||||
/*uncomp_size=*/0,
|
||||
/*uncomp_crc32=*/0,
|
||||
/*last_modified=*/nullptr,
|
||||
/*user_extra_data=*/padding_.c_str(),
|
||||
/*user_extra_data_len=*/padding_size,
|
||||
/*user_extra_data_central=*/nullptr,
|
||||
/*user_extra_data_central_len=*/0);
|
||||
ar_.get(),
|
||||
full_name.c_str(),
|
||||
data,
|
||||
size,
|
||||
nullptr,
|
||||
0,
|
||||
flags,
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
padding_.c_str(),
|
||||
padding_size,
|
||||
nullptr,
|
||||
0);
|
||||
valid("writing file ", name.c_str());
|
||||
files_written_.insert(name);
|
||||
}
|
||||
|
@ -203,21 +203,11 @@ class TORCH_API PyTorchStreamReader final {
|
||||
size_t additional_reader_size_threshold_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
size_t default_seek_func(size_t nbytes) {
|
||||
TORCH_CHECK(false, "attempting to write record metadata but seek_func unimplemented, please implement seek_func");
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class TORCH_API PyTorchStreamWriter final {
|
||||
public:
|
||||
explicit PyTorchStreamWriter(const std::string& archive_name);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
const std::function<size_t(size_t)> seek_func = default_seek_func);
|
||||
const std::function<size_t(const void*, size_t)> writer_func);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
@ -256,7 +246,6 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
std::string padding_;
|
||||
std::ofstream file_stream_;
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
std::function<size_t(size_t)> seek_func_;
|
||||
uint64_t combined_uncomp_crc32_ = 0;
|
||||
std::string serialization_id_;
|
||||
|
||||
@ -270,10 +259,6 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
uint64_t file_ofs,
|
||||
const void* pBuf,
|
||||
size_t n);
|
||||
friend size_t ostream_seek_func(
|
||||
void* pOpaque,
|
||||
uint64_t file_ofs,
|
||||
size_t n);
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
@ -4000,50 +4000,6 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
y['even'][0] = torch.tensor(-0.25, dtype=dtype)
|
||||
self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))
|
||||
|
||||
@parametrize('filename', (True, False))
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_filewriter_metadata_writing(self, filename):
|
||||
sd = torch.nn.Linear(3, 5).state_dict()
|
||||
weight_nbytes = sd['weight'].untyped_storage().nbytes()
|
||||
bias_nbytes = sd['bias'].untyped_storage().nbytes()
|
||||
# TemporaryFileName will give a string
|
||||
# NamedTemporaryFile will be treated as a buffer
|
||||
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
|
||||
|
||||
with file_creation_func() as f, file_creation_func() as g:
|
||||
# save state_dict in f
|
||||
torch.save(sd, f)
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
# extract 'data.pkl' for use in our fake checkpoint
|
||||
with torch.serialization._open_file_like(f, 'rb') as opened_file:
|
||||
with torch.serialization._open_zipfile_reader(opened_file) as zip_file:
|
||||
data_file = io.BytesIO(zip_file.get_record('data.pkl'))
|
||||
data_0_offset = zip_file.get_record_offset('data/0')
|
||||
data_1_offset = zip_file.get_record_offset('data/1')
|
||||
|
||||
# write nulls for 'data/0' and 'data/1'
|
||||
with open(f if filename else f.name, 'rb+') as opened_f:
|
||||
opened_f.seek(data_0_offset)
|
||||
opened_f.write(b'0' * weight_nbytes)
|
||||
opened_f.seek(data_1_offset)
|
||||
opened_f.write(b'0' * bias_nbytes)
|
||||
|
||||
with torch.serialization._open_zipfile_writer(g) as zip_file:
|
||||
data_value = data_file.getvalue()
|
||||
zip_file.write_record('data.pkl', data_value, len(data_value))
|
||||
zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
|
||||
# Only write metadata for storages
|
||||
zip_file.write_record_metadata('data/0', weight_nbytes)
|
||||
zip_file.write_record_metadata('data/1', bias_nbytes)
|
||||
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
g.seek(0)
|
||||
sd_loaded = torch.load(g)
|
||||
sd_loaded_ref = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd_loaded_ref)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
20
third_party/miniz-2.1.0/miniz.c
vendored
20
third_party/miniz-2.1.0/miniz.c
vendored
@ -6250,7 +6250,6 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
mz_uint32 extra_size = 0;
|
||||
mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
|
||||
mz_uint16 bit_flags = 0;
|
||||
mz_bool write_metadata_only = buf_size && !pBuf;
|
||||
|
||||
if ((int)level_and_flags < 0)
|
||||
level_and_flags = MZ_DEFAULT_LEVEL;
|
||||
@ -6264,7 +6263,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
level = level_and_flags & 0xF;
|
||||
store_data_uncompressed = ((!level) || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA));
|
||||
|
||||
if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
|
||||
if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || ((buf_size) && (!pBuf)) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION))
|
||||
return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER);
|
||||
|
||||
pState = pZip->m_pState;
|
||||
@ -6309,9 +6308,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
|
||||
if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
|
||||
{
|
||||
if (!write_metadata_only) {
|
||||
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
|
||||
}
|
||||
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
|
||||
uncomp_size = buf_size;
|
||||
if (uncomp_size <= 3)
|
||||
{
|
||||
@ -6333,8 +6330,8 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
if (!pState->m_zip64)
|
||||
{
|
||||
/* Bail early if the archive would obviously become too large */
|
||||
if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
|
||||
+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
|
||||
if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size
|
||||
+ MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len +
|
||||
pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len
|
||||
+ MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF)
|
||||
{
|
||||
@ -6445,14 +6442,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
|
||||
if (store_data_uncompressed)
|
||||
{
|
||||
mz_bool write_failed;
|
||||
if (write_metadata_only) {
|
||||
write_failed = pZip->m_pSeek(pZip->m_pIO_opaque, cur_archive_file_ofs, buf_size) != buf_size;
|
||||
} else {
|
||||
write_failed = pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size;
|
||||
}
|
||||
|
||||
if (write_failed)
|
||||
if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size)
|
||||
{
|
||||
pZip->m_pFree(pZip->m_pAlloc_opaque, pComp);
|
||||
return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED);
|
||||
|
6
third_party/miniz-2.1.0/miniz.h
vendored
6
third_party/miniz-2.1.0/miniz.h
vendored
@ -116,7 +116,7 @@
|
||||
|
||||
|
||||
|
||||
/* Defines to completely disable specific portions of miniz.c:
|
||||
/* Defines to completely disable specific portions of miniz.c:
|
||||
If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */
|
||||
|
||||
/* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */
|
||||
@ -139,7 +139,7 @@
|
||||
/* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */
|
||||
#define MINIZ_NO_ZLIB_COMPATIBLE_NAMES
|
||||
|
||||
/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
|
||||
/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc.
|
||||
Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc
|
||||
callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user
|
||||
functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */
|
||||
@ -980,7 +980,6 @@ typedef struct
|
||||
|
||||
typedef size_t (*mz_file_read_func)(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n);
|
||||
typedef size_t (*mz_file_write_func)(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n);
|
||||
typedef size_t (*mz_file_seek_func)(void *pOpaque, mz_uint64 file_ofs, size_t n);
|
||||
typedef mz_bool (*mz_file_needs_keepalive)(void *pOpaque);
|
||||
|
||||
struct mz_zip_internal_state_tag;
|
||||
@ -1072,7 +1071,6 @@ typedef struct mz_zip_archive /* note: added name so it can be forward declared
|
||||
|
||||
mz_file_read_func m_pRead;
|
||||
mz_file_write_func m_pWrite;
|
||||
mz_file_seek_func m_pSeek;
|
||||
mz_file_needs_keepalive m_pNeeds_keepalive;
|
||||
void *m_pIO_opaque;
|
||||
|
||||
|
@ -1394,21 +1394,9 @@ void initJITBindings(PyObject* module) {
|
||||
buffer.attr("write")(std::move(memory_view));
|
||||
return size;
|
||||
};
|
||||
auto seek_func = [=](size_t offset) {
|
||||
auto current_pos = py::cast<size_t>(buffer.attr("tell")());
|
||||
buffer.attr("seek")(
|
||||
offset, py::module::import("os").attr("SEEK_CUR"));
|
||||
return current_pos + offset;
|
||||
};
|
||||
return std::make_unique<PyTorchStreamWriter>(
|
||||
std::move(writer_func), std::move(seek_func));
|
||||
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
|
||||
}))
|
||||
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
|
||||
.def(
|
||||
"write_record_metadata",
|
||||
[](PyTorchStreamWriter& self, const std::string& name, size_t size) {
|
||||
return self.writeRecord(name, nullptr, size);
|
||||
})
|
||||
.def(
|
||||
"write_record",
|
||||
[](PyTorchStreamWriter& self,
|
||||
|
Reference in New Issue
Block a user