mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
We previously only supported the same v_head dim and + qk_head dim. When allowed for different head-dims I accidently kept the same query strides for the output. This PR fixes this bug as well it ensures that we always produce output in the same stride order as the input query. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135882 Approved by: https://github.com/yanboliang, https://github.com/Chillee