Commit Graph

53 Commits

Author SHA1 Message Date
16e202a38e [dynamo] improved graph break messages for some common graph break sites [1/N] (#146525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525
Approved by: https://github.com/jansel
2025-02-20 00:08:13 +00:00
99dbc5b0e2 PEP585 update - test (#145176)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176
Approved by: https://github.com/bobrenjc93
2025-01-22 04:48:28 +00:00
972d4a154d Add facility to run dynamo UTs for non-cuda devices (#140929)
This is in line with changes introduced with https://github.com/pytorch/pytorch/pull/130714, additional files are included to support non-cuda devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140929
Approved by: https://github.com/kwen2501, https://github.com/EikanWang, https://github.com/guangyey
2025-01-20 05:56:38 +00:00
d25e6e623f Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
2024-12-13 22:13:12 +00:00
db4e8a1d8a [ca] expose option to collect sizes as dynamic (#141153)
This is to address recompiles from eager nodes that saved dynamic activations

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141153
Approved by: https://github.com/jansel
ghstack dependencies: #141152
2024-11-22 19:26:27 +00:00
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
e80fe7f13a [dynamo][guards] Skip guards on empty nn module hooks (#138942)
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in https://github.com/pytorch/pytorch/issues/138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138942
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #139040, #138954
2024-10-29 02:11:47 +00:00
fd9f4e6770 Back out "[compiled autograd] tls access helpers (#138061)" and Back out "[compiled autograd] Compiled autograd configs in TLS (#137821)" (#139086)
Summary:
Original commit changeset: 9bf80c1492d7

Original Phabricator Diff: D64796226

Original commit changeset: aa1d9ef8f6e6

Original Phabricator Diff: D64796212

Differential Revision: D65072644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139086
Approved by: https://github.com/malfet
2024-10-28 23:37:05 +00:00
5a13282c75 [compiled autograd] tls access helpers (#138061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138061
Approved by: https://github.com/yf225
ghstack dependencies: #137953, #137821
2024-10-22 08:03:52 +00:00
49fa437097 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-22 08:03:52 +00:00
361f42bc42 Revert "[compiled autograd] Compiled autograd configs in TLS (#137821)"
This reverts commit 9aba0b91c8df4a15654f9ccc02abca31bdd81650.

Reverted https://github.com/pytorch/pytorch/pull/137821 on behalf of https://github.com/wdvr due to Reverting this for now, it is failing test_public_bindings in trunk ([comment](https://github.com/pytorch/pytorch/pull/137821#issuecomment-2417351788))
2024-10-16 16:38:29 +00:00
9aba0b91c8 [compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
2024-10-16 09:28:32 +00:00
e14b58ffbd Using device-agnostic autocast api (#136613)
- using torch.autocast(device_str="cuda") instead of torch.cuda.amp.autocast()
- using torch.autocast(device_str="cpu") instead of torch.cpu.amp.autocast()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136613
Approved by: https://github.com/shink, https://github.com/cyyever, https://github.com/kwen2501
2024-09-27 07:16:24 +00:00
a815611db9 [Traceable FSDP2][Partitioner] Must save AC output if output has a backward hook (#135727)
If node is AC region output and has a backward hook on it, we intentionally choose to save it.
This is to work around circular dependencies in Traceable FSDP2+AC.
Example:
```
out = fully_shard(utils.checkpoint(module))(x)
norm_out = layer_norm(out)
```
and there is a circular dependency:
1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) in order to be recomputed.
3. `out`'s FSDP2 backward hook, as is the case for all eager backward hooks, depends on `out_grad`  -> circular dependency with (1)!

Solution: check whether `out` has a backward hook, and if so, intentionally save `out` in forward graph outputs. With this, we can break the above circular dependency.

----

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135727
Approved by: https://github.com/Chillee
2024-09-14 08:45:58 +00:00
920f0426ae Add None return type to init -- tests rest (#132376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132376
Approved by: https://github.com/jamesjwu
ghstack dependencies: #132335, #132351, #132352
2024-08-01 15:44:51 +00:00
918ece4f4d [BE][Easy][11/19] enforce style for empty lines in import segments in test/dy*/ (#129762)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762
Approved by: https://github.com/anijain2305
2024-07-27 17:43:53 +00:00
eqy
f845a7a91a [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-30 19:22:16 +00:00
999eec8dea Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit b7e7a4cb01de394af7686ab6feb216a8a5c716bb.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some test_transformer running on internal A100 and V100 ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196202003))
2024-06-28 06:03:54 +00:00
0ffb17547e [Simple FSDP] Add unit test for torch.compile + reparameterization + SAC (#129641)
This can reproduce the error in https://github.com/pytorch/pytorch/issues/129684. Adding a unit test so that we hold the line for torch.compile + reparameterization + SAC to always be working, to pave the path for Tianyu's intern's project.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129641
Approved by: https://github.com/tianyu-l
2024-06-28 00:00:36 +00:00
b7e7a4cb01 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-26 00:49:18 +00:00
575bc1e3af [Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) (#129295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129295
Approved by: https://github.com/Chillee
2024-06-25 23:47:08 +00:00
1877b7896c [checkpoint] Clean up selective activation checkpoint and make public (#125795)
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230))  Enum instead of bool.
   - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
   - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE :  metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-18 18:18:50 +00:00
6895a5804c Revert "[checkpoint] Clean up selective activation checkpoint and make public (#125795)"
This reverts commit c472cec5656b9ffb668af97a02d711bdbdf5ebec.

Reverted https://github.com/pytorch/pytorch/pull/125795 on behalf of https://github.com/soulitzer due to breaking torchtitan CI ([comment](https://github.com/pytorch/pytorch/pull/125795#issuecomment-2167036157))
2024-06-14 01:14:59 +00:00
c472cec565 [checkpoint] Clean up selective activation checkpoint and make public (#125795)
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
2024-06-12 23:57:33 +00:00
1507d5205a [dynamo][fsdp] Skip Dynamo tracing of __getattr__ if its top-level frame (#127263)
The generated bytecode for the first frame is below. Inlined comments about the LOAD_ATTR which causes Dynamo to trigger again on `__getattr__`.

~~~
[__bytecode] MODIFIED BYTECODE fn /data/users/anijain/pytorch2/test/dynamo/test_activation_checkpointing.py line 1129
[__bytecode] 1129           0 COPY_FREE_VARS           1
[__bytecode]                2 RESUME                   0
[__bytecode]                4 PUSH_NULL
[__bytecode]                6 LOAD_GLOBAL             10 (__compiled_fn_1)
[__bytecode]               18 LOAD_FAST                0 (x)
[__bytecode]               20 LOAD_DEREF               1 (mod)
[__bytecode]               22 LOAD_ATTR                6 (_checkpoint_wrapped_module)
[__bytecode]               32 LOAD_CONST               1 (0)
[__bytecode]               34 BINARY_SUBSCR
[__bytecode]               44 LOAD_ATTR                7 (weight)
[__bytecode]               54 LOAD_DEREF               1 (mod)
[__bytecode]               56 LOAD_ATTR                6 (_checkpoint_wrapped_module)
[__bytecode]               66 LOAD_CONST               1 (0)
[__bytecode]               68 BINARY_SUBSCR
[__bytecode]               78 LOAD_ATTR                8 (bias)

# When this optimized bytecode is executed, these two lines call the __getattr__ of ActivationWrapper module.
# Dynamo gets invoked on __getattr__.

# If we had inlined __getattr__ during the tracing, we would have seen the LOAD_ATTR
# on more low level data structures like _modules, obviating the need for CPython
# to call python overriden __getattr__. But today, UnspecializedNNModuleVariable
# calls python getattr at tracing time (instead of inlining it), resulting in LOAD_ATTR
# on the module itself.

# To prevent Dynamo to skip tracing of __Getattr__ on the optimized bytecode,
# we can check if its top level frame and just skip it.

[__bytecode]               88 LOAD_DEREF               1 (mod)
[__bytecode]               90 LOAD_ATTR                0 (a)

[__bytecode]              100 PRECALL                  4
[__bytecode]              104 CALL                     4
[__bytecode]              114 UNPACK_SEQUENCE          1
[__bytecode]              118 RETURN_VALUE
~~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127263
Approved by: https://github.com/yf225
2024-05-28 08:16:53 +00:00
26f4f10ac8 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
2024-05-27 14:49:57 +00:00
55c0ab2887 Revert "[5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)"
This reverts commit 7763c83af67eebfdd5185dbe6ce15ece2b992a0f.

Reverted https://github.com/pytorch/pytorch/pull/127126 on behalf of https://github.com/XuehaiPan due to Broken CI ([comment](https://github.com/pytorch/pytorch/pull/127126#issuecomment-2133044286))
2024-05-27 09:22:08 +00:00
7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00
a28bfb5ed5 [4/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort functorch (#127125)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125
Approved by: https://github.com/Skylion007
ghstack dependencies: #127122, #127123, #127124
2024-05-25 22:45:38 +00:00
8e98fda7a9 [dynamo][easy] Add AC test and improve graph break message (#121394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121394
Approved by: https://github.com/yanboliang
2024-04-06 01:02:45 +00:00
cbde0f048b [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123300
Approved by: https://github.com/jansel, https://github.com/malfet, https://github.com/zou3519
2024-04-05 20:13:17 +00:00
b9c9f037d1 Added some checkpointing tests (#122848)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122848
Approved by: https://github.com/anijain2305
2024-03-29 03:49:19 +00:00
244b124bb8 Add linux cpu test for 3.12 (#117853)
This is continuation of work: https://github.com/pytorch/pytorch/pull/113987

Co-authored-by: albanD <desmaison.alban@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117853
Approved by: https://github.com/albanD
2024-02-14 20:52:23 +00:00
1de50f8654 [HigherOrderOp] fix stack trace to report user stack (#118826)
Fixes https://github.com/pytorch/pytorch/issues/111020

For the following code:
```python
import torch
import torch._higher_order_ops.wrap

glob = []

def f(x):
    glob.append(x)
    return x.clone()

@torch.compile(backend='eager', fullgraph=True)
def g(x):
    return torch.ops.higher_order.wrap(f, x)

x = torch.randn(3)
g(x)
```

The stacktrace now becomes:
```
[2024-02-01 15:23:34,691] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting wrap, we were unable to trace function `f` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 381, in speculate_subgraph
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     output = f.call_function(tx, args, sub_kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 278, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 86, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2370, in inline_call_
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return inner_fn(self, inst)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.call_function(fn, args, {})
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.push(fn.call_function(self, args, kwargs))
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/misc.py", line 583, in call_function
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return self.obj.call_method(tx, self.name, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 330, in call_method
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_method(tx, name, args, kwargs)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 241, in call_method
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tx.output.side_effects.mutation(self)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 325, in mutation
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.check_allowed_side_effect(var)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 157, in check_allowed_side_effect
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     unimplemented(
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[2024-02-01 15:23:34,692] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test.py", line 219, in <module>
    g(x)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 453, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 615, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 390, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 650, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 248, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 155, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 496, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2125, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1227, in call_function
    p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1190, in create_wrapped_node
    ) = speculate_subgraph(
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 453, in speculate_subgraph
    raise ex
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 381, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 278, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 86, in call_function
    return tx.inline_user_function_return(
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2370, in inline_call_
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 787, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 750, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 469, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1196, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 651, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/misc.py", line 583, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 330, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lists.py", line 241, in call_method
    tx.output.side_effects.mutation(self)
  File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 325, in mutation
    self.check_allowed_side_effect(var)
  File "/home/yidi/local/pytorch/torch/_dynamo/side_effects.py", line 157, in check_allowed_side_effect
    unimplemented(
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)

from user code:
   File "/home/yidi/local/pytorch/test.py", line 216, in g
    return torch.ops.higher_order.wrap(f, x)
  File "/home/yidi/local/pytorch/test.py", line 211, in f
    glob.append(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118826
Approved by: https://github.com/yanboliang, https://github.com/zou3519
2024-02-02 20:08:01 +00:00
865945cc1f Convert requires_cuda to full decorator (#118281)
Don't require using it as `@requires_cuda()` -> `@requires_cuda` instead No need for the partial function invoked many times

Split out this change from the initial large refactoring in #117741 to hopefully get merged before conflicts arise

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118281
Approved by: https://github.com/ezyang
2024-01-25 15:50:21 +00:00
e056cf5507 [ac][pattern matcher] Do not percolate tags beyond the inputs of matched portion (#118034)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118034
Approved by: https://github.com/yf225
2024-01-23 05:02:32 +00:00
8bf788c390 [SAC][Dynamo] Add support for functools.partial in CheckpointHigherOrderVariable (#117657)
# Context

In some cases, we might want to build the `context_fn` with runtime-defined policies. One way of implementing this is to make `context_fn` be a partial, which holds the information that we want to pass. One concrete example is the [automatic policy selection from `xformers`](ad986981b1/xformers/checkpoint.py (L185)).

# The problem

The previous implementation wouldn't work with partials because `FunctoolsPartialVariable` doesn't have a `fn` attribute.

This PR addresses this case, but ideally we could get this solved in a more general fashion, as callable classes and `NestedUserFunctionVariable` are not supported by this PR.

# Tests

I've added a basic test that mimics the tests around it. The tests could probably be simplified, but I've decided to keep changes to a minimum.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117657
Approved by: https://github.com/yf225
2024-01-18 11:59:23 +00:00
db79ceb110 [ROCm] Enabling additional UTs on ROCm (#115738)
Unskips mostly for dynamo/inductor UT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115738
Approved by: https://github.com/jithunnair-amd, https://github.com/malfet
2024-01-09 08:36:07 +00:00
4d6a1ad400 Activation checkpoint and checkpoint_sequential errors if use_reentrant not passed explicitly (#115868)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115868
Approved by: https://github.com/albanD
ghstack dependencies: #115438
2023-12-20 15:23:44 +00:00
dd367b7c8f check tensor subclass when using torch.compile + SAC (#115960)
as titled, when using SAC + torch.compile, it currently only check for
functional tensor, but not checking any tensor subclasses, therefore SAC
under torch.compile would ignore the tensor types like tensor
subclasses. Fixed in this PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115960
Approved by: https://github.com/bdhirsh
2023-12-18 17:49:06 +00:00
495054545c Allow preserve_rng_state=True when torch.compile + selective checkpointing + CUDA (#113718)
Fixes https://github.com/pytorch/pytorch/issues/113717.

When `preserve_rng_state=True`, we let AOTAutograd trace through `torch.random.fork_rng` op, and the tracing doesn't work under CUDA, hence the original error reported in the issue.

But since we are already doing RNG functionalization at Inductor level, we don't actually need to trace this `fork_rng` op. So we should just rewrite `preserve_rng_state` to False when we are using torch.compile (and let Inductor do its RNG functionalization which it's already been doing).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113718
Approved by: https://github.com/wanchaol
2023-12-09 01:47:25 +00:00
a7bcc78bff Make it clearer that current selective AC is PT2-only and private (#115081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115081
Approved by: https://github.com/albanD
2023-12-04 23:01:22 +00:00
11a3c7696b [dynamo - testing] Add repro for higher order op list inputs (#111647)
Add repro from https://github.com/pytorch/pytorch/issues/110118 now that it has been fixed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111647
Approved by: https://github.com/ezyang
2023-10-20 18:23:23 +00:00
fe11227764 [dynamo][higher order op] Fix minor bug in error msgs (#110099)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110099
Approved by: https://github.com/zou3519
2023-09-27 20:28:17 +00:00
3f3e353885 torch.compile + selective activation checkpointing (#105489)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489

NOTE: this PR is tagged "not user facing", because it's not ready to be announced externally yet.

This PR implements torch.compile + selective activation checkpoint (SAC) integration, by using `TagActivationCheckpoint` (same backend as torch.compile + full activation checkpoint integration).

TorchDispatchMode based implementation cannot support including inplace ops in the checkpointed region at the moment (the reason for this needs investigation), and there is also no way to ban them (because TorchDispatchMode now only sees "after-functionalization" ops, so can't detect if an op is in-place). Hence we hide torch.compile + SAC behind a flag (`torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint`) and will only use it internally for cases that are known to not have in-place ops. This state won't last too long, because in-place op will at least be able to be detected after Brian's mode reordering and related functionalization changes.
So next steps after this PR:
1. Wait for Brian's mode reordering and related functionalization changes to land, and then try to enable the "inplace ops" unit test for torch.compile + selective activation checkpoint (if it doesn't work, investigate why).
2. Unify selective- and full-checkpoint under TorchDispatchMode based implementation.

Differential Revision: D47497145

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105489
Approved by: https://github.com/anijain2305
2023-09-21 16:24:11 +00:00
78a053bad7 [activation checkpointing] Add default autocast keys to functional rng wrappers (#107934)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107934
Approved by: https://github.com/xw285cornell
2023-08-25 18:22:02 +00:00
0b11da0ccb [partitioners][ac][dynamic] Fix output signature of fwd with symints (#105771)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105771
Approved by: https://github.com/Chillee
2023-07-22 03:04:11 +00:00
90eaa98d13 dynamo : kwarg support for wrap (higher order op) (#104180)
Ref: https://github.com/pytorch/pytorch/issues/100278

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104180
Approved by: https://github.com/zou3519
2023-07-11 06:08:18 +00:00
8c191d8eef [dynamo][ac] Reland #104397 - Remove disable monkeypatching of utils.checkpoint (#104665)
NO CHANGE from before. The ancestor diff was reverted, so this diff got reverted as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104665
Approved by: https://github.com/wconstab
2023-07-06 00:48:02 +00:00
40f53912cf Revert "[dynamo][ac] Remove disable monkeypatching of utils.checkpoint (#104397)"
This reverts commit 537a6c0651edda1e1a55b90658a6c24d049ff982.

Reverted https://github.com/pytorch/pytorch/pull/104397 on behalf of https://github.com/huydhn due to This has been reverted internally by D47216591, so I need to also revert it on OSS to keep them in sync ([comment](https://github.com/pytorch/pytorch/pull/104397#issuecomment-1621086360))
2023-07-05 06:11:08 +00:00