mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
Ensure torch.save() deterministic output (#57536)
Summary: Fixes https://github.com/pytorch/pytorch/issues/42163. ## {emoji:1f525} Pitch Currently, the binary outputs produced by `torch.save()` are non-deterministic (as pointed out in https://github.com/pytorch/pytorch/issues/42163). This means that running a simple snippet that creates a tensor (or a model) twice will produce output files with a different `md5` sum. **Why does this occur?** The cause of this behavior lies in the fact that the `obj._cdata` is used to identify a tensor and is written to a file, but the `_cdata` attribute is of course non-deterministic:a80b215a9a/torch/serialization.py (L416)**Why does this matter?** Reproducibility is essential for many Machine Learning projects. For instance, when using [`dvc`](https://dvc.org/) you would expect that if none of the dependencies of a stage of a ML pipeline has changed, then running the same stage another time will produce the same binary output. For the reasons explained above, with `torch` this was not the case, so this PR tries to fix this issue. ## {emoji:1f4cc} Content of this PR ### What changes? - The `persistent_id()` function now returns a deterministic value, rather than `obj._cdata` (which depends on runtime). - As a consequence, `torch.save(obj, "output.pt")` produces a deterministic output, i.e. the `md5` hash of `output.pt` is determinstic. See **Test 1** and **Test 2** below. ### What does not change? - If an `obj` contains several tensors that share the same underlying data (e.g. they are views of the same tensor),the `obj_key` returned by `persistent_id()` is still going to be the same for all of them - As a consequence, serialization optimizes disk storage by storing only necessary tensors, rather than writing one tensor per view. See **Test 3** below. ## � How to test ### Test 1: snipped from https://github.com/pytorch/pytorch/issues/42163 Consider the following `snippet_1.py` (from https://github.com/pytorch/pytorch/issues/42163). ```python import hashlib import torch def get_sha256_hash(file: str, chunk_size: int = 4096) -> str: hasher = hashlib.sha256() with open(file, "rb") as fh: for chunk in iter(lambda: fh.read(chunk_size), b""): hasher.update(chunk) return hasher.hexdigest() file = "tensor.pt" hashes = [] for _ in range(5): obj = torch.ones(1) torch.save(obj, file) hashes.append(get_sha256_hash(file)[:8]) del obj hash = hashes[0] assert all(other == hash for other in hashes[1:]) print(hash) ``` On `master` you obtain an error ```bash $ python snippet_1.py Traceback (most recent call last): File "save_tensor.py", line 84, in <module> assert all(other == hash for other in hashes[1:]) AssertionError ``` while on this PR branch you should get the following consistent behaviour: ```bash $ for run in {1..2}; do python snippet_1.py; done 600a83cb 600a83cb ``` ### Test 2: Deterministic save of `Tensor` and `nn.Module` instances Consider the following `snippet_2.py` ```python import torch torch.manual_seed(0) x = torch.tensor([8., 8., 5., 0.]) torch.save(x, "out_tensor.pt") model = torch.nn.Sequential( torch.nn.Linear(3, 1), torch.nn.Flatten(0, 1) ) torch.save(model, "out_model.pt") ``` On `master` branch, the `md5` hash of `out_tensor.pt` and `out_model.pt` are non-determinstic, for instance you may get ```bash $ for run in {1..2}; do python snippet_2.py; md5 out_*pt; done MD5 (bc9e8af218) (out_model.pt) = 92dca4a310b691e893f3cb41d64d5af1 MD5 (bc9e8af218) (out_tensor.pt) = a4ef290583f50a9c203a42d0cfc078af MD5 (bc9e8af218) (out_model.pt) = de3cb9791a66af8aed77ed7224bd1d5c MD5 (bc9e8af218) (out_tensor.pt) = 3b8a6009d3a0be5b9dd94152dcc0c7cb ``` while on this PR branch you should get the following consistent behaviour: ```bash $ for run in {1..2}; do python snippet_2.py; md5 out_*pt; done MD5 (bc9e8af218) (out_model.pt) = dba75fd50a190e4e7fa89b7a2477bab7 MD5 (bc9e8af218) (out_tensor.pt) = 029f52f0706d6c813cc796d3cdcd3eb0 MD5 (bc9e8af218) (out_model.pt) = dba75fd50a190e4e7fa89b7a2477bab7 MD5 (bc9e8af218) (out_tensor.pt) = 029f52f0706d6c813cc796d3cdcd3eb0 ``` ### Test 3: Views of the same tensor are not re-written to file Consider the following `snippet_3.py`. ```python import torch torch.manual_seed(0) x = torch.rand(1_000, 1_000) y = x.T z = x.view(1_000_000, 1) torch.save({"x": x}, "out_tensor_x.pt") torch.save({"x": x, "y": y, "z": z}, "out_tensor_xyz.pt") ``` Both on `master` branch and on this PR branch you should get two output files with same size: ```bash $ python snippet_3.py && du -sh out_tensor*pt && md5 out_*pt 3.8M out_tensor_x.pt 3.8M out_tensor_xyz.pt MD5 (bc9e8af218) (out_tensor_x.pt) = eda516d9156177b27bdc2a75c9064d9b MD5 (bc9e8af218) (out_tensor_xyz.pt) = 333b869f5b93ced7b8649ab1571eb8e3 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/57536 Reviewed By: bdhirsh Differential Revision: D28304728 Pulled By: ailzhang fbshipit-source-id: 49788e566a3cd2c6c36dc801e6bdd8f42c9459cb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fe3c63d9d3
commit
fea3824214
@ -456,6 +456,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
|
||||
def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
serialized_storages = {}
|
||||
id_map: Dict[int, str] = {}
|
||||
|
||||
def persistent_id(obj):
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
@ -465,7 +466,7 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
||||
if torch.is_storage(obj):
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
obj_key = str(obj._cdata)
|
||||
obj_key = id_map.setdefault(obj._cdata, str(len(id_map)))
|
||||
location = location_tag(obj)
|
||||
serialized_storages[obj_key] = obj
|
||||
|
||||
|
||||
Reference in New Issue
Block a user