mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
doc: graph: support SDPA training
This commit is contained in:
BIN
doc/graph/fusion_patterns/images/sdpa_backward.png
Normal file
BIN
doc/graph/fusion_patterns/images/sdpa_backward.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 58 KiB |
BIN
doc/graph/fusion_patterns/images/sdpa_forward.png
Normal file
BIN
doc/graph/fusion_patterns/images/sdpa_forward.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
@ -31,7 +31,7 @@ SDPA graph, getting partition from the graph, and optimizing the kernels
|
||||
underneath. In general, an SDPA pattern is defined as a directional acyclic
|
||||
graph (DAG) using oneDNN Graph API.
|
||||
|
||||
### Floating-point SDPA
|
||||
### Floating-point SDPA for Inference
|
||||
|
||||
oneDNN defines floating-point (f32, bf16, or f16) SDPA as follows. The blue
|
||||
nodes are required when defining an SDPA pattern while the brown parts are
|
||||
@ -89,6 +89,52 @@ optional.
|
||||
|
||||

|
||||
|
||||
### Floating-point SDPA for Training Forward Propagation
|
||||
|
||||
oneDNN defines floating-point (f32, bf16, or f16) SDPA for training forward
|
||||
propagation as follows. The blue nodes are required while the brown nodes are optional.
|
||||
|
||||

|
||||
|
||||
The only difference between the inference and training forward propagation
|
||||
patterns is that, for training forward propagation, the `Stats` output of the
|
||||
SoftMax operation is needed. See [SoftMax](@ref dev_guide_op_softmax) in Graph
|
||||
API for more details.
|
||||
|
||||
### Floating-point SDPA for Training Backpropagation
|
||||
|
||||
oneDNN defines floating-point (f32, bf16, or f16) SDPA for training
|
||||
backpropagation as follows. The blue nodes are required while the brown nodes
|
||||
are optional.
|
||||
|
||||

|
||||
|
||||
1. The first MatMul computes the score between Query and Key, similar to
|
||||
inference and training forward propagation. See
|
||||
[MatMul](@ref dev_guide_op_matmul) in Graph API.
|
||||
2. The Scale node is optional and scales the output of the first MatMul using a
|
||||
scaling factor. This can be implemented using [Multiply](@ref dev_guide_op_multiply)
|
||||
or [Divide](@ref dev_guide_op_divide) in Graph API.
|
||||
3. The Mask node is optional and applies an attention mask to the output of the
|
||||
previous Scale node. For training backpropagation, only explicit user-generated
|
||||
masks are currently supported. The mask definition is the same as in
|
||||
inference and training forward propagation.
|
||||
4. The Subtract and Exp operations take the masked output and `Stats` as inputs
|
||||
and recover the probabilities computed by SoftMax in the training forward
|
||||
propagation. See [Subtract](@ref dev_guide_op_subtract) and [Exp](@ref dev_guide_op_exp)
|
||||
in Graph API.
|
||||
5. The TypeCast and MatMul operations after Exp are used to compute the
|
||||
gradients with respect to Value. TypeCast is required for bf16 and f16
|
||||
training scenarios. See [TypeCast](@ref dev_guide_op_typecast) in Graph API.
|
||||
6. The MatMul takes the output gradients (`dO`) and the Value as inputs to
|
||||
compute the gradients of the probabilities.
|
||||
7. The SoftMaxBackward operation computes the gradients of the scaled output.
|
||||
See [SoftMaxBackward](@ref dev_guide_op_softmaxbackward) in Graph API.
|
||||
8. The Scale node after SoftMaxBackward corresponds to the forward Scale node
|
||||
and is used to compute the gradients of the score.
|
||||
9. The TypeCast and two MatMul operations after the Scale node compute the
|
||||
gradients with respect to Query and Key, respectively. TypeCast is required
|
||||
for bf16 and f16 training scenarios.
|
||||
|
||||
## Data Types
|
||||
|
||||
@ -96,12 +142,12 @@ oneDNN supports the floating-point SDPA pattern with data types f32, bf16, and
|
||||
f16. You can specify the data type via the input and output logical tensors'
|
||||
data type fields for each operation.
|
||||
|
||||
oneDNN supports bf16 or f16 SDPA with f32 intermediate type, which means the
|
||||
Q/K/V tensors have bf16 or f16 data type while the output of the first MatMul,
|
||||
Scale, Mask, and the input of SoftMax are in f32 data type.
|
||||
|
||||
oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision,
|
||||
int8-bf16 mixed precision, and int8-f16 mixed precision data types.
|
||||
oneDNN supports bf16 or f16 SDPA with f32 intermediate type. For
|
||||
inference and traing forward propagation, the Q, K and V tensors use bf16 or f16
|
||||
data types, while the outputs of the first MatMul, Scale, Mask, and the input of
|
||||
SoftMax are in f32. Similarly, in training backpropagation, the Q, K, V, dO, dQ,
|
||||
dK and dV tensors use bf16 or f16, while the Stats input uses f32. The intermediate
|
||||
tensors are in f32, except those after TypeCast, which cast to bf16 or f16.
|
||||
|
||||
The definition of the data types and support status on different CPU and GPU
|
||||
platforms follow the general description in @ref dev_guide_data_types.
|
||||
@ -122,20 +168,22 @@ platforms follow the general description in @ref dev_guide_data_types.
|
||||
Divide, and Select operations require the input tensors to have the same
|
||||
shape or the shapes can be properly broadcasted based on the operation
|
||||
attribute.
|
||||
3. CPU
|
||||
- Optimized implementation is available for 4D Q/K tensors with shape defined
|
||||
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v).
|
||||
- Optimized implementation is available for OpenMP runtime and Threadpool
|
||||
3. Dropout is currently not supported in SDPA training.
|
||||
4. CPU
|
||||
- Optimized implementation for inference is available for 4D Q/K tensors with
|
||||
shape defined as (N, H, S, D_qk) and V tensor with shape defined as
|
||||
(N, H, S, D_v).
|
||||
- Optimized implementation for inference is available for OpenMP runtime and Threadpool
|
||||
runtime on Intel Architecture Processors.
|
||||
- Specifically for OpenMP runtime, the optimized implementation requires `N *
|
||||
H > 2 * thread number` to get enough parallelism.
|
||||
4. GPU
|
||||
- Optimized implementation is available for 4D Q/K tensors with shape defined
|
||||
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v) where
|
||||
D_qk equals D_v.
|
||||
- Optimized implementation is available for `f16` or `bf16` SDPA with `f32`
|
||||
intermediate data type and `D <= 512` on Intel Graphics Products with
|
||||
Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.
|
||||
5. GPU
|
||||
- Optimized implementation for inference is available for 4D Q/K tensors with
|
||||
shape defined as (N, H, S, D_qk) and V tensor with shape defined as (N, H,
|
||||
S, D_v) where D_qk equals D_v.
|
||||
- Optimized implementation for inference is available for `f16` or `bf16`
|
||||
SDPA with `f32` intermediate data type and `D <= 512` on Intel Graphics
|
||||
Products with Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.
|
||||
|
||||
## Example
|
||||
|
||||
|
Reference in New Issue
Block a user