mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation fix https://github.com/pytorch/pytorch/issues/110040 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158189 Approved by: https://github.com/Skylion007, https://github.com/cyyever
271 lines
8.9 KiB
Python
271 lines
8.9 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Optional, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
|
|
from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox
|
|
from .virtualized import ops
|
|
|
|
|
|
# pyre-ignore[2,3]
|
|
def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len):
|
|
# jagged_len + 1 is used as the upper bound,
|
|
# because the last sequence length may be zero
|
|
begin_idx = ops.indirect_indexing(
|
|
offsets_loader([batch_idx]),
|
|
jagged_len + 1,
|
|
)
|
|
end_idx = offsets_loader([batch_idx + 1])
|
|
jagged_idx = begin_idx + seq_idx
|
|
return jagged_idx, end_idx
|
|
|
|
|
|
def get_inverse_offsets(
|
|
offsets: TensorBox,
|
|
jagged_len: Union[int, sympy.Expr],
|
|
realize: bool = True,
|
|
) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
|
"""
|
|
Returns "inverse_offsets" - the inverse of the offsets array.
|
|
offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor).
|
|
inverse_offsets maps jagged index to batch index.
|
|
|
|
e.g. for offsets [0, 3, 4, 9, 10] this will return
|
|
inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3]
|
|
|
|
For the given offsets, the computed inverse_offsets are cached
|
|
on the first call and reused in the further calls.
|
|
"""
|
|
|
|
if hasattr(offsets, "inverse_offsets"):
|
|
# inverse_offsets are already computed
|
|
# for these offsets: can reuse
|
|
return offsets.inverse_offsets
|
|
|
|
# ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise
|
|
# kernels, i.e. we need to realize it before using. In other words, we need
|
|
# offsets to be in global memory so that we can binary search over the
|
|
# entire tensor
|
|
offsets.realize()
|
|
device: torch.device = offsets.get_device_or_error()
|
|
dtype: torch.dtype = offsets.get_dtype()
|
|
|
|
# pyre-ignore[2,3]
|
|
def inner_fn(index):
|
|
idx = index[0]
|
|
bucket = ops.bucketize(
|
|
values=ops.index_expr(idx, dtype),
|
|
boundaries=(
|
|
offsets.get_name(),
|
|
offsets.get_size()[-1],
|
|
offsets.get_size()[0] * offsets.get_stride()[0],
|
|
offsets.get_stride()[-1],
|
|
),
|
|
boundary_indices=0,
|
|
indexing_dtype=dtype,
|
|
right=True,
|
|
)
|
|
# ops.bucketize above returns 1-based bucket indices,
|
|
# but we need 0-based, hence we subtract 1 from batch
|
|
return bucket - 1
|
|
|
|
inverse_offsets = Pointwise.create(
|
|
device=device,
|
|
dtype=dtype,
|
|
inner_fn=inner_fn,
|
|
ranges=[jagged_len],
|
|
)
|
|
|
|
if realize:
|
|
# "freeze" the node so that it doesn't get inlined downstream.
|
|
inverse_offsets.realize()
|
|
|
|
# cache inverse_offsets for further reuse
|
|
offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined]
|
|
|
|
return inverse_offsets
|
|
|
|
|
|
def jagged_idx_to_dense_idx(
|
|
jagged_idx, # pyre-ignore[2]
|
|
inverse_offsets_loader, # pyre-ignore[2]
|
|
offsets_loader, # pyre-ignore[2]
|
|
batch_size: Union[int, sympy.Expr],
|
|
max_seq_len: Union[int, sympy.Expr],
|
|
offsets_dtype: torch.dtype,
|
|
) -> tuple[sympy.Expr, sympy.Expr]:
|
|
batch_idx = ops.indirect_indexing(
|
|
inverse_offsets_loader([jagged_idx]),
|
|
batch_size + 1,
|
|
)
|
|
batch_start = offsets_loader([batch_idx])
|
|
seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start
|
|
# check=False because there may be sequences longer than max_seq_len
|
|
seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False)
|
|
return batch_idx, seq_idx
|
|
|
|
|
|
def register_jagged_ops():
|
|
# Avoid circular import by importing here
|
|
from .lowering import fallback_handler, is_integer_type, register_lowering
|
|
|
|
# pyre-ignore[56]
|
|
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
|
|
def _jagged_to_padded_dense_forward(
|
|
jagged_values: TensorBox,
|
|
jagged_offsets: list[TensorBox],
|
|
max_lengths: list[int], # list of ints/SymInts
|
|
padding_value: float = 0.0,
|
|
) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
|
device = jagged_values.get_device_or_error()
|
|
dtype = jagged_values.get_dtype()
|
|
|
|
jagged_values_size = jagged_values.get_size()
|
|
|
|
# only handle the common case of a single jagged dimension
|
|
if (
|
|
len(jagged_offsets) != 1
|
|
or device.type != "cuda"
|
|
or device != jagged_offsets[0].get_device()
|
|
or len(jagged_values_size) != 2
|
|
or len(jagged_offsets[0].get_size()) != 1
|
|
or len(max_lengths) != len(jagged_offsets)
|
|
or not is_integer_type(jagged_offsets[0])
|
|
):
|
|
return fallback_handler(
|
|
torch.ops.aten._jagged_to_padded_dense_forward.default,
|
|
add_to_fallback_set=False,
|
|
)(
|
|
jagged_values,
|
|
jagged_offsets,
|
|
max_lengths,
|
|
padding_value,
|
|
)
|
|
|
|
offsets: TensorBox = jagged_offsets[0]
|
|
offsets_len = offsets.get_size()[0]
|
|
offsets_dtype = offsets.get_dtype()
|
|
batch_size = offsets_len - 1
|
|
max_seq_len = max_lengths[0]
|
|
embedding_len = jagged_values_size[1]
|
|
jagged_len = jagged_values_size[0]
|
|
|
|
output_size = [batch_size, max_seq_len, embedding_len]
|
|
|
|
values_loader = jagged_values.make_loader()
|
|
offsets_loader = offsets.make_loader()
|
|
|
|
# pyre-ignore[2,3,53]
|
|
def inner_fn(index):
|
|
# dense tensor size: [B, N, D]
|
|
batch_idx, seq_idx, emb_idx = index
|
|
jagged_idx, end_idx = dense_idx_to_jagged_idx(
|
|
batch_idx=batch_idx,
|
|
seq_idx=seq_idx,
|
|
offsets_loader=offsets_loader,
|
|
jagged_len=jagged_len,
|
|
)
|
|
return ops.masked(
|
|
ops.lt(
|
|
ops.index_expr(jagged_idx, offsets_dtype),
|
|
end_idx,
|
|
),
|
|
lambda: values_loader([jagged_idx, emb_idx]),
|
|
padding_value,
|
|
)
|
|
|
|
return Pointwise.create(
|
|
device=device,
|
|
dtype=dtype,
|
|
inner_fn=inner_fn,
|
|
ranges=output_size,
|
|
)
|
|
|
|
def _dense_to_jagged_forward_impl(
|
|
fallback_op, # pyre-ignore[2]
|
|
dense: TensorBox,
|
|
jagged_offsets: list[TensorBox],
|
|
jagged_len: Optional[int] = None,
|
|
) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
|
device = dense.get_device_or_error()
|
|
dtype = dense.get_dtype()
|
|
|
|
dense_size = dense.get_size()
|
|
|
|
# only handle the common case of a single jagged dimension
|
|
if (
|
|
len(jagged_offsets) != 1
|
|
or device.type != "cuda"
|
|
or device != jagged_offsets[0].get_device()
|
|
or len(jagged_offsets[0].get_size()) != 1
|
|
or len(dense_size) != 3
|
|
or jagged_len is None
|
|
or not is_integer_type(jagged_offsets[0])
|
|
):
|
|
return fallback_handler(fallback_op, add_to_fallback_set=False)(
|
|
dense,
|
|
jagged_offsets,
|
|
jagged_len,
|
|
)
|
|
|
|
offsets: TensorBox = jagged_offsets[0]
|
|
offsets_dtype = offsets.get_dtype()
|
|
batch_size = dense_size[0]
|
|
max_seq_len = dense_size[1]
|
|
embedding_len = dense_size[-1]
|
|
|
|
output_size = [jagged_len, embedding_len]
|
|
|
|
dense_loader = dense.make_loader()
|
|
offsets_loader = offsets.make_loader()
|
|
|
|
inverse_offsets = get_inverse_offsets(
|
|
offsets=offsets,
|
|
jagged_len=jagged_len,
|
|
)
|
|
inverse_offsets_loader = inverse_offsets.make_loader()
|
|
|
|
# pyre-ignore[2,3,53]
|
|
def inner_fn(index):
|
|
# jagged tensor size: [sum_B(N_B), D]
|
|
jagged_idx, emb_idx = index
|
|
batch_idx, seq_idx = jagged_idx_to_dense_idx(
|
|
jagged_idx=jagged_idx,
|
|
offsets_loader=offsets_loader,
|
|
inverse_offsets_loader=inverse_offsets_loader,
|
|
batch_size=batch_size,
|
|
max_seq_len=max_seq_len,
|
|
offsets_dtype=offsets_dtype,
|
|
)
|
|
return ops.masked(
|
|
ops.lt(
|
|
ops.index_expr(seq_idx, offsets_dtype),
|
|
ops.index_expr(max_seq_len, offsets_dtype),
|
|
),
|
|
lambda: dense_loader([batch_idx, seq_idx, emb_idx]),
|
|
0.0, # jagged sequence longer than max_seq_len
|
|
)
|
|
|
|
return Pointwise.create(
|
|
device=device,
|
|
dtype=dtype,
|
|
inner_fn=inner_fn,
|
|
ranges=output_size,
|
|
)
|
|
|
|
# pyre-ignore[56]
|
|
@register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
|
|
def _dense_to_jagged_forward(
|
|
dense: TensorBox,
|
|
jagged_offsets: list[TensorBox],
|
|
jagged_len: Optional[int] = None,
|
|
) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
|
return _dense_to_jagged_forward_impl(
|
|
fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default,
|
|
dense=dense,
|
|
jagged_offsets=jagged_offsets,
|
|
jagged_len=jagged_len,
|
|
)
|