116 Commits

Author SHA1 Message Date
e20c9bf288 [torch/utils][Code Clean] Clean asserts in torch/utils/*.py (#165410)
Including:
- `torch/utils/*.py`

Fixes part of #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165410
Approved by: https://github.com/albanD
2025-10-20 23:29:17 +00:00
ed74dc054d add the option to disable functionalization in AOTDispatcher (#164577)
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
2025-10-16 15:44:11 +00:00
5e58420dff LocalTensor (#164537)
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally.  When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards.  A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.

NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!)  (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)

In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.

Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.

Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
2025-10-12 20:06:41 +00:00
086dec3235 Pyrefly suppressions 6/n (#164877)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

INFO 0 errors (5,064 ignored)

Only four directories left to enable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877
Approved by: https://github.com/oulgen
2025-10-08 02:30:57 +00:00
35c4130fd1 [2/N] Fix ruff warnings (#164460)
Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
2025-10-04 03:40:32 +00:00
1e4dfeeb06 Add early_stop kwarg to torch.utils.checkpoint (#160781)
We already have a context manager "set_checkpoint_early_stop". This PR adds a kwarg that toggles the same setting.

It is also useful to have a kwarg version of the setting in addition to the context manager because is annoying to apply a context manager when the AC is being applied via CheckpointWrapper.

Similar to the "debug" kwarg and the corresponding "set_checkpoint_debug_enabled" context manager, the context manager defaults to None and overrides the local setting when non-None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160781
Approved by: https://github.com/tianyu-l
2025-08-26 22:32:35 +00:00
50cfe76231 Update checkpoint warning to target PyTorch 2.9 (#160725)
Follow-up to #160534. Fixes the docstrings and the warning in checkpoint_sequential, which presumably should have same deprecation notice
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160725
Approved by: https://github.com/soulitzer
2025-08-19 15:08:50 +00:00
4a90dc0c1f Update checkpoint warning to target PyTorch 2.9 (#160643)
Fixes #160534

Updates the warning in torch.utils.checkpoint to state that starting in PyTorch 2.9, calling checkpoint without explicitly passing use_reentrant will raise an exception. Follows the guidance from the issue discussion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160643
Approved by: https://github.com/soulitzer
2025-08-14 20:53:17 +00:00
d40aaa42ee [BE][16/16] fix typos in torch/ (torch/utils/) (#156606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156606
Approved by: https://github.com/albanD
ghstack dependencies: #156318, #156320, #156602, #156604
2025-07-02 22:55:29 +00:00
3580b8dde4 [BE] Mention debug=True in AC error messages (#155593)
See https://github.com/pytorch/pytorch/issues/155171#issuecomment-2949415407
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155593
Approved by: https://github.com/janeyx99
2025-06-11 00:32:41 +00:00
80703ca332 [FlexAttention] Allow dispatch to SAC for flex (#150080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150080
Approved by: https://github.com/zou3519
2025-06-05 04:34:27 +00:00
b040d63ce4 Prevent SAC cache from being kept alive by reference cycle (#154651)
Fixes https://github.com/pytorch/pytorch/issues/154642
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154651
Approved by: https://github.com/xmfan
2025-05-29 22:27:35 +00:00
5aea57d653 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-16 01:37:48 +00:00
236b08cbf8 Revert "[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)"
This reverts commit 4863e5c843722eb2a34fb0ca1d518a33431a38c0.

Reverted https://github.com/pytorch/pytorch/pull/153300 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see fa8543454a/1 ([comment](https://github.com/pytorch/pytorch/pull/153300#issuecomment-2884489459))
2025-05-15 16:58:52 +00:00
4863e5c843 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-15 08:10:35 +00:00
5abe74857a SAC: fix recompute tag propagation for ops with list[tensor] inputs (#152195)
There's an "are we compiling" check in SAC, which we rely on to know when to propagate recompute tags during tracing.

This check was a bit brittle, and missed cases where input ops accept list of tensors - I updated it to check if a `FunctionalTensorMode` is active, which should be a 100% reliable way to know if AOTDispatcher is in the middle of running.

There is a long-standing followup here around unifying `torch.compiler.is_compiling()` to work in all cases. We should probably just update it to always check if FakeMode/FunctionalMode are active and use it there. This has a bit of BC risk though so I opted for the more local fix to SAC.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152195
Approved by: https://github.com/soulitzer
2025-05-05 17:21:00 +00:00
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
be4b7e8131 Param fixes in docstring (#136097)
Fixes wrong param names in docstrings. cc: @kit1980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136097
Approved by: https://github.com/ezyang
2024-09-21 18:56:34 +00:00
a23dae22d5 Update AC pass use_reentrant message (#134472)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134472
Approved by: https://github.com/albanD
2024-08-26 21:57:38 +00:00
94155ce31b [Torch] Support meta device in checkpoint (#132684)
Summary:
## Why
utils.checkpoint doesn't support meta device:

```
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 490, in checkpoint
    next(gen)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 1359, in _checkpoint_without_reentrant_generator
    device_module = _get_device_module(device)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 98, in _get_device_module
    device_module = getattr(torch, device)
  File "/Users/lyu1/torchdev/lib/python3.9/site-packages/torch/__init__.py", line 1938, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'meta'
```

This blocks us from running model with checkpoint enabled in meta mode.

## What
This diff handles the case of meta device in checkpoint.py.

(in checkpoint.py, device module is manily used when preserve_rng_state=true, which doesn't apply to meta case. So a more elgant fix might be set preserve_rng_state=false when detecting args are on meta device. But I didn't find where to do this check in the minimum way. Let me know if you have ideas.)

Test Plan: Tested with toy model which has checkpoint on its module: P1513716944

Differential Revision: D60749427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132684
Approved by: https://github.com/kit1980
2024-08-06 20:45:50 +00:00
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
eeef68671d [autograd] Do not detach when unpacking tensors that do not require grad (#127959)
In this PR:
- Ensure that if a tensor not requiring grad is saved for backward unpacking does not trigger a detach (unless the user installs a saved tensor pack hook that returns a tensor requiring grad).
- Update non-reentrant checkpoint to also no longer detach for this case.

Alternatives:
- For custom autograd Function, you could directly save on ctx to work around this, but that would not work for when we switch to using custom ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127959
Approved by: https://github.com/YuqingJ
ghstack dependencies: #125795, #128545, #129262
2024-07-01 21:57:36 +00:00
321bdcb372 Fix device propagation for checkpointing (#128671)
Fixes: #128478

In backward() implementation checkpointing code was quering device type from the rng_state tensors saved on forward(). These tensors are CPU only tensors and don't carry device information with them. As a result CUDA device was assumed as a default. Which is not correct if user runs on some other device. For example, on XPU.

This patch saves full device information on forward() and uses it on backward() to get device type. Previously forward save only device index.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128671
Approved by: https://github.com/guangyey, https://github.com/soulitzer
2024-06-27 17:14:13 +00:00
575bc1e3af [Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) (#129295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129295
Approved by: https://github.com/Chillee
2024-06-25 23:47:08 +00:00
c89a9f5d17 Allow SAC policy_fn to return bool for backward compatibility (#129262)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129262
Approved by: https://github.com/Chillee, https://github.com/fmassa
ghstack dependencies: #125795, #128545
2024-06-24 13:54:30 +00:00
1877b7896c [checkpoint] Clean up selective activation checkpoint and make public (#125795)
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230))  Enum instead of bool.
   - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
   - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE :  metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-18 18:18:50 +00:00
6895a5804c Revert "[checkpoint] Clean up selective activation checkpoint and make public (#125795)"
This reverts commit c472cec5656b9ffb668af97a02d711bdbdf5ebec.

Reverted https://github.com/pytorch/pytorch/pull/125795 on behalf of https://github.com/soulitzer due to breaking torchtitan CI ([comment](https://github.com/pytorch/pytorch/pull/125795#issuecomment-2167036157))
2024-06-14 01:14:59 +00:00
c472cec565 [checkpoint] Clean up selective activation checkpoint and make public (#125795)
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-12 23:57:33 +00:00
8c1247cffb [BE] Fixed CPU autocast warning (#127774)
This PR fixes
```
/data/users/andgu/pytorch/torch/utils/checkpoint.py:1398: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127774
Approved by: https://github.com/soulitzer, https://github.com/Skylion007, https://github.com/tianyu-l
2024-06-11 21:33:35 +00:00
8db9dfa2d7 Flip default value for mypy disallow_untyped_defs [9/11] (#127846)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127846
Approved by: https://github.com/ezyang
ghstack dependencies: #127842, #127843, #127844, #127845
2024-06-08 18:50:06 +00:00
d937d0db0f [SAC] fix ignored ops in eager mode to recompute (#126751)
as titled. I found that there're some issues in the eager mode SAC where
sometimes we would have recompute pop from storage of ops that are
missing, these ops are detach ops. So this PR refactors the two modes,
so that they would always recompute ignored ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126751
Approved by: https://github.com/yf225
2024-05-22 06:47:22 +00:00
d17be10df1 make torch.amp.autocast more generic (#125103)
# Motivation
As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
2024-05-08 12:13:26 +00:00
fab5bd5359 [checkpoint] Improve error message when use_reentrant=True is used with .grad() (#125155)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125155
Approved by: https://github.com/albanD
2024-04-29 18:57:35 +00:00
19a83eacb5 add new API torch.amp.is_autocast_available (#124938)
# Motivation
expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938
Approved by: https://github.com/albanD
2024-04-26 08:45:20 +00:00
cdc66e9dc3 refactor autocast python APIs (#124479)
# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
2024-04-25 14:33:33 +00:00
25f321b84f Refactor autocast C++ APIs to be device-agnostic (#124359)
# Motivation
This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast  **C++** APIs.
In C++ side,
- `is_enabled()` -> `is_enabled(device_type)`.
- `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`.
- `get_autocast_dtype()` -> `get_autocast_dtype(device_type)`
- `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)`

These following C++ APIs are deprecated and should be removed in PyTorch 2.5
- `is_cpu_enabled`
- `set_cpu_enabled`
- `get_autocast_cpu_dtype`
- `set_autocast_cpu_dtype`
- `is_xpu_enabled`
- `set_xpu_enabled`
- `get_autocast_xpu_dtype`
- `set_autocast_xpu_dtype`
- `is_ipu_enabled`
- `set_ipu_enabled`
- `get_autocast_ipu_dtype`
- `set_autocast_ipu_dtype`
- `is_hpu_enabled`
- `set_hpu_enabled`
- `get_autocast_hpu_dtype`
- `set_autocast_hpu_dtype`
- `is_xla_enabled`
- `set_xla_enabled`
- `get_autocast_xla_dtype`
- `set_autocast_xla_dtype`
- `is_privateuseone_enabled`
- `set_privateuseone_enabled`
- `get_autocast_privateuseone_dtype`
- `set_autocast_privateuseone_dtype`

In Python side,
provide 4 generic autocast APIs:
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type, new_enabled)`
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type, dtype)`

# Additional Context
We will submit another PR to refactor autocast **Python** APIs based on this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359
Approved by: https://github.com/jgong5, https://github.com/albanD
2024-04-23 10:38:50 +00:00
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
26a9b05bce Set stacklevel on checkpoint warning (#123717)
Partially addresses https://github.com/pytorch/pytorch/issues/123626

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123717
Approved by: https://github.com/Skylion007
2024-04-10 17:25:06 +00:00
1d6fc0d4de Fixed _infer_device_type warning in checkpoint (#122726)
Previously, we were checking `len(device_types)` where `device_types` is a `list`. This meant that if there were multiple inputs, we would see something like `device_types = ["cuda", "cuda"]` and a false positive warning. We should check `len(set(device_types))`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122726
Approved by: https://github.com/soulitzer
2024-03-27 18:38:42 +00:00
512251c8f3 Use tree_map to get device ids and device types for activation checkpointing (#121462)
`get_device_states` doesn't recursively look into nested lists/dicts to find tensors. As a result, activation checkpointing for such inputs results in silent incorrect results as `get_device_states` returns an empty result and no rng is saved as a result here: https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L188 since `fwd_device_states` is empty.

Fixed this by using `tree_map` for both `get_device_states` and `_infer_device_type`. Also added appropriate unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121462
Approved by: https://github.com/soulitzer
2024-03-20 21:09:21 +00:00
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45ef53747e2eefffd65d91ce840b431b.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
5866284d4a Make not passing use_reentrant back to warning instead of erroring and clarify docs (#116710)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116710
Approved by: https://github.com/albanD
ghstack dependencies: #116523
2024-01-09 20:58:49 +00:00
e728ebb66d Small docstring fix (#116947)
Fix a small typo in the docstring of checkpoint function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116947
Approved by: https://github.com/Skylion007, https://github.com/kit1980
2024-01-08 23:51:59 +00:00
4d6a1ad400 Activation checkpoint and checkpoint_sequential errors if use_reentrant not passed explicitly (#115868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115868
Approved by: https://github.com/albanD
ghstack dependencies: #115438
2023-12-20 15:23:44 +00:00
dd367b7c8f check tensor subclass when using torch.compile + SAC (#115960)
as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
2023-12-18 17:49:06 +00:00
495054545c Allow preserve_rng_state=True when torch.compile + selective checkpointing + CUDA (#113718)
Fixes https://github.com/pytorch/pytorch/issues/113717.

When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.

But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
2023-12-09 01:47:25 +00:00
a7bcc78bff Make it clearer that current selective AC is PT2-only and private (#115081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115081
Approved by: https://github.com/albanD
2023-12-04 23:01:22 +00:00