mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb92983c9b
commit
1dc4588c6a
@ -2,6 +2,7 @@ import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
|
||||
|
||||
from .nested_tensor import NestedTensor
|
||||
from typing import * # noqa: F403
|
||||
@ -184,6 +185,8 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
|
||||
def extract_kwargs(arg):
|
||||
kwargs = {
|
||||
"offsets": arg.offsets(),
|
||||
"_max_seqlen": arg._max_seqlen,
|
||||
"_min_seqlen": arg._min_seqlen,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@ -256,18 +259,10 @@ def jagged_binary_pointwise(func, *args, **kwargs):
|
||||
|
||||
|
||||
def jagged_torch_function(func, *args, **kwargs):
|
||||
# Handle SDPA specially since it's CompositeImplicit. We don't want
|
||||
# the nestedness of the inputs to affect the kernel choice, so unwrap
|
||||
# the NTs here before passing to SDPA -> rewrap the output as NT.
|
||||
# SDPA has special kernels that handle nested tensors.
|
||||
# Dispatch to the correct implementation here
|
||||
if func is torch._C._nn.scaled_dot_product_attention:
|
||||
t_args = [t._values if isinstance(t, NestedTensor) else t for t in args]
|
||||
t_kwargs = {
|
||||
k: v._values if isinstance(v, NestedTensor) else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
output = func(*t_args, **t_kwargs)
|
||||
return NestedTensor(output, **extract_kwargs(args[0]))
|
||||
return jagged_scaled_dot_product_attention(*args, **kwargs)
|
||||
|
||||
# Handle flatten() here because it's CompositeImplicit.
|
||||
if func.__name__ == "flatten":
|
||||
@ -355,6 +350,10 @@ def is_contiguous_general(func, *args, **kwargs):
|
||||
if inp.lengths() is not None:
|
||||
return False
|
||||
|
||||
# If jagged dim is not 1 it's not contiguous
|
||||
if inp._ragged_idx != 1:
|
||||
return False
|
||||
|
||||
new_kwargs["memory_format"] = new_kwargs.get(
|
||||
"memory_format", torch.contiguous_format
|
||||
)
|
||||
@ -537,6 +536,11 @@ def unbind_int(func, *args, **kwargs):
|
||||
offsets = inp.offsets()
|
||||
lengths = inp.lengths()
|
||||
|
||||
if inp._ragged_idx != 1:
|
||||
raise RuntimeError(
|
||||
"unbind(): only supported for NestedTensor when jagged dimension is 1"
|
||||
)
|
||||
|
||||
if lengths is None:
|
||||
return torch.split(values, offsets.diff().tolist())
|
||||
return [
|
||||
@ -713,7 +717,32 @@ def transpose_int(func, *args, **kwargs):
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
from torch._prims_common import canonicalize_dims
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
|
||||
|
||||
# To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
|
||||
# instead of 1, although the internal Flash and mem-effn implementations will
|
||||
# use the inputs with raggedness in dim 1.
|
||||
if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
|
||||
if dim0 == 0 or dim1 == 0:
|
||||
raise ValueError(
|
||||
"Transpose is not supported on the batch dimension for jagged NT"
|
||||
)
|
||||
if dim0 == inp._ragged_idx:
|
||||
to_dim = dim1
|
||||
else:
|
||||
to_dim = dim0
|
||||
return NestedTensor(
|
||||
inp.values().transpose(
|
||||
_outer_to_inner_dim(len(inp._size), dim0),
|
||||
_outer_to_inner_dim(len(inp._size), dim1),
|
||||
),
|
||||
**extract_kwargs(inp),
|
||||
_ragged_idx=to_dim,
|
||||
)
|
||||
|
||||
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")
|
||||
new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user