Many extensions (including pybind helpers) call `Tensor.__dlpack__()` without a stream argument. Before #150217, `stream=None` behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After #150217, `stream=None` maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream.
See this example
```
import torch
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()
with torch.cuda.stream(s):
with torch.cuda.graph(g):
_ = x + 1
cap = x.__dlpack__()
_ = torch.utils.dlpack.from_dlpack(cap)
```
This PR partially reverts #150217 that stream=None defaults to no sync.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163242
Approved by: https://github.com/ngimel
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.
Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.
Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.
There are two major feature gaps in the implementation:
- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
This PR introduces the rest of the keyword-arguments added in DLPack
version 2023.12: `dl_device` and `copy`.
In summary, we handle these arguments in the C++ implementation of
`to_dlpack(...)` at _torch/csrc/Module.cpp_, by calling the
`maybeCopyTensor` function at _aten/src/ATen/DLConvertor.cpp_. It also
introduces the following changes:
- Add a new Python API `torchDeviceToDLDevice()`, which is simply a
refactoring of the `getDLDevice()` function at
_aten/src/ATen/DLConvertor.cpp_.
- Add both keyword-arguments to the `from_dlpack()` function at
_torch/utils/dlpack.py_ and to the `Tensor.__dlpack__()` dunder
method.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150218
Approved by: https://github.com/albanD
ghstack dependencies: #150216, #150217
Summary:
NumPy based tensor rebuilding from serialization has been deprecated by other backends (eg. [XLA](https://github.com/pytorch/pytorch/pull/137444)). The new flow has CPU storage being constructed with data from the file and then moved to the target backend device.
Furthermore, relying on numpy for serialization will fail loudly when torch.load flips weights_only.
Reviewed By: andyanwang
Differential Revision: D77843238
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157884
Approved by: https://github.com/albanD
This PR makes the necessary changes in order to upgrade PyTorch DLPack
support to version 1.0. In summary, we add support for the following:
- Support both `DLManagedTensor` and `DLManagedTensorVersioned` when
producing and consuming DLPack capsules
- New parameter for `__dlpack__` method: `max_version`
- Version checks:
- Fallback to old implementation if no `max_version` or if version
lower than 1.0
- Check that the to-be-consumed capsule is of version up to 1.X
In order to accommodate these new specifications, this PR adds the
following main changes:
- `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python
API for creating a versioned DLPack capsule (called by `__dlpack__`
method)
- `DLPackTraits<T>` class (DLConvertor.h): select the correct
traits (e.g. capsule name, conversion functions) depending on which
DLPack tensor class is being used
- `toDLPackImpl<T>` function (DLConvertor.cpp): populates the
common fields of both classes
- `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor
from a DLPAck capsule
- `fillVersion<T>` function (DLConvertor.cpp): populates the version
field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`)
- `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function
for constructing a tensor out of a DLPack capsule that also marks the
capsule as used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145000
Approved by: https://github.com/albanD
This PR makes the necessary changes in order to upgrade PyTorch DLPack
support to version 1.0. In summary, we add support for the following:
- Support both `DLManagedTensor` and `DLManagedTensorVersioned` when
producing and consuming DLPack capsules
- New parameter for `__dlpack__` method: `max_version`
- Version checks:
- Fallback to old implementation if no `max_version` or if version
lower than 1.0
- Check that the to-be-consumed capsule is of version up to 1.X
In order to accommodate these new specifications, this PR adds the
following main changes:
- `torch._C._to_dlpack_versioned` Python API (Module.cpp): new Python
API for creating a versioned DLPack capsule (called by `__dlpack__`
method)
- `DLPackTraits<T>` class (DLConvertor.h): select the correct
traits (e.g. capsule name, conversion functions) depending on which
DLPack tensor class is being used
- `toDLPackImpl<T>` function (DLConvertor.cpp): populates the
common fields of both classes
- `fromDLPackImpl<T>` function (DLConvertor.cpp): constructs a tensor
from a DLPAck capsule
- `fillVersion<T>` function (DLConvertor.cpp): populates the version
field for `DLManagedTensorVersioned` (no-op for `DLManagedTensor`)
- `tensor_fromDLPackImpl<T>` function (tensor_new.cpp): outer function
for constructing a tensor out of a DLPack capsule that also marks the
capsule as used
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145000
Approved by: https://github.com/albanD
## What's the problem?
The popular `fx.node.map_arg()` and `fx.node.map_aggregate()` apply operations recursively on `dict`s, `tuples`, `list`s, etc, and return a new collection of the same type.
Unfortunately, their base input type is `Argument`, which is [very unspecific indeed](5d55a6585d/torch/fx/node.py (L48-L58)): most type information is just thrown away at the call site of either of these functions, as far as the type checker goes.
As `torch` moves to a more typed code base, this would force innocent, unsuspecting developers to add logically unnecessary casts or `# type: ignore` statements.
## What's the solution?
Making these two `node.map_*` functions generic on the first argument and return type means that type information is preserved for the type checker. (The signature of the other parameter, the function that visits the nodes and subnodes, has not changed, nor should it.)
## Won't it break everything?
It doesn't break the type checker - one place needed an extra hint.
There have been code breakages, resolved one, at least one new one... we'll see!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146248
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
Rationale: While Numpy doesn't support `bfloat16` and therefore there's no official typestr for `bfloat16` in `__array_interface__` (https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__), JAX/ml_dtypes uses "<V2":
```
>>> from jax import numpy as jnp
>>> jnp.bfloat16.dtype.str
'<V2'
```
Using the same in PyTorch has the upside of making the typestrs returned by `__cuda_array_interface__` identify the torch dtype uniquely.
### Misc notes
(1) JAX itself just refuses to do `__cuda_array_interface__` for `bfloat16`:
```
>>> from jax import numpy as jnp
>>> jnp.arange(10, dtype=jnp.bfloat16).__cuda_array_interface__
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: __cuda_array_interface__ is not supported for bfloat16 buffers.
```
(2) The "official" description of `__cuda_array_interface__` doesn't mention bfloat16, it just references `__array_interface__`: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
(3) Ongoing issue for numpy to support bfloat16: https://github.com/numpy/numpy/issues/19808
(4) Tweet that triggered this: https://x.com/HeinrichKuttler/status/1866761979349844211, with @ezyang responding.
(5) "<V2" is kinda weird, as it's a "little-endian void" type. When given to Numpy, it gets turned into endian-agnostic:
```
>>> import numpy as np
>>> import ml_dtypes
>>> np.dtype("bfloat16").str
'<V2'
>>> np.dtype("<V2").str
'|V2'
```
Still, it makes sense to have a unique string for `bfloat16` and since Google chose "<V2" we might as well use that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143042
Approved by: https://github.com/ezyang
Related: https://github.com/pytorch/xla/issues/7799#issuecomment-2375818263
Follow ups: Do the same for maia and mtia
## Motivation
With the move to `weights_only` by default, we are making an explicit decision not to allowlist GLOBALs required to deserialize `numpy` tensors by default. The implication is that backends relying on numpy for serialization will fail loudly when `torch.load` flips `weights_only`.
However, we make the observation that this dependency on numpy was legacy and is not actually needed anymore. So we can remove it, which aligns with our weights_only strategy.
## Why is this ok?
The following comment on why numpy is necessary for serialization is legacy
c87c9f0a01/torch/_tensor.py (L303-L312)
We no longer do the following, though it was the case 5 years ago in the PR that added this
> CPU storage is reconstructed with randomly initialized data, moved onto backend device, and then storage is updated to the serialized content
**Instead what now happens is that CPU storage is constructed with data from the file **and then** moved onto backend device.**
Old behavior (`legacy_load`): 67adda891a/torch/serialization.py (L620)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137444
Approved by: https://github.com/albanD
Fixes#130154
This PR takes the strategy outlined in the above issue and clears out any cached sizes / strides PyCapsules before serialization. This affects the default subclass serialization logic.
The PyCapsule issue also affects `deepcopy`, so that's fixed here as well.
Note: I originally tried utilizing a context manager to remove / restore cached PyCapsules after serialization, but in practice the state returned from `_reduce_ex_internal()` references the actual `tensor.__dict__()`, so the problem persists once the cached values are restored. Instead, we have to be careful to remove the cached values in the right place so they're not re-cached when pulling out size / stride information for serialization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137030
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD