doc: graph: support SDPA training

This commit is contained in:
Bao, Yixin
2025-07-02 18:02:11 -07:00
committed by Tao Lv
parent c121df46f6
commit 66a749f80d
3 changed files with 66 additions and 18 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -31,7 +31,7 @@ SDPA graph, getting partition from the graph, and optimizing the kernels
underneath. In general, an SDPA pattern is defined as a directional acyclic
graph (DAG) using oneDNN Graph API.
### Floating-point SDPA
### Floating-point SDPA for Inference
oneDNN defines floating-point (f32, bf16, or f16) SDPA as follows. The blue
nodes are required when defining an SDPA pattern while the brown parts are
@ -89,6 +89,52 @@ optional.
![SDPA-Reorder](images/sdpa-reorder.png)
### Floating-point SDPA for Training Forward Propagation
oneDNN defines floating-point (f32, bf16, or f16) SDPA for training forward
propagation as follows. The blue nodes are required while the brown nodes are optional.
![SDPA pattern](images/sdpa_forward.png)
The only difference between the inference and training forward propagation
patterns is that, for training forward propagation, the `Stats` output of the
SoftMax operation is needed. See [SoftMax](@ref dev_guide_op_softmax) in Graph
API for more details.
### Floating-point SDPA for Training Backpropagation
oneDNN defines floating-point (f32, bf16, or f16) SDPA for training
backpropagation as follows. The blue nodes are required while the brown nodes
are optional.
![SDPA backward pattern](images/sdpa_backward.png)
1. The first MatMul computes the score between Query and Key, similar to
inference and training forward propagation. See
[MatMul](@ref dev_guide_op_matmul) in Graph API.
2. The Scale node is optional and scales the output of the first MatMul using a
scaling factor. This can be implemented using [Multiply](@ref dev_guide_op_multiply)
or [Divide](@ref dev_guide_op_divide) in Graph API.
3. The Mask node is optional and applies an attention mask to the output of the
previous Scale node. For training backpropagation, only explicit user-generated
masks are currently supported. The mask definition is the same as in
inference and training forward propagation.
4. The Subtract and Exp operations take the masked output and `Stats` as inputs
and recover the probabilities computed by SoftMax in the training forward
propagation. See [Subtract](@ref dev_guide_op_subtract) and [Exp](@ref dev_guide_op_exp)
in Graph API.
5. The TypeCast and MatMul operations after Exp are used to compute the
gradients with respect to Value. TypeCast is required for bf16 and f16
training scenarios. See [TypeCast](@ref dev_guide_op_typecast) in Graph API.
6. The MatMul takes the output gradients (`dO`) and the Value as inputs to
compute the gradients of the probabilities.
7. The SoftMaxBackward operation computes the gradients of the scaled output.
See [SoftMaxBackward](@ref dev_guide_op_softmaxbackward) in Graph API.
8. The Scale node after SoftMaxBackward corresponds to the forward Scale node
and is used to compute the gradients of the score.
9. The TypeCast and two MatMul operations after the Scale node compute the
gradients with respect to Query and Key, respectively. TypeCast is required
for bf16 and f16 training scenarios.
## Data Types
@ -96,12 +142,12 @@ oneDNN supports the floating-point SDPA pattern with data types f32, bf16, and
f16. You can specify the data type via the input and output logical tensors'
data type fields for each operation.
oneDNN supports bf16 or f16 SDPA with f32 intermediate type, which means the
Q/K/V tensors have bf16 or f16 data type while the output of the first MatMul,
Scale, Mask, and the input of SoftMax are in f32 data type.
oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision,
int8-bf16 mixed precision, and int8-f16 mixed precision data types.
oneDNN supports bf16 or f16 SDPA with f32 intermediate type. For
inference and traing forward propagation, the Q, K and V tensors use bf16 or f16
data types, while the outputs of the first MatMul, Scale, Mask, and the input of
SoftMax are in f32. Similarly, in training backpropagation, the Q, K, V, dO, dQ,
dK and dV tensors use bf16 or f16, while the Stats input uses f32. The intermediate
tensors are in f32, except those after TypeCast, which cast to bf16 or f16.
The definition of the data types and support status on different CPU and GPU
platforms follow the general description in @ref dev_guide_data_types.
@ -122,20 +168,22 @@ platforms follow the general description in @ref dev_guide_data_types.
Divide, and Select operations require the input tensors to have the same
shape or the shapes can be properly broadcasted based on the operation
attribute.
3. CPU
- Optimized implementation is available for 4D Q/K tensors with shape defined
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v).
- Optimized implementation is available for OpenMP runtime and Threadpool
3. Dropout is currently not supported in SDPA training.
4. CPU
- Optimized implementation for inference is available for 4D Q/K tensors with
shape defined as (N, H, S, D_qk) and V tensor with shape defined as
(N, H, S, D_v).
- Optimized implementation for inference is available for OpenMP runtime and Threadpool
runtime on Intel Architecture Processors.
- Specifically for OpenMP runtime, the optimized implementation requires `N *
H > 2 * thread number` to get enough parallelism.
4. GPU
- Optimized implementation is available for 4D Q/K tensors with shape defined
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v) where
D_qk equals D_v.
- Optimized implementation is available for `f16` or `bf16` SDPA with `f32`
intermediate data type and `D <= 512` on Intel Graphics Products with
Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.
5. GPU
- Optimized implementation for inference is available for 4D Q/K tensors with
shape defined as (N, H, S, D_qk) and V tensor with shape defined as (N, H,
S, D_v) where D_qk equals D_v.
- Optimized implementation for inference is available for `f16` or `bf16`
SDPA with `f32` intermediate data type and `D <= 512` on Intel Graphics
Products with Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.
## Example