mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Release the GIL in serialization when it is safe to do so (#120818)
In particular this ensures we release the GIL when serializing: - PyBytes objects (this is how we get the pickle object) - Storage objects Other string-like objects keep the gil which is fine because we only use this for very small strings today (for endianess) and so releasing the GIL is not important there Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/120818 Approved by: https://github.com/colesbury
This commit is contained in:
@ -1353,7 +1353,7 @@ class PyTorchFileWriter:
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def write_record(self, name: str, data: Union[bytes, _int], size: _int) -> None: ...
|
||||
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
def get_all_written_records(self) -> List[str]: ...
|
||||
|
@ -1383,6 +1383,7 @@ void initJITBindings(PyObject* module) {
|
||||
if (size == 0) {
|
||||
return size;
|
||||
}
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto memory_view = py::memoryview::from_memory(
|
||||
reinterpret_cast<const char*>(data), size);
|
||||
buffer.attr("write")(std::move(memory_view));
|
||||
@ -1396,18 +1397,50 @@ void initJITBindings(PyObject* module) {
|
||||
[](PyTorchStreamWriter& self,
|
||||
const std::string& name,
|
||||
const char* data,
|
||||
size_t size) { return self.writeRecord(name, data, size); })
|
||||
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
|
||||
.def("set_min_version", &PyTorchStreamWriter::setMinVersion)
|
||||
size_t size) {
|
||||
// Since we don't know where the data come from, we cannot
|
||||
// release the GIL in this overload
|
||||
return self.writeRecord(name, data, size);
|
||||
})
|
||||
.def(
|
||||
"write_record",
|
||||
[](PyTorchStreamWriter& self,
|
||||
const std::string& name,
|
||||
py::bytes data,
|
||||
size_t size) {
|
||||
// It is not clear from the doc but according to CPython own code,
|
||||
// it is ok to use the result of PyBytes_AsString without the GIL
|
||||
// being held
|
||||
// https://github.com/python/cpython/blob/e2a3e4b7488aff6fdc704a0f258bc315e96c1d6e/Objects/stringlib/join.h#L67
|
||||
const char* data_str = PyBytes_AsString(data.ptr());
|
||||
py::gil_scoped_release release;
|
||||
return self.writeRecord(name, data_str, size);
|
||||
})
|
||||
.def(
|
||||
"write_record",
|
||||
[](PyTorchStreamWriter& self,
|
||||
const std::string& name,
|
||||
c10::Storage data,
|
||||
size_t size) {
|
||||
// Reading Tensor data is always ok without the GIL held
|
||||
py::gil_scoped_release release;
|
||||
return self.writeRecord(
|
||||
name, reinterpret_cast<const char*>(data.data()), size);
|
||||
})
|
||||
.def(
|
||||
"write_record",
|
||||
[](PyTorchStreamWriter& self,
|
||||
const std::string& name,
|
||||
uintptr_t data,
|
||||
size_t size) {
|
||||
TORCH_WARN_ONCE(
|
||||
"write_record(): Passing Storage by data pointer is deprecated and will be an error in ",
|
||||
"the future, please pass the Storage object instead.");
|
||||
return self.writeRecord(
|
||||
name, reinterpret_cast<const char*>(data), size);
|
||||
})
|
||||
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
|
||||
.def("set_min_version", &PyTorchStreamWriter::setMinVersion)
|
||||
.def("archive_name", &PyTorchStreamWriter::archiveName)
|
||||
.def("serialization_id", &PyTorchStreamWriter::serializationId)
|
||||
.def(
|
||||
|
@ -941,7 +941,7 @@ class PackageExporter:
|
||||
storage = storage.cpu()
|
||||
num_bytes = storage.nbytes()
|
||||
self.zip_file.write_record(
|
||||
f".data/{storage_id}.storage", storage.data_ptr(), num_bytes
|
||||
f".data/{storage_id}.storage", storage, num_bytes
|
||||
)
|
||||
return ("storage", storage_type, storage_id, location, storage_numel)
|
||||
|
||||
|
@ -859,7 +859,7 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
|
||||
storage = storage.cpu()
|
||||
# Now that it is on the CPU we can directly copy it into the zip file
|
||||
num_bytes = storage.nbytes()
|
||||
zip_file.write_record(name, storage.data_ptr(), num_bytes)
|
||||
zip_file.write_record(name, storage, num_bytes)
|
||||
|
||||
|
||||
def load(
|
||||
|
Reference in New Issue
Block a user