Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
Removes an unnecessary duplicated utility functions and just have it rely on itertools. Since the file is low traffic, I also added the modified files to UFMT'd files and formatted them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116192
Approved by: https://github.com/malfet
By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes
- Add `_rebuild_tensor_v3` that takes an extra dtype argument
- In `Tensor.__reduce_ex__` serialize tensor using untyped storage for
v3_dtypes (which are at the moment limited to float8 dtypes)
Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"`
Fixes https://github.com/pytorch/pytorch/issues/114634
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114662
Approved by: https://github.com/ngimel
Reland - the previous PR was reverted by internal with this error:
```
File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```
I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))
Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
Resolves https://github.com/pytorch/pytorch/issues/107097
After this PR, instead of
```python
torch.sparse_coo_tensor(indices, values, size)._coalesced_(is_coalesced)
```
(that does not work in the autograd context, see #107097), use
```python
torch.sparse_coo_tensor(indices, values, size, is_coalesced=is_coalesced)
```
All sparse coo factory functions that take indices as input support the `is_coalesced` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107638
Approved by: https://github.com/cpuhrsch
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
- Add get_printoptions and printoptions context manager
- Improve edgeitems handling when it is zero
- Add render_call which can be used to conveniently print command
line arguments of a function call, while suppressing actual
tensor data
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102623
Approved by: https://github.com/albanD
We now always have a `__getstate__`/`__setstate__` pair AND the `__dict__` attribute is lazily initialized. So we need to support that in our serialization code.
A quick audit of the rest doesn't look like the new `__getstate__` is too problematic. But maybe the test suite will bring more things to light.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94096
Approved by: https://github.com/ezyang, https://github.com/malfet
This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted.
The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:
`torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking.
`torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.
The PR fixes https://github.com/pytorch/pytorch/issues/90833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094
Approved by: https://github.com/cpuhrsch
This PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:
- `torch.enable_check_sparse_tensor_invariants` and `torch.is_check_sparse_tensor_invariants_enabled` functions to globally enable/disable the invariant checks and to retrieve the state of the feature, respectively
- `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.
The PR also fixes https://github.com/pytorch/pytorch/issues/90833
# Main issue
*The following content is outdated after merging the PRs in this ghstack but kept for the record.*
The importance of this feature is that when enabling the invariants checks by default, say, via
<details>
```
$ git diff
diff --git a/torch/__init__.py b/torch/__init__.py
index c8543057c7..19a91d0482 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1239,3 +1239,8 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.symbolic_shapes
+
+# temporarily enable sparse tensor arguments validation in unsafe
+# constructors:
+
+torch._C._set_check_sparse_tensor_invariants(True)
```
</details>
a massive number of test failures/errors occur in test_sparse_csr.py tests:
```
$ pytest -sv test/test_sparse_csr.py
<snip>
==== 4293 failed, 1557 passed, 237 skipped, 2744 errors in 69.71s (0:01:09) ====
```
that means that we are silently constructing sparse compressed tensors that do not satisfy the sparse tensor invariants. In particular, the following errors are raised:
```
AssertionError: "resize_as_sparse_compressed_tensor_: self and src must have the same layout" does not match "expected values to be a strided and contiguous tensor"
RuntimeError: CUDA error: device-side assert triggered
RuntimeError: `col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] for all i = 1, ..., nrows are sorted and distinct along the last dimension values` is not satisfied.
RuntimeError: expected col_indices to be a strided and contiguous tensor
RuntimeError: expected row_indices to be a strided and contiguous tensor
RuntimeError: expected values to be a strided and contiguous tensor
RuntimeError: for_each: failed to synchronize: cudaErrorAssert: device-side assert triggered
RuntimeError: tensor dimensionality must be sum of batch, base, and dense dimensionalities (=0 + 2 + 0) but got 3
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90849
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
@mlazos: skips `item()` calls if compiling with dynamo, by defining a helper function `_get_value` which either returns the result of `.item()` or the scalar cpu tensor if compiling with dynamo. This was done because removing `item()` calls significantly regresses eager perf. Additionally, `_dispatch_sqrt` calls the appropriate sqrt function (math.sqrt, or torch.sqrt).
Fixes https://github.com/pytorch/torchdynamo/issues/1083
This PR will no longer be needed once symint support is default.
This PR closes all remaining graph breaks in the optimizers (!!)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88173
Approved by: https://github.com/albanD
Fixes#81690
TODO:
* [x] C++ Unpickler Fix (locally tested pickled in Python and unpickled in C++)
* [x] C++ Pickler Fix (locally tested pickled in C++ and unpickled in Python)
* [x] Do quant_tensor, sparse_tensor, etc require similar changes? (Sparse and Quant don't need this)
* [x] Add Comments
* [x] How to make sure C++ and Python are in sync? (Functions in `pickler.h` help in getting and setting Tensor Metadata (math-bits for now) on a tensor. They are the only place which should handle this.)
Notes:
Quant Tensor don't support complex dtypes and for float they segfault with `_neg_view` : https://github.com/pytorch/pytorch/issues/88484
Sparse Tensor:
```python
>>> a = torch.tensor([[0, 2.], [3j, 0]]).to_sparse()
>>> a.conj().is_conj()
False
>>> a._neg_view()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: Cannot access storage of SparseTensorImpl
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88182
Approved by: https://github.com/ezyang, https://github.com/anjali411
### Description
Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.
`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.
Documentation for storages is improved as well.
### Issue
Fixes#82436
### Testing
N/A
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438
Approved by: https://github.com/ezyang
## Motivation
There is a bug in torch._utils.rebuild_qtensor when to restore a qtensor from pickle for not CPU device type. The tensor is created on the CPU device but set to a storage which maybe a different device type.
## Solution
Create the qtensor based on the storage device type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78234
Approved by: https://github.com/ezyang