[3.13] fix 3.13 pickle error in serialization.py (#136034)

Error encountered when adding dynamo 3.13 support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136034
Approved by: https://github.com/albanD
This commit is contained in:
William Wen
2024-09-13 13:02:33 -07:00
committed by PyTorch MergeBot
parent b608ff3bea
commit a00faf4408

View File

@ -1005,8 +1005,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
class PyTorchLegacyPickler(pickle_module.Pickler):
def persistent_id(self, obj):
return persistent_id(obj)
pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol)
pickler.dump(obj)
serialized_storage_keys = sorted(serialized_storages.keys())
@ -1083,8 +1087,12 @@ def _save(
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined]
def persistent_id(self, obj):
return persistent_id(obj)
pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record("data.pkl", data_value, len(data_value))