Add an SDPA dispatcher for nested tensors with jagged layouts (#114164)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114164
Approved by: https://github.com/jbschlosser
This commit is contained in:
Antoni Viros
2023-12-05 03:38:26 +00:00
committed by PyTorch MergeBot
parent fb92983c9b
commit 1dc4588c6a
7 changed files with 1180 additions and 123 deletions

View File

@ -2,6 +2,7 @@ import functools
import math
import torch
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
from .nested_tensor import NestedTensor
from typing import * # noqa: F403
@ -184,6 +185,8 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
def extract_kwargs(arg):
kwargs = {
"offsets": arg.offsets(),
"_max_seqlen": arg._max_seqlen,
"_min_seqlen": arg._min_seqlen,
}
return kwargs
@ -256,18 +259,10 @@ def jagged_binary_pointwise(func, *args, **kwargs):
def jagged_torch_function(func, *args, **kwargs):
# Handle SDPA specially since it's CompositeImplicit. We don't want
# the nestedness of the inputs to affect the kernel choice, so unwrap
# the NTs here before passing to SDPA -> rewrap the output as NT.
# SDPA has special kernels that handle nested tensors.
# Dispatch to the correct implementation here
if func is torch._C._nn.scaled_dot_product_attention:
t_args = [t._values if isinstance(t, NestedTensor) else t for t in args]
t_kwargs = {
k: v._values if isinstance(v, NestedTensor) else v
for k, v in kwargs.items()
}
output = func(*t_args, **t_kwargs)
return NestedTensor(output, **extract_kwargs(args[0]))
return jagged_scaled_dot_product_attention(*args, **kwargs)
# Handle flatten() here because it's CompositeImplicit.
if func.__name__ == "flatten":
@ -355,6 +350,10 @@ def is_contiguous_general(func, *args, **kwargs):
if inp.lengths() is not None:
return False
# If jagged dim is not 1 it's not contiguous
if inp._ragged_idx != 1:
return False
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
@ -537,6 +536,11 @@ def unbind_int(func, *args, **kwargs):
offsets = inp.offsets()
lengths = inp.lengths()
if inp._ragged_idx != 1:
raise RuntimeError(
"unbind(): only supported for NestedTensor when jagged dimension is 1"
)
if lengths is None:
return torch.split(values, offsets.diff().tolist())
return [
@ -713,7 +717,32 @@ def transpose_int(func, *args, **kwargs):
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
from torch._prims_common import canonicalize_dims
inp = new_kwargs.pop("input")
dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
# To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
# instead of 1, although the internal Flash and mem-effn implementations will
# use the inputs with raggedness in dim 1.
if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
if dim0 == 0 or dim1 == 0:
raise ValueError(
"Transpose is not supported on the batch dimension for jagged NT"
)
if dim0 == inp._ragged_idx:
to_dim = dim1
else:
to_dim = dim0
return NestedTensor(
inp.values().transpose(
_outer_to_inner_dim(len(inp._size), dim0),
_outer_to_inner_dim(len(inp._size), dim1),
),
**extract_kwargs(inp),
_ragged_idx=to_dim,
)
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")
new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose")