## TODO
Check on multi indices
```Python
@cute.jit
def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
in_ptr4 = buffers[0]
tmp0 = tSrS_ssa
tmp1 = b_idx
tmp2 = h_idx
tmp3 = cute.make_fragment(1, cutlass.Int32)
tmp4 = tmp3.store(32*tmp1 + tmp2)
tmp5 = cute.make_fragment(1, cutlass.BFloat16)
tmp6 = tmp3[0]
tmp7 = tmp5[0] = (in_ptr4[tmp6])
tmp8 = (tmp5.load()).to(cutlass.Float32)
tmp9 = (tmp0 + tmp8)
tSrS_ssa = tmp9
return tSrS_ssa
```
I dont think that
```
tmp4 = tmp3.store(32*tmp1 + tmp2)
tmp5 = cute.make_fragment(1, cutlass.BFloat16)
tmp6 = tmp3[0]
tmp7 = tmp5[0] = (in_ptr4[tmp6]
```
is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118