This reverts commit 7743149b2be4a9eba7e0997ccdc6abe552bec266.
Reverts
* https://github.com/pytorch/pytorch/pull/135503
* https://github.com/pytorch/pytorch/pull/135502
* https://github.com/pytorch/pytorch/pull/135422
This passes this test. Earlier, the getitem would stay like a getitem in the Fx graph. But now the fake tensor propagations fails saying that .item is called. It seems that torch function is not getting triggered while fake tensor propagation.
```
import torch
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature, _score_mod_signature, flex_attention
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.virtualized import ops
from torch.nn.attention.flex_attention import create_block_mask
torch.set_default_device('cuda')
flex_attention = torch.compile(flex_attention, dynamic=False)
prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
return prefix_lengths[b] >= kv
mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136590
Approved by: https://github.com/Chillee
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)
Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call
resume fn structure:
1. enter context
2. jump
...
3. exit context
The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).
So for torch function modes the structure of our output code is this:
1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function
Then our resume fn looks like this:
1. no-op enter torch function mode
2. jump
3. exit tf mode
To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).
Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Before the fix, the unit test will fail at forward Dynamo tracing:
```
File "/data/users/willfeng/pytorch/test/distributed/_composable/test_replicate_with_compiler.py", line 415, in test_ddp_tp
loss = compiled_replicate_model(data).sum()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
torch._dynamo.exc.InternalTorchDynamoError: SymNodeVariable() is not a constant
from user code:
File "/data/users/willfeng/pytorch/torch/distributed/tensor/parallel/_data_parallel_utils.py", line 34, in _unflatten_tensor
result = DTensor.from_local(
```
After the fix, the compilation fails at a later step (Compiled Autograd tracing), due to needing "pre-dispatch tracing of backward graph" feature (see details at https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474).
I believe this PR is a net improvement, because it should also fix the 1D Traceable FSDP2 failure case on internal models (https://github.com/pytorch/pytorch/issues/130978#issuecomment-2319476690), which is much harder to build a minimal unit test for.
Fixes https://github.com/pytorch/pytorch/issues/130978.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135315
Approved by: https://github.com/bdhirsh
Fixes https://github.com/pytorch/pytorch/issues/114389
Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151
Approved by: https://github.com/bdhirsh
Context: Adding support for the beta parameters to be tensors
Details:
In this PR similarly to the previous, foreach_pow calls item() on the first argument when it is a scalar tensor. In this case, we broadcast that scalar tensor into a list of aliases of that tensor to avoid the item() call, and this results in a device copy of the scalar tensor. Once again, I dont think we can change the foreach_pow API due to BC concerns, so this op rewrite allows us to avoid a graph break, generate semantically the same code, and not affect eager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134167
Approved by: https://github.com/anijain2305
ghstack dependencies: #134166
Context: Adding support for the beta parameters to be tensors
Details:
In order to add support for the beta params to be tensors without graph breaks in the Adam family of optimizers it is necessary to support foreach_lerp(x, y, s) where s is a scalar tensor. Today, this isn't possible because when `s` is a scalar, internally the aten op calls item() on it to extract the value and distribute it to each of the ops on the individual list indices. To support this in dynamo without graph breaks, I decompose the lerp into its constituent ops which support a scalar tensor in the list argument positions which do not result in an item() call. To be clear the item() call is more performant for eager I think and for BC I don't think we can modify that API, so this allows us to have performance in eager and no graph breaks in compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134166
Approved by: https://github.com/anijain2305
Changes:
1. Move `polyfill.py` -> `polyfills/__init__.py`. It can be used as `polyfill.xxx` -> `polyfills.xxx`.
2. Move submodule loading from `polyfills/__init__.py` to `polyfills/loader.py`.
Merge `polyfill.py` and `polyfills/` packages. Each polyfill module have its own namespace for better code organization.
The ultimate goal is make `polyfills/__init__.py` empty and all polyfill functions move to its own namespace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133977
Approved by: https://github.com/jansel
This PR adds support `torch._C._push_on_torch_function_stack()` by updating `torch.py` to push onto the symbolic torch function mode stack when a push is encountered. The same side effects infra used in the previous PR is used to track the mutation of the torch function mode stack and add bytecode to update it if it is mutated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133132
Approved by: https://github.com/williamwen42
ghstack dependencies: #133130, #133729, #133131
This PR adds support for tracing `torch._C._pop_torch_function_stack()` without graph breaking and in order to verify the state change also adds replay of mutations to the torch function mode stack via side_effects appending supplemental bytecode as we do for other python mutable objects.
Details:
To represent the torch function mode stack symbolically a deque field is added to the instruction translator. When the InstructionTranslator is initialized, all modes are read from the current torch function mode stack, and stashed in a global weak ref for later access (using existing sources) without needing to push/pop the python/cpp torch function mode stack.
During tracing, when `_pop_torch_function_stack` is encountered a value is popped from this deque and the variable tracker representing the mode is returned. To ensure the true torch function mode stack matches this state, `TorchFunctionModeStackVariable`, a singleton, is marked as mutated, this adds it to side effects, where during final codegen, side effects will codegen a call to a python helper which will update the python torch function mode stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133131
Approved by: https://github.com/jansel
ghstack dependencies: #133130, #133729
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
Significant bytecode generation API change!
The new suggested convention to generating bytecode to call a function is now to wrap instructions that push a callable to the stack with `add_push_null`, then that callable is called with `create_call_function` with `push_null=False` (see diff for examples).
In Python 3.13, NULL is now expected to be pushed after the callable. In <=3.12, the NULL was pushed before the callable. This change abstracts away the exact placement of the NULL, but the developer must be aware that a NULL may be needed when codegen'ing a callable.
This abstraction also reduces the need for the `push_null=True` option in `create_call_function`, which removes the need to rotate a NULL to the right place on the stack with a sequence of `SWAP` instructions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129172
Approved by: https://github.com/jansel
Improve Dynamo to support the FSDP2 `use_training_state()` context manager.
Test command:
`
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_dynamo_trace_use_training_state
`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127854
Approved by: https://github.com/yanboliang