2 Commits

Author SHA1 Message Date
0747d95994 Add Loads from fixed inputs (#162031)
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118
2025-10-10 01:23:37 +00:00
0a2cde2f06 Add Flash Attention support to FlexAttention (#161118)
Relies on this PR in Flash Attention: https://github.com/Dao-AILab/flash-attention/pull/1840

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161118
Approved by: https://github.com/v0i0
2025-10-10 01:23:37 +00:00