## What's the problem?
The popular `fx.node.map_arg()` and `fx.node.map_aggregate()` apply operations recursively on `dict`s, `tuples`, `list`s, etc, and return a new collection of the same type.
Unfortunately, their base input type is `Argument`, which is [very unspecific indeed](5d55a6585d/torch/fx/node.py (L48-L58)): most type information is just thrown away at the call site of either of these functions, as far as the type checker goes.
As `torch` moves to a more typed code base, this would force innocent, unsuspecting developers to add logically unnecessary casts or `# type: ignore` statements.
## What's the solution?
Making these two `node.map_*` functions generic on the first argument and return type means that type information is preserved for the type checker. (The signature of the other parameter, the function that visits the nodes and subnodes, has not changed, nor should it.)
## Won't it break everything?
It doesn't break the type checker - one place needed an extra hint.
There have been code breakages, resolved one, at least one new one... we'll see!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146248
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
Rationale: While Numpy doesn't support `bfloat16` and therefore there's no official typestr for `bfloat16` in `__array_interface__` (https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__), JAX/ml_dtypes uses "<V2":
```
>>> from jax import numpy as jnp
>>> jnp.bfloat16.dtype.str
'<V2'
```
Using the same in PyTorch has the upside of making the typestrs returned by `__cuda_array_interface__` identify the torch dtype uniquely.
### Misc notes
(1) JAX itself just refuses to do `__cuda_array_interface__` for `bfloat16`:
```
>>> from jax import numpy as jnp
>>> jnp.arange(10, dtype=jnp.bfloat16).__cuda_array_interface__
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: __cuda_array_interface__ is not supported for bfloat16 buffers.
```
(2) The "official" description of `__cuda_array_interface__` doesn't mention bfloat16, it just references `__array_interface__`: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
(3) Ongoing issue for numpy to support bfloat16: https://github.com/numpy/numpy/issues/19808
(4) Tweet that triggered this: https://x.com/HeinrichKuttler/status/1866761979349844211, with @ezyang responding.
(5) "<V2" is kinda weird, as it's a "little-endian void" type. When given to Numpy, it gets turned into endian-agnostic:
```
>>> import numpy as np
>>> import ml_dtypes
>>> np.dtype("bfloat16").str
'<V2'
>>> np.dtype("<V2").str
'|V2'
```
Still, it makes sense to have a unique string for `bfloat16` and since Google chose "<V2" we might as well use that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143042
Approved by: https://github.com/ezyang
Related: https://github.com/pytorch/xla/issues/7799#issuecomment-2375818263
Follow ups: Do the same for maia and mtia
## Motivation
With the move to `weights_only` by default, we are making an explicit decision not to allowlist GLOBALs required to deserialize `numpy` tensors by default. The implication is that backends relying on numpy for serialization will fail loudly when `torch.load` flips `weights_only`.
However, we make the observation that this dependency on numpy was legacy and is not actually needed anymore. So we can remove it, which aligns with our weights_only strategy.
## Why is this ok?
The following comment on why numpy is necessary for serialization is legacy
c87c9f0a01/torch/_tensor.py (L303-L312)
We no longer do the following, though it was the case 5 years ago in the PR that added this
> CPU storage is reconstructed with randomly initialized data, moved onto backend device, and then storage is updated to the serialized content
**Instead what now happens is that CPU storage is constructed with data from the file **and then** moved onto backend device.**
Old behavior (`legacy_load`): 67adda891a/torch/serialization.py (L620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137444
Approved by: https://github.com/albanD
Fixes#130154
This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic.
The PyCapsule issue also affects `deepcopy`, so that's fixed here as well.
Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137030
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
Adds the following to allowed globals for the `weights_only` unpickler
- [x] `torch._utils._rebuild_qtensor` and qtensor related types
- [x] `torch._utils._rebuild_parameter_with_state` (used deserializing a parameter that has user-defined attributes like `Param.foo`)
The remaining rebuild functions that have not been allowlisted are
- [x] `torch._utils._rebuild_wrapper_subclass` (allowlisted in above PR)
- [ ] `torch._utils._rebuild_device_tensor_from_numpy`
- [ ] `torch._utils._rebuild_xla_tensor` (legacy)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124330
Approved by: https://github.com/albanD
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
PT2)
The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.
This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```
Test Plan:
- existing tests
Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
Make `torch.__future__.get_swap_module_params_on_conversion() == True` account for `assign` argument to `nn.Module.load_state_dict`
Similar to when `torch.__future__.set_swap_module_params_on_conversion()` is `False`, `assign=True` means that we do not incur a `self.copy_(other)` and the properties of `other` will be preserved
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121158
Approved by: https://github.com/albanD
ghstack dependencies: #121157
Added a `torch.Tensor` method that defines how to transform `other`, a value in the state dictionary, to be loaded into `self`, a param/buffer in an `nn.Module` before swapping via `torch.utils.swap_tensors`
* `param.module_load(sd[key])`
This method can be overridden using `__torch_function__`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117913
Approved by: https://github.com/albanD