Commit Graph

163 Commits

Author SHA1 Message Date
33dee721ae Reraise worker errors as runtime errors in more cases when the original exception can't be constructed (#140911)
related to https://github.com/pytorch/pytorch/issues/34130

when pytorch attempts to re-raise an exception from a worker process (e.g. multiprocessing dataloader), if it can't reconstruct the original exception message due to a type error, it instead raises it as a runtime error. However, if it can't reconstruct the exception for some other reason, it throws an error with a stacktrace pointing to the `ExceptionWrapper` code rather than the original underlying issue.

One case in which I run into this is with boto3's [HTTPClientError](66dc1f8d52/botocore/exceptions.py (L94))s. They must be constructed with a keyword argument `error`, but if `error` isn't passed, a `KeyError` is thrown instead of a `TypeError`, due to the particular way it is implemented:

* [HTTPClientError](66dc1f8d52/botocore/exceptions.py (L94))'s constructor excepts variable keyword arguments it passes to `super` (BotoCoreError)
* [it also defines a field `fmt` with `error`](66dc1f8d52/botocore/exceptions.py (L95))
* BotoCoreError [expects to be able to format that string with the kwargs](66dc1f8d52/botocore/exceptions.py (L41))

So in this case, if a HTTPClientError occurs on a worker process, you simply get a `KeyError: error` with a stacktrace pointing to [this line](3e2f276a14/torch/_utils.py (L710)) which is unhelpful.

