Fixes#42376
`torch.save` serializes bound methods inside LR scheduler resulting in large serialized file.
Test cases include checking file size, checking if the `anneal_func` is bounded and file is loaded correctly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102627
Approved by: https://github.com/albanD
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run
<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>
```
# test_load_save_gpt.py
from model import GPT
import torch
import time
torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 48
n_head: int = 25
n_embd: int = 1600
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
def f():
model = GPT(GPTConfig())
state_dict = model.state_dict()
start_saving = time.time()
torch.save(state_dict, "gpt2-xlarge.pth")
end_saving = time.time()
if __name__ == "__main__":
f()
```
</details>
<details><summary><b>Click for script to load</b></summary>
```
# test_load_gpt.py
import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse
def f(mmap, meta):
device = 'meta' if meta else 'cpu'
assign = True if meta else False
with torch.device(device):
model = GPT(GPTConfig())
start_loading = time.time()
loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
end_loading = time.time()
print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
model.load_state_dict(loaded_state_dict, assign=assign)
end_load_state_dict = time.time()
print("load_state_dict time: ", end_load_state_dict - end_loading)
model.cuda()
end_cuda = time.time()
print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
parser.add_argument('-m', '--mmap', action='store_true')
parser.add_argument('-d', '--devicemeta', action='store_true')
args = parser.parse_args()
mmap = args.mmap
meta = args.devicemeta
f(mmap, meta)
```
</details>
`python test_load_gpt.py`
<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">
`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">
If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have
`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">
\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by
1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image
`python test_load_gpt.py`
Docker will Kill the process due to OOM whereas
`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
#85303 added a patch to `torch.testing.assert_close` to handle `torch.storage.TypedStorage`'s. This change is not reflected in the docs and is not intended for the public API. This PR removes the patch ones again and moves the behavior to `TestCase.assertEqual` instead. Meaning, `TypedStorage`'s are again not supported by the public API, but the behavior is the same for all internal use cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89557
Approved by: https://github.com/kurtamohler, https://github.com/mruberry
Preparation for the next PR in this stack: #89559.
I replaced
- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).
There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
Fixes#81690
TODO:
* [x] C++ Unpickler Fix (locally tested pickled in Python and unpickled in C++)
* [x] C++ Pickler Fix (locally tested pickled in C++ and unpickled in Python)
* [x] Do quant_tensor, sparse_tensor, etc require similar changes? (Sparse and Quant don't need this)
* [x] Add Comments
* [x] How to make sure C++ and Python are in sync? (Functions in `pickler.h` help in getting and setting Tensor Metadata (math-bits for now) on a tensor. They are the only place which should handle this.)
Notes:
Quant Tensor don't support complex dtypes and for float they segfault with `_neg_view` : https://github.com/pytorch/pytorch/issues/88484
Sparse Tensor:
```python
>>> a = torch.tensor([[0, 2.], [3j, 0]]).to_sparse()
>>> a.conj().is_conj()
False
>>> a._neg_view()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: Cannot access storage of SparseTensorImpl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88182
Approved by: https://github.com/ezyang, https://github.com/anjali411
Avoid double exception in destructor if attempting to serialize to
python object that does not have `write` method
Use `Finalizer` class in `PyTorchStreamWriter::writeEndOfFile()` to a
always set `finailized_` property even if excretion occurs. (as there
isn't much one can do at this point)
Add expicit check for the attribue to `_open_zipfile_writer_buffer` and
add unitests
Modernize code a bit by using Python-3 `super()` method
Fixes https://github.com/pytorch/pytorch/issues/87997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88128
Approved by: https://github.com/albanD
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
### Description
Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.
`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.
Documentation for storages is improved as well.
### Issue
Fixes#82436
### Testing
N/A
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71708
In Python 3.2, a number of asserts were deprecated.
In Python 3.11, these asserts are deleted completely. The files in this change still use the deprecated asserts.
Switch over to the supported syntax for 3.2 onwards.
Test Plan: Tested on the internal test suite runner.
Reviewed By: ajtulloch
Differential Revision: D33503694
fbshipit-source-id: a150f296033260acf8365d77b837ce0679f57361
(cherry picked from commit abf60ed97409265222915d8265aaabedd625fd93)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030
Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible
Fixes https://github.com/pytorch/pytorch/issues/47442
* **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today.
* There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate.
* As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes.
* `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments.
* It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor.
* It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling.
* The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall.
To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization.
Original pull request: https://github.com/pytorch/pytorch/pull/59671
Reviewed By: soulitzer, ngimel
Differential Revision: D29466819
Pulled By: ezyang
fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
Summary:
Happy to get any feedback on how to make this code cleaner!
This:
- Fix Tensor attribute deepcopy BC-breaking?
- Add a test for Tensor attribute deepcopy
- Fix subclass deepcopy
- Moves the subclass serialization tests into their own class not to interfere with other serialization test logic
- Add a test for subclass deepcopy
cc ezyang gchanan
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65584
Reviewed By: gchanan
Differential Revision: D31206590
Pulled By: albanD
fbshipit-source-id: 74a8f0767f4933b9c941fbea880a8fd1b893ea2f