Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677)
We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts.
In this PR, we register the buffers with the base graph to solve this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170
Approved by: https://github.com/drisspg
# Summary
The follow up PR to: https://github.com/pytorch/pytorch/pull/137526. In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.
We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found. Along the way found some masking bugs.
There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.
By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.
## Worked Example
Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.
ScoreMod:
```Python
bias = torch.randn(
params.seq_length,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
offset = torch.randint(
0,
params.seq_length,
(params.seq_length,),
device=self.device,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[offset[q_idx]]
```
I am removing all but the new subgraph injected into the backwards:
``` Python
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
grad_scores = (dsT)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
tmp4 = (dsT).to(tl.float32)
tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```
## Key points
* We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions
* We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up.
* Will do more extensive performance/ memory profiling in a follow up.
### Toy E2E example
I have a toy E2E training example PR in the gym for now: https://github.com/pytorch-labs/attention-gym/pull/84/
I plan to update to a realistic learnable bias before landing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137452
Approved by: https://github.com/Chillee
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
This allows `associative_scan` to take an arbitrary pytree of tensors,
which is flattened to their leaves before calling the `associative_scan`
higher order operator.
I also add support in inductor to generate code for scanning over sequences
of tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122137
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #119430