Instead, I propose to reraise the error as a `RuntimeError` unconditionally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140911
Approved by: https://github.com/vmoens
2024-12-14 03:11:36 +00:00
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3071a20dab8fc2c4e453479e1bb7cf2.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
c30dd35877 [Device] Add "mps" to torch._utils._get_device_attr (#142447)
Follow up after  https://github.com/pytorch/pytorch/pull/141098
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142447
Approved by: https://github.com/kit1980
2024-12-10 20:57:17 +00:00
e1196dfe51 Deprecate torch._utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-12-08 22:55:36 +00:00
2ee2dcb736 [Device] Add mps as device type in torch._utils._get_available_device_type() (#141098)
As the title states

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141098
Approved by: https://github.com/malfet
2024-11-20 20:45:59 +00:00
1d28b8b6d5 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit e84d1121ad66a453c8c24fcc098625e2e9764fca.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
2024-11-05 23:10:38 +00:00
e84d1121ad Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-05 10:44:56 +00:00
37149d032c Fix .to(cpu) for Storage (#138011)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138011
Approved by: https://github.com/albanD
2024-10-23 01:31:48 +00:00
70288c3c2d Remove dependency on numpy for serialization for XLA/open registration devices without numpy (#137444)
Related: https://github.com/pytorch/xla/issues/7799#issuecomment-2375818263

Follow ups: Do the same for maia and mtia

## Motivation

With the move to `weights_only` by default, we are making an explicit decision not to allowlist GLOBALs required to deserialize `numpy` tensors  by default. The implication is that backends relying on numpy for serialization will fail loudly when `torch.load` flips `weights_only`.

However, we make the observation that this dependency on numpy was legacy and is not actually needed anymore. So we can remove it, which aligns with our weights_only strategy.

## Why is this ok?

The following comment on why numpy is necessary for serialization is legacy

c87c9f0a01/torch/_tensor.py (L303-L312)

We no longer do the following, though it was the case 5 years ago in the PR that added this
> CPU storage is reconstructed with randomly initialized data, moved onto backend device, and then storage is updated to the serialized content

**Instead what now happens is that CPU storage is constructed with data from the file **and then** moved onto backend device.**

Old behavior (`legacy_load`): 67adda891a/torch/serialization.py (L620)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137444
Approved by: https://github.com/albanD
2024-10-09 19:35:55 +00:00
a096f2899d Add torch.serialization.skip_data context manager (#134504)
## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
2024-09-05 16:53:39 +00:00
2fd36086bc Revert "Add torch.serialization.skip_data context manager (#134504)"
This reverts commit 94db935749b8de99d8c3ab23fb880c67c8f3e67a.

Reverted https://github.com/pytorch/pytorch/pull/134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](https://github.com/pytorch/pytorch/pull/134504#issuecomment-2327542276))
2024-09-03 22:21:27 +00:00
94db935749 Add torch.serialization.skip_data context manager (#134504)
## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
2024-08-29 04:52:52 +00:00
1285443994 Revert "Add torch.serialization.skip_data context manager (#134504)"
This reverts commit 202600bc2384cb19a29b8fca503bafc289158c32.

Reverted https://github.com/pytorch/pytorch/pull/134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](https://github.com/pytorch/pytorch/pull/134504#issuecomment-2316543901))
2024-08-29 01:30:49 +00:00
202600bc23 Add torch.serialization.skip_data context manager (#134504)
## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
2024-08-28 23:53:17 +00:00
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7cec82a43f2de52b83eff152d703be7a3.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
25cec43678 Remove dependency on private _compat_pickle in CPython (#129509)
Use the IMPORT_MAPPING and NAME_MAPPING from here https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129509
Approved by: https://github.com/malfet
ghstack dependencies: #129239, #129396
2024-06-26 14:20:27 +00:00
f85d1e845a [BE] enable UFMT for torch/nn/*.py (#128593)
Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128593
Approved by: https://github.com/mikaylagawarecki
2024-06-23 16:05:13 +00:00
cc8193c707 Revert "[BE] enable UFMT for torch/nn/functional.py (#128592)"
This reverts commit f6e6e55fa7d883a89ba99584f8632c260519ba73.

Reverted https://github.com/pytorch/pytorch/pull/128592 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128592#issuecomment-2181783936))
2024-06-21 00:44:16 +00:00
f6e6e55fa7 [BE] enable UFMT for torch/nn/functional.py (#128592)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #128596, #128594
2024-06-17 16:29:29 +00:00
90bb510ece Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 348b181a97abc2e636a6c18e5880a78e5d1dab94.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/clee2000 due to sorry I think https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 is still relevant, I will reach out to them to see what needs to be done in internal to get this remerged ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2159248859))
2024-06-10 20:44:42 +00:00
afe15d2d2f Flip default value for mypy disallow_untyped_defs [3/11] (#127840)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
348b181a97 Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007
2024-06-08 15:25:03 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
ba3b05fdf3 [1/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort stdlib (#127122)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127122
Approved by: https://github.com/kit1980
2024-05-25 08:25:50 +00:00
87f79af24d Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)
Fixes https://github.com/pytorch/pytorch/issues/124418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728
Approved by: https://github.com/albanD
2024-05-24 16:39:30 +00:00
8f0c207e18 xpu: implement xpu serialization (#125530)
Fixes: #125529

BC-breaking note:
The deprecated "async" argument to the Storage.cuda and Storage.hpu has been removed. Use non_blocking instead.

CC: @jbschlosser, @frank-wei @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125530
Approved by: https://github.com/guangyey, https://github.com/albanD
2024-05-16 20:22:17 +00:00
8573d9551a Fix to preserve tensor wrapper subclass dtype through multiprocessing serialization (#125615)
Fixes #125583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125615
Approved by: https://github.com/albanD
2024-05-07 14:35:48 +00:00
73744a2c00 torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
    "init",
    "is_available",
    "synchronize",
    "device_count",
    "current_device",
    "current_stream",
    "default_stream",
    "set_stream",
    "stream",
    "device",
]

```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```

---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
2024-04-26 16:17:54 +00:00
e04c7b19f4 Revert "torch.mtia module for MTIA device backend (#123612)"
This reverts commit 381653de63df4b1b31cc95531320caf83b1b60b3.

Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to this PR broke ROCm with message RuntimeError: Cannot have MTIA with other devices ([comment](https://github.com/pytorch/pytorch/pull/123612#issuecomment-2077649762))
2024-04-25 16:06:46 +00:00
381653de63 torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
    "init",
    "is_available",
    "synchronize",
    "device_count",
    "current_device",
    "current_stream",
    "default_stream",
    "set_stream",
    "stream",
    "device",
]

```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```

---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------

Differential Revision: [D56443356](https://our.internmc.facebook.com/intern/diff/D56443356)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
2024-04-24 20:51:20 +00:00
929242a15c Revert "torch.mtia module for MTIA device backend (#123612)"
This reverts commit d7e1bf9ff908d2a9c20d5354426d34c539fcb7a1.

Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
2024-04-19 22:44:26 +00:00
d7e1bf9ff9 torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
    "init",
    "is_available",
    "synchronize",
    "device_count",
    "current_device",
    "current_stream",
    "default_stream",
    "set_stream",
    "stream",
    "device",
]

```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```

---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------
@exported-using-ghexport

Differential Revision: [D52923602](https://our.internmc.facebook.com/intern/diff/D52923602/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
2024-04-18 17:38:06 +00:00
eb7adc3ae0 Refactor gpu trace to be device-agnostic (#121794)
# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.

# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
2024-03-30 13:04:38 +00:00
968c4c4154 Revert "Refactor gpu trace to be device-agnostic (#121794)"
This reverts commit 74deacbf31d032a2659dc1633dc3e5248921d466.

Reverted https://github.com/pytorch/pytorch/pull/121794 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks ROCm jobs in trunk 74deacbf31, please help take a look and reland the change ([comment](https://github.com/pytorch/pytorch/pull/121794#issuecomment-2013674083))
2024-03-21 20:33:17 +00:00
74deacbf31 Refactor gpu trace to be device-agnostic (#121794)
# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.

# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
2024-03-21 01:52:58 +00:00
f9ed1c432d Revert "Refactor gpu trace to be device-agnostic (#121794)"
This reverts commit 0ff1109e2688b8c841c9dd0eeecfba16f027b049.

Reverted https://github.com/pytorch/pytorch/pull/121794 on behalf of https://github.com/jeanschmidt due to Reverting to see if rocm trunk errors are related ([comment](https://github.com/pytorch/pytorch/pull/121794#issuecomment-2007519408))
2024-03-19 15:40:26 +00:00
0ff1109e26 Refactor gpu trace to be device-agnostic (#121794)
# Motivation
Refactor gpu trace to be device-agnostic. gpu trace is usually used in runtime components, including Device, Stream, Event, Guard, and Allocator. It should be device-agnostic and can be shared among each device backend.

# Solution
move `_cuda_trace.py` to `_gpu_trace.py`, which makes each device backend owns their callback, respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121794
Approved by: https://github.com/jgong5, https://github.com/albanD, https://github.com/EikanWang, https://github.com/gujinghui
2024-03-19 06:02:28 +00:00
4b18ab869f [torch.export] Support is_compiling() flag for non-strict mode (#119602)
Summary: In non-strict mode of torch.export() we didn't set those `is_compiling()` to `True` which is needed by some models.

Test Plan: Unit tests and manual testing.

Differential Revision: D53624452

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119602
Approved by: https://github.com/suo
2024-02-29 05:52:51 +00:00
46e3f670b4 refactor code to share across different devices (#120602)
# Motivation
Refactor utils code to make it possible to share across CUDA, XPU, and other backends.

# Solution
Move `_dummy_type` and `_LazySeedTracker` to torch._utils;

# Additional Context
When upstreaming, refactor these code changes by isolating them into in an additional PR to minimize their impact on the CUDA code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120602
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang
2024-02-28 09:42:58 +00:00
761fa5d6ec Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
There are two scenarios:

* Scenario 1: The checkpoint was saved with pytorch < 1.6
* Scenario 2: The checkpoint was saved with pytorch >= 1.6

Repro Scenario 1:

```python
from torch._subclasses import fake_tensor
import transformers

fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
    fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")
```

Error:

```bash
Some weights of the model checkpoint at sshleifer/tiny-gpt2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:463 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    460 │   │   │   )                                                                             │
│    461 │   │   return safe_load_file(checkpoint_file)                                            │
│    462 │   try:                                                                                  │
│ ❱  463 │   │   return torch.load(checkpoint_file, map_location="cpu")                            │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1030 in load                                                 │
│                                                                                                  │
│   1027 │   │   │   │   return _legacy_load(opened_file, map_location, _weights_only_unpickler,   │
│   1028 │   │   │   except RuntimeError as e:                                                     │
│   1029 │   │   │   │   raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None           │
│ ❱ 1030 │   │   return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args  │
│   1031                                                                                           │
│   1032                                                                                           │
│   1033 # Register pickling support for layout instances such as                                  │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1258 in _legacy_load                                         │
│                                                                                                  │
│   1255 │   _sys_info = pickle_module.load(f, **pickle_load_args)                                 │
│   1256 │   unpickler = UnpicklerWrapper(f, **pickle_load_args)                                   │
│   1257 │   unpickler.persistent_load = persistent_load                                           │
│ ❱ 1258 │   result = unpickler.load()                                                             │
│   1259 │                                                                                         │
│   1260 │   deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)                 │
│   1261                                                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:201 in _rebuild_tensor_v2                                           │
│                                                                                                  │
│   198 def _rebuild_tensor_v2(                                                                    │
│   199 │   storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None    │
│   200 ):                                                                                         │
│ ❱ 201 │   tensor = _rebuild_tensor(storage, storage_offset, size, stride)                        │
│   202 │   tensor.requires_grad = requires_grad                                                   │
│   203 │   if metadata:                                                                           │
│   204 │   │   set_tensor_metadata(tensor, metadata)                                              │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:180 in _rebuild_tensor                                              │
│                                                                                                  │
│   177 def _rebuild_tensor(storage, storage_offset, size, stride):                                │
│   178 │   # first construct a tensor with the correct dtype/device                               │
│   179 │   t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device)      │
│ ❱ 180 │   return t.set_(storage._untyped_storage, storage_offset, size, stride)                  │
│   181                                                                                            │
│   182                                                                                            │
│   183 def get_tensor_metadata(tensor):                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/utils/_stats.py:20 in wrapper                                                 │
│                                                                                                  │
│   17 │   │   if fn.__qualname__ not in simple_call_counter:                                      │
│   18 │   │   │   simple_call_counter[fn.__qualname__] = 0                                        │
│   19 │   │   simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1     │
│ ❱ 20 │   │   return fn(*args, **kwargs)                                                          │
│   21 │   return wrapper                                                                          │
│   22                                                                                             │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1160 in __torch_dispatch__                         │
│                                                                                                  │
│   1157 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                      │
│   1158 │   │   assert self not in _get_current_dispatch_mode_stack(), func                       │
│   1159 │   │   try:                                                                              │
│ ❱ 1160 │   │   │   return self.dispatch(func, types, args, kwargs)                               │
│   1161 │   │   except TypeError:                                                                 │
│   1162 │   │   │   log.exception("fake tensor raised TypeError")                                 │
│   1163 │   │   │   raise                                                                         │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1318 in dispatch                                   │
│                                                                                                  │
│   1315 │   │                                                                                     │
│   1316 │   │   # we are falling through to running non constant tensors, any input constant tha  │
│   1317 │   │   # is written to must be invalidated                                               │
│ ❱ 1318 │   │   self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)   │
│   1319 │   │                                                                                     │
│   1320 │   │   # Try for fastpath                                                                │
│   1321 │   │   if has_symbolic_sizes:                                                            │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1557 in invalidate_written_to_constants            │
│                                                                                                  │
│   1554 │   │   any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)         │
│   1555 │   │   if any_constant and get_schema_info(func).is_mutable():                           │
│   1556 │   │   │   schema_info = get_schema_info(func)                                           │
│ ❱ 1557 │   │   │   _, new_kwargs = normalize_function(                                           │
│   1558 │   │   │   │   func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True         │
│   1559 │   │   │   )                                                                             │
│   1560 │   │   │   for k, v in new_kwargs.items():                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:297 in normalize_function                              │
│                                                                                                  │
│   294 │   │   new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,    │
│   295 │   else:                                                                                  │
│   296 │   │   assert callable(target)                                                            │
│ ❱ 297 │   │   torch_op_schemas = get_signature_for_torch_op(target)                              │
│   298 │   │   matched_schemas = []                                                               │
│   299 │   │   if torch_op_schemas:                                                               │
│   300 │   │   │   # Iterate through all of the schema until we find one that matches             │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in get_signature_for_torch_op                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in <listcomp>                                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:70 in _torchscript_schema_to_signature                 │
│                                                                                                  │
│    67 │   from inspect import Parameter                                                          │
│    68 │   parameters : List[Parameter] = []                                                      │
│    69 │   for arg in ts_schema.arguments:                                                        │
│ ❱  70 │   │   arg_type = _torchscript_type_to_python_type(arg.type)                              │
│    71 │   │   default = arg.default_value if arg.has_default_value() else Parameter.empty        │
│    72 │   │   # TODO: Figure out if this is safe. It seems like when generating the type signa   │
│    73 │   │   # PythonArgParser, we emit signatures with `input` instead of `self` as the firs   │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:64 in _torchscript_type_to_python_type                 │
│                                                                                                  │
│    61 │   eval'ing the annotation_str. _type_eval_globals sets up expressions                    │
│    62 │   like "List" and "Future" to map to actual types (typing.List and jit.Future)           │
│    63 │   """                                                                                    │
│ ❱  64 │   return eval(ts_type.annotation_str, _type_eval_globals)                                │
│    65                                                                                            │
│    66 def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Sig   │
│    67 │   from inspect import Parameter                                                          │
│ <string>:1 in <module>                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'Storage' is not defined

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:467 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│ ❱  467 │   │   │   │   if f.read(7) == "version":                                                │
│    468 │   │   │   │   │   raise OSError(                                                        │
│    469 │   │   │   │   │   │   "You seem to have cloned a repository without having git-lfs ins  │
│    470 │   │   │   │   │   │   "git-lfs and run `git lfs install` followed by `git lfs pull` in  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/codecs.py:322 in decode                                       │
│                                                                                                  │
│    319 │   def decode(self, input, final=False):                                                 │
│    320 │   │   # decode input (taking the buffer into account)                                   │
│    321 │   │   data = self.buffer + input                                                        │
│ ❱  322 │   │   (result, consumed) = self._buffer_decode(data, self.errors, final)                │
│    323 │   │   # keep undecoded input until the next call                                        │
│    324 │   │   self.buffer = data[consumed:]                                                     │
│    325 │   │   return result                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/pytorch/bug_repro.py:16 in <module>                                                         │
│                                                                                                  │
│   13 fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")                  │
│   14 assert fake_model is not None                                                               │
│   15 with fake_mode:                                                                             │
│ ❱ 16 │   fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")  # raises    │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:484 in │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   481 │   │   │   )                                                                              │
│   482 │   │   elif type(config) in cls._model_mapping.keys():                                    │
│   483 │   │   │   model_class = _get_model_class(config, cls._model_mapping)                     │
│ ❱ 484 │   │   │   return model_class.from_pretrained(                                            │
│   485 │   │   │   │   pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs,   │
│   486 │   │   │   )                                                                              │
│   487 │   │   raise ValueError(                                                                  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:2604 in          │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   2601 │   │   if from_pt:                                                                       │
│   2602 │   │   │   if not is_sharded and state_dict is None:                                     │
│   2603 │   │   │   │   # Time to load the checkpoint                                             │
│ ❱ 2604 │   │   │   │   state_dict = load_state_dict(resolved_archive_file)                       │
│   2605 │   │   │                                                                                 │
│   2606 │   │   │   # set dtype to instantiate the model under:                                   │
│   2607 │   │   │   # 1. If torch_dtype is not None, we use that dtype                            │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:479 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    476 │   │   │   │   │   │   "model. Make sure you have saved the model properly."             │
│    477 │   │   │   │   │   ) from e                                                              │
│    478 │   │   except (UnicodeDecodeError, ValueError):                                          │
│ ❱  479 │   │   │   raise OSError(                                                                │
│    480 │   │   │   │   f"Unable to load weights from pytorch checkpoint file for '{checkpoint_f  │
│    481 │   │   │   │   f"at '{checkpoint_file}'. "                                               │
│    482 │   │   │   │   "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please s  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Unable to load weights from pytorch checkpoint file for '/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin' at
'/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set
from_tf=True.
```

Repro scenario 2:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during torch.load (when fake mode is active) by changing the storage's device to 'meta'.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang, https://github.com/atalman
2024-02-16 23:42:50 +00:00
458e83b5b3 Revert "Add FakeTensor support to torch._utils._rebuild_tensor (#108186)"
This reverts commit 113506d2d4a0120e912c8f36e70a621f55378f81.

Reverted https://github.com/pytorch/pytorch/pull/108186 on behalf of https://github.com/atalman due to Reverted Internally ([comment](https://github.com/pytorch/pytorch/pull/108186#issuecomment-1935310344))
2024-02-09 04:19:20 +00:00
113506d2d4 Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
Partially fixes https://github.com/pytorch/pytorch/issues/105077

Repro:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.

Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
2024-02-08 03:01:34 +00:00
499040ac32 Revert "Add FakeTensor support to torch._utils._rebuild_tensor (#108186)"
This reverts commit 426339e4de2efc0cbd501e2bff947ba890ec9817.

Reverted https://github.com/pytorch/pytorch/pull/108186 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/108186#issuecomment-1929978008))
2024-02-06 15:04:48 +00:00
426339e4de Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
Partially fixes https://github.com/pytorch/pytorch/issues/105077

Repro:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.

Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
2024-02-02 20:35:38 +00:00
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d6735051a4714fa2ac1c163315b650118744.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00