mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
In OneDNN v3.7, SDPA has below defects: 1. The dtype of intermediate value is the same as QKV, while Pytorch uses FP32 dtype for intermediate value to make sure better accuracy. 2. Only support headdim size <= 256. 3. Don't support implict causal mask when QKV is FP32. We need to build an attention mask explicitly with aten ops. In OneDNN v3.8, they have update for these defects. Since these are tiny changes, I decided to put them in single PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152091 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/drisspg