Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
Use proper `os.fspath` to better convert `os.PathLike` object to a path.
Replace `pathlib.Path` with `os.PathLike` which is more generic and typing correct. `pathlib.Path` is an instance of `os.PathLike`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116562
Approved by: https://github.com/malfet
Use the same strategy as for unsafe pickler, i.e. use dummy `torch.serialization.StorageType` to represent legacy typed storage classes during deserialization. Add `_dtype` property to be able to use it for both new and legacy format deserialization.
Parametrize `test_serialization_new_format_old_format_compat`
Add regression test to validate that loading legacy modes can be done
without any warnings
Before the change:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_safe_cpu (__main__.TestBothSerializationCPU) ... /Users/nshulga/git/pytorch/pytorch/torch/_utils.py:836: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return self.fget.__get__(instance, owner)()
ok
----------------------------------------------------------------------
Ran 2 tests in 0.116s
OK
```
Without the change but update test to catch warnings:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_weights_only_False_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ... FAIL
======================================================================
FAIL: test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2536, in wrapper
method(*args, **kwargs)
File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
result = test(self, **param_kwargs)
File "/Users/nshulga/git/pytorch/pytorch/test/test_serialization.py", line 807, in test_serialization_new_format_old_format_compat
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
AssertionError: False is not true : Expected no warnings but got ["{message : UserWarning('TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()'), category : 'UserWarning', filename : '/Users/nshulga/git/pytorch/pytorch/torch/_utils.py', lineno : 836, line : None}"]
To execute this test, run the following from the base repo dir:
python test/test_serialization.py -k test_serialization_new_format_old_format_compat_weights_only_True_cpu
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 2 tests in 0.109s
FAILED (failures=1)
```
Fixes problem reported in https://github.com/pytorch/pytorch/issues/52181#issuecomment-1715738910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113614
Approved by: https://github.com/kit1980, https://github.com/albanD
Fixes#111876
`torch.load` without setting `weights_only=True` is unsafe. So updating examples of `torch.load` to use `weights_only=True` where possible and `weights_only=False` elsewhere with a warning of being unsafety.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112860
Approved by: https://github.com/kit1980
**Get device index by torch.privateuse1._utils._get_device_index, if the metched exists.**
Reason:
Can only get device_index 0 if ```location``` such as 'privateuse1' before modify.
Can get accurate deivce index use _get_device_index in this scenario.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108123
Approved by: https://github.com/albanD
# Motivation
fix hpu deserialization bug. It should check hpu model if and only if location start with hpu. Otherwise, it always raise an AssertError if hpu is not imported. This break the serialization/desirialization functionality abourt other third-party like IPEX.
# Solution
only assert hpu model when start with hpu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109499
Approved by: https://github.com/ezyang
Fixes#108955.
Right now, the `_is_zipfile` check in `torch.load` performs multiple `read()` calls, reading 1 byte at a time in a loop. This is rather wasteful and leads to performance problems when accessing files on a network share (see #108955) .
This PR replaces those 1 byte calls with a single big call. Functionally, this is equivalent as `read(n)` only reads up to `n` bytes, so even if the file is shorter there should not be any problems.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109119
Approved by: https://github.com/mikaylagawarecki
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None
Towards enabling mypy-1.4.1 in lintrunner
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>
> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run
<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>
```
# test_load_save_gpt.py
from model import GPT
import torch
import time
torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 48
n_head: int = 25
n_embd: int = 1600
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
def f():
model = GPT(GPTConfig())
state_dict = model.state_dict()
start_saving = time.time()
torch.save(state_dict, "gpt2-xlarge.pth")
end_saving = time.time()
if __name__ == "__main__":
f()
```
</details>
<details><summary><b>Click for script to load</b></summary>
```
# test_load_gpt.py
import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse
def f(mmap, meta):
device = 'meta' if meta else 'cpu'
assign = True if meta else False
with torch.device(device):
model = GPT(GPTConfig())
start_loading = time.time()
loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
end_loading = time.time()
print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
model.load_state_dict(loaded_state_dict, assign=assign)
end_load_state_dict = time.time()
print("load_state_dict time: ", end_load_state_dict - end_loading)
model.cuda()
end_cuda = time.time()
print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
parser.add_argument('-m', '--mmap', action='store_true')
parser.add_argument('-d', '--devicemeta', action='store_true')
args = parser.parse_args()
mmap = args.mmap
meta = args.devicemeta
f(mmap, meta)
```
</details>
`python test_load_gpt.py`
<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">
`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">
If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have
`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">
\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by
1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image
`python test_load_gpt.py`
Docker will Kill the process due to OOM whereas
`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
Summary: The new logger allows passing metadata into the api usage logger. The immediate use case is to pass the serialization_id to the save and load events to be enable tracking serialized models in API events. It could be extended to add more metadata in the future.
Test Plan:
```
buck2 test @//mode/dev //caffe2/caffe2/serialize:inline_container_test
```
Reviewed By: davidberard98
Differential Revision: D45683697
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101762
Approved by: https://github.com/davidberard98
add entry for privateuse1 storage serialization register_package in _register_device_module.
1. User only need to implement `privateuse1_tag` and `privateuse1_deserialize` in the device module of open device. When registering device module, the methods are registered with _package_registry in storage serialization.
2. Provides a fixed sequence number 30 for privateuse1 in storage serialization _package_registry list.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98920
Approved by: https://github.com/ezyang
Changes:
1. `typing_extensions -> typing-extentions` in dependency. Use dash rather than underline to fit the [PEP 503: Normalized Names](https://peps.python.org/pep-0503/#normalized-names) convention.
```python
import re
def normalize(name):
return re.sub(r"[-_.]+", "-", name).lower()
```
2. Import `Literal`, `Protocal`, and `Final` from standard library as of Python 3.8+
3. Replace `Union[Literal[XXX], Literal[YYY]]` to `Literal[XXX, YYY]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94490
Approved by: https://github.com/ezyang, https://github.com/albanD
Avoid double exception in destructor if attempting to serialize to
python object that does not have `write` method
Use `Finalizer` class in `PyTorchStreamWriter::writeEndOfFile()` to a
always set `finailized_` property even if excretion occurs. (as there
isn't much one can do at this point)
Add expicit check for the attribue to `_open_zipfile_writer_buffer` and
add unitests
Modernize code a bit by using Python-3 `super()` method
Fixes https://github.com/pytorch/pytorch/issues/87997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88128
Approved by: https://github.com/albanD
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants.
Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.
To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang