Add lowering for aten.searchsorted (#135701)

Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
This commit is contained in:
Benjamin Glass
2024-10-03 21:11:19 +00:00
committed by PyTorch MergeBot
parent 58ec6a360c
commit a968576777
13 changed files with 498 additions and 66 deletions

View File

@ -32,15 +32,20 @@ class TestDependencies(InductorTestCase):
self._stack.close()
super().tearDown()
def test_bucketize_dependencies(self):
def test_bucketize_dependencies_no_sorter(self):
offsets = self._create_buffer("offsets", (1025,), torch.int32)
def inner_fn(index):
idx = index[0]
return ops.bucketize(
values=idx,
offsets_name=offsets.get_name(),
offsets_size=offsets.get_size()[0],
boundaries=(
offsets.get_name(),
offsets.get_size()[-1],
offsets.get_size()[0] * offsets.get_stride()[0],
offsets.get_stride()[-1],
),
boundary_indices=0,
indexing_dtype=torch.int32,
right=True,
)
@ -54,6 +59,39 @@ class TestDependencies(InductorTestCase):
self.assertEqual(len(pointwise.get_reads()), 1)
def test_bucketize_dependencies_sorter(self):
offsets = self._create_buffer("offsets", (1025,), torch.int32)
sorter = self._create_buffer("sorter", (1025,), torch.int32)
def inner_fn(index):
idx = index[0]
return ops.bucketize(
values=idx,
boundaries=(
offsets.get_name(),
offsets.get_size()[-1],
offsets.get_size()[0] * offsets.get_stride()[0],
offsets.get_stride()[-1],
),
boundary_indices=0,
indexing_dtype=torch.int32,
right=True,
sorter=(
sorter.get_name(),
sorter.get_stride()[-1],
),
sorter_indices=0,
)
pointwise = Pointwise.create(
device=torch.device(GPU_TYPE),
dtype=torch.int32,
inner_fn=inner_fn,
ranges=[1024 * 4],
)
self.assertEqual(len(pointwise.get_reads()), 2)
def test_get_offset(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")

View File

@ -10437,6 +10437,53 @@ class CommonTemplate:
self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False)
@xfail_if_triton_cpu
def test_searchsorted(self):
def fn(sorted_sequence, values, out_int32, right, side, sorter):
return torch.searchsorted(
sorted_sequence,
values,
out_int32=out_int32,
right=right,
side=side,
sorter=sorter,
)
shapes = (
((1,), (16, 16)), # scalar sorted_sequence
((16,), ()), # scalar values
((32,), (16, 16)), # 1-D sorted_sequence
((16, 32), (16, 16)), # N-D sorted_sequence
((3, 5), (3, 7)), # prime dimensioned sequence, to flush out indexing bugs
)
booleans = (False, True)
for (seq_shape, value_shape), out_int32, right in itertools.product(
shapes, booleans, booleans
):
unsorted_sequence = torch.rand(seq_shape)
sorted_sequence, sorting_indices = torch.sort(unsorted_sequence)
values = torch.rand(value_shape)
side = "right" if right else "left"
self.common(
fn,
(sorted_sequence, values, out_int32, right, side, None),
check_lowp=False,
)
self.common(
fn,
(
unsorted_sequence,
values,
out_int32,
right,
side,
sorting_indices,
),
check_lowp=False,
)
def test_bucketize(self):
def fn(input, boundaries, out_int32, right):
return torch.bucketize(input, boundaries, out_int32=out_int32, right=right)
@ -11511,7 +11558,7 @@ class CommonTemplate:
@dataclasses.dataclass
class TestFailure:
suffixes: Tuple[str]
suffixes: Tuple[str, ...]
is_skip: bool = False
__test__: bool = False

View File

@ -120,7 +120,7 @@ test_failures = {
"test_stack_dynamic_shapes": TestFailure(("cpu",)),
"test_tensor2_dynamic_shapes": TestFailure(("cpu",)),
"test_tensor3_dynamic_shapes": TestFailure(("cpu",)),
"test_to_device_constant_dynamic_shapes": TestFailure("cpu"),
"test_to_device_constant_dynamic_shapes": TestFailure(("cpu",)),
"test_upsample_nearest2d_backward_dynamic_shapes": TestFailure(("cpu",)),
"test_views3_dynamic_shapes": TestFailure(("cpu",)),
"test_views4_dynamic_shapes": TestFailure(("cpu",)),
@ -160,9 +160,10 @@ test_failures = {
"test_empty1_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_bucketize_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_dynamic_shapes": TestFailure(("cpu",)),
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure(("cpu",)),
"test_bucketize_int_dynamic_shapes": TestFailure(("cpu",)),
"test_searchsorted_dynamic_shapes": TestFailure(("cpu",)),
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
@ -247,7 +248,7 @@ test_failures = {
"test_views5_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_view_detach_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_view_on_aliased_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linear_float64_dynamic_shapes": TestFailure("cpu"),
"test_linear_float64_dynamic_shapes": TestFailure(("cpu",)),
"test_adaptive_avg_pool_with_output_size_0_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu")
),

View File

@ -1780,10 +1780,12 @@ class Kernel(CodeGen):
def bucketize(
self,
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: CSEVariable,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[CSEVariable] = None,
) -> CSEVariable:
"""
See [Note: Inductor bucketize op]
@ -2035,27 +2037,81 @@ class Kernel(CodeGen):
@staticmethod
def bucketize(
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: CSEVariable,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[CSEVariable] = None,
) -> CSEVariable:
"""
[Note: Inductor bucketize op]
Given values (tensor) and offsets_name (reference to the name of a 1D
tensor), calculate the bucket that each value belongs to.
Inputs:
-------
values: the values to be bucketized.
boundaries: a tuple containing
(a) the name of the boundaries tensor (which must be sorted, unless
the sorting tensor is present),
(b) the length of the tensor in the last dimension (i.e. the length of
one set of boundaries),
(c) the number of elements in the underlying storage (i.e. the length
of the flattened tensor, ignoring striding), and
(d) the stride of the tensor in the last dimension.
boundary_indices: indices into a flattened version of the boundaries
tensor, of the same size and shape as "values". Each index points to
the first element in the set of boundaries to be used for the
corresponding value.
indexing_dtype: the dtype to use when indexing into the boundaries
tensor. This must be int64 or int32. This additionally specifies the
dtype of the return value.
right: see "Details" below.
sorter: an optional tuple containing
(a) the name of an optional sorting tensor, used to access unsorted
boundaries without reordering the boundaries tensor, and
(b) the stride of the tensor in the last dimension.
The values in the sorting tensor are used as indices into the *last*
dimension of the boundaries tensor, with all other indices matching.
The size of the sorting and boundaries tensors must be equivalent.
sorter_indices: must be present if the sorting array is present; see
"boundary_indices" for the equivalent definition for the boundaries
tensor.
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
Output:
-------
The buckets each value belongs in, within a given set of boundaries. 0
indicates a position before the first boundary, and len(boundaries_set)
represents a position after the last boundary.
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
Details:
--------
Given a value and a set of boundaries, calculate the bucket that each
value belongs to. This works differently in 1-D and N-D cases.
Offsets must be non-decreasing or the result is undefined.
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True
return = [[ 0, 1, 1, 1], [1, 3, 3, 4]].
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True
return = [[ 0, 1, 1, 1], [0, 1, 1, 2]]
Note that in the N-D boundaries case, the shape of "values" and
"boundaries" must match in every dimension _except_ the last.
When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]].
When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]).
Boundaries must be non-decreasing, or a sorter must be provided which
would re-index offsets in a non-decreasing order (e.g. the second output
of torch.sort(offsets)). Otherwise, the result is undefined.
"""
return self.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
# Use mypy to check protocol implemented correctly

View File

@ -1952,10 +1952,12 @@ class TritonKernel(SIMDKernel):
def bucketize(
self,
values: CSEVariable,
offsets_name: str,
offsets_size: sympy.Expr,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: CSEVariable,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[CSEVariable] = None,
) -> CSEVariable:
"""
See [Note: Inductor bucketize op]
@ -1967,9 +1969,13 @@ class TritonKernel(SIMDKernel):
# autotuning config with num_elements_per_warp=(warp_size) exists.
self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD)
offsets_ptr = self.args.input(offsets_name)
boundaries_ptr = self.args.input(boundaries[0])
boundary_size = self.index_to_str(boundaries[1])
boundaries_underlying_numel = self.index_to_str(boundaries[2])
boundary_stride = self.index_to_str(boundaries[3])
sorter_ptr = self.args.input(sorter[0]) if sorter else "None"
sorter_stride = self.index_to_str(sorter[1]) if sorter else "None"
block_size = self.dense_size_str()
offsets_size_str = self.index_to_str(offsets_size)
if indexing_dtype == torch.int32:
triton_dtype = "tl.int32"
@ -1982,7 +1988,15 @@ class TritonKernel(SIMDKernel):
result = self.cse.generate(
self.compute,
f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long
f"triton_helpers.bucketize_binary_search({values}, "
f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, "
f"{boundary_indices}, "
f"{triton_dtype}, "
f"{right}, "
f"{sorter_ptr}, {sorter_stride}, "
f"{sorter_indices}, "
f"{block_size}, "
")",
)
return result

View File

@ -996,3 +996,23 @@ def adaptive_max_pool2d(
return aten.max_pool2d_with_indices(x, kernel_size)
return NotImplemented
@register_decomposition(aten.searchsorted.Scalar)
def searchsorted_scalar(
sorted_sequence: torch.Tensor,
self: torch.types.Number,
*,
out_int32: bool = False,
right: bool = False,
side: Optional[str] = None,
sorter: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return aten.searchsorted(
sorted_sequence,
torch.tensor([self], device=sorted_sequence.device),
out_int32=out_int32,
right=right,
side=side,
sorter=sorter,
)[0]

View File

@ -5,7 +5,7 @@ import itertools
import logging
import re
import typing
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
from unittest.mock import patch
import sympy
@ -26,6 +26,8 @@ from .utils import (
from .virtualized import OpsHandler, ReductionType, V
T = TypeVar("T")
log = logging.getLogger(__name__)
is_indirect = re.compile(r"indirect|tmp").search
@ -506,14 +508,18 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
def bucketize(
self,
values,
offsets_name: str,
offsets_size: sympy.Expr,
values: T,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
):
self._reads.add(StarDep(offsets_name))
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> None:
"""Records the names of the buffers that bucketize will read from."""
self._reads.add(StarDep(boundaries[0]))
if sorter is not None:
self._reads.add(StarDep(sorter[0]))
class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
@ -592,8 +598,10 @@ def extract_read_writes(
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
inner.index_expr(name_to_index[entry.index_name], None)
for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
# All that matters is that we record the buffer name, so place it in the
# "boundaries" name position to ensure that it's recorded.
inner.bucketize(
None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type]
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
)
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
else:

View File

@ -59,8 +59,13 @@ def get_inverse_offsets(
idx = index[0]
bucket = ops.bucketize(
values=ops.index_expr(idx, dtype),
offsets_name=offsets.get_name(),
offsets_size=offsets.get_size()[0],
boundaries=(
offsets.get_name(),
offsets.get_size()[-1],
offsets.get_size()[0] * offsets.get_stride()[0],
offsets.get_stride()[-1],
),
boundary_indices=0,
indexing_dtype=dtype,
right=True,
)

View File

@ -5,7 +5,17 @@ import functools
import itertools
import re
from enum import auto, Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
TypeVar,
)
import sympy
@ -20,6 +30,9 @@ from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
T = TypeVar("T")
class InterpreterShim(torch.fx.Interpreter):
@staticmethod
@functools.lru_cache(None)
@ -479,17 +492,51 @@ class LoopBodyBlock:
def bucketize(
self,
values,
offsets_name: str,
offsets_size: sympy.Expr,
values: T,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
):
offsets_size = add_index(
offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
boundaries = (
boundaries[0],
add_index(
boundaries[1],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[2],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
add_index(
boundaries[3],
MemoryUsageType.BUCKETIZE,
buffer_name=boundaries[0],
),
)
if sorter is not None:
sorter = (
sorter[0],
add_index(
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
),
)
return self._inner.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
@staticmethod

View File

@ -2078,6 +2078,113 @@ def inductor_randint(
)
def _boundaries_helper(tb: TensorBox) -> Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
return (
tb.get_name(),
tb.get_size()[-1],
tb.get_size()[0] * tb.get_stride()[0],
tb.get_stride()[-1],
)
def _sorter_helper(tb: TensorBox) -> Tuple[str, sympy.Expr]:
return tb.get_name(), tb.get_stride()[-1]
@register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
def searchsorted(
sorted_sequence: TensorBox,
self: TensorBox,
*,
out_int32: bool = False,
right: bool = False,
side: Optional[str] = None,
sorter: Optional[TensorBox] = None,
) -> TensorBox:
validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
tb, BackendFeature.BUCKETIZE
)
if (
not validate_bucketize(sorted_sequence)
or not validate_bucketize(self)
or (sorter is not None and not validate_bucketize(sorter))
):
return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
sorted_sequence,
self,
out_int32=out_int32,
right=right,
side=side,
sorter=sorter,
)
# If side is present, override the value of right if needed. This assumes that
# validation of the two options being non-contradictory is already done by the
# searchsorted meta-function.
if side is not None and side == "right":
right = True
index_dtype = torch.int32 if out_int32 else torch.int64
values_loader = self.make_loader()
# The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
# realize it into global memory; or in other words, we can't guarantee that
# sorted_sequence.get_name() (used below) will exist unless we call
# sorted_sequence.realize().
sorted_sequence.realize()
if sorter is not None:
sorter.realize()
if len(sorted_sequence.get_size()) == 1:
def inner_fn(idx):
val = values_loader(idx)
return ops.bucketize(
val,
_boundaries_helper(sorted_sequence),
0,
index_dtype,
right,
sorter=None if sorter is None else _sorter_helper(sorter),
sorter_indices=None if sorter is None else 0,
)
else:
def inner_fn(idx):
val = values_loader(idx)
# Get index to the beginning of the sorted sequence within a flattened
# version of the array.
def get_flattened_index(tb: TensorBox):
strides = tb.get_stride()
return ops.index_expr(
functools.reduce(
operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
),
index_dtype,
)
return ops.bucketize(
val,
_boundaries_helper(sorted_sequence),
get_flattened_index(sorted_sequence),
index_dtype,
right,
sorter=None if sorter is None else _sorter_helper(sorter),
sorter_indices=None if sorter is None else get_flattened_index(sorter),
)
device = self.get_device()
return Pointwise.create(
device=device,
dtype=index_dtype,
inner_fn=inner_fn,
ranges=self.shape,
)
@register_lowering(aten.bucketize, type_promotion_kind=None)
def bucketize(
input: TensorBox,
@ -2101,7 +2208,6 @@ def bucketize(
# guarantee that boundaries.get_name() (used below) will exist unless
# we call boundaries.realize().
boundaries.realize()
boundaries_size = boundaries.get_size()[0]
device = input.get_device()
input_loader = input.make_loader()
@ -2111,8 +2217,8 @@ def bucketize(
val = input_loader(index)
indices = ops.bucketize(
val,
boundaries.get_name(),
boundaries_size,
_boundaries_helper(boundaries),
0,
index_dtype,
right,
)
@ -2246,7 +2352,6 @@ make_fallback(aten.uniform, warn=False)
make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl)
# 1.5) Easy or Impossible

View File

@ -294,10 +294,12 @@ class OpsHandler(Protocol[T]):
def bucketize(
self,
values: T,
offsets_name: str,
offsets_size: sympy.Expr,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
# See [Note: Inductor bucketize op]
...
@ -1016,18 +1018,31 @@ class OpCounterCSE:
def bucketize(
self,
values,
offsets_name: str,
offsets_size: sympy.Expr,
values: T,
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: T,
indexing_dtype: torch.dtype,
right: bool,
):
sorter: Optional[Tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> T:
"""
See [Note: Inductor bucketize op]
"""
val = self.parent_handler.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
values,
boundaries,
boundary_indices,
indexing_dtype,
right,
sorter,
sorter_indices,
)
if val not in self.var_names:
self._used_ops.add("bucketize")
self._read_names.append(offsets_name)
self._read_names.append(boundaries[0])
if sorter is not None:
self._read_names.append(sorter[0])
return self._update_count(val)
def getvalue(self):

View File

@ -226,25 +226,71 @@ def any(a, dim):
@triton.jit
def bucketize_binary_search(
values, # 1D tensor
offsets_ptr,
indexing_dtype,
right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
OFFSETS_SIZE: int,
BLOCK_SHAPE, # tuple/list of block shape
values: tl.tensor,
boundaries_ptr: tl.tensor,
BOUNDARIES_SIZE: int,
BOUNDARIES_UNDERLYING_NUMEL: int,
BOUNDARIES_STRIDE: int,
boundary_indices: tl.tensor,
indexing_dtype: tl.dtype,
right: "bool", # triton can't handle the unquoted bool annotation
sorter_ptr: tl.tensor,
SORTER_STRIDE: int,
sorter_indices: tl.tensor,
BLOCK_SHAPE,
):
"""
See [Note: Inductor bucketize op]
Inputs:
-------
values: the values to bucketize.
boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D.
BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one
individual set of boundaries).
BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring
any striding.
BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor
boundary_indices: a tensor of the same size as "values"; each element is an index
into a 1-D, un-strided boundaries tensor, pointing to the first element in the set
of boundaries used for that value.
indexing_dtype: the dtype used for indexing into the boundaries tensor, and the
return dtype.
right: if true, use boundary intervals closed on the left; otherwise use intervals
closed on the right.
sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries,
but potentially different striding. If present, this allows us to treat boundaries
as sorted even if the elements of boundaries are unsorted.
SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last
dimension of the sorter tensor.
sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices".
BLOCK_SHAPE: the shape of the data block being processed.
"""
low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
high = tl.full(BLOCK_SHAPE, BOUNDARIES_SIZE, dtype=indexing_dtype)
full_range = OFFSETS_SIZE + 1
full_range = BOUNDARIES_SIZE + 1
while full_range > 1:
mid = (high + low) // 2
mask = mid < OFFSETS_SIZE
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)
mask = (
mid * BOUNDARIES_STRIDE + boundary_indices
) < BOUNDARIES_UNDERLYING_NUMEL and mid < BOUNDARIES_SIZE
mid_indices = (
mid
if sorter_ptr is None or SORTER_STRIDE is None
else tl.load(
sorter_ptr + sorter_indices + SORTER_STRIDE * mid,
mask=mask,
other=0,
)
)
bucket_upper_bound = tl.load(
boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices,
mask=mask,
other=0,
)
if right:
is_above = values >= bucket_upper_bound
else:

View File

@ -6247,6 +6247,36 @@ def meta_searchsorted(
side=None,
sorter=None,
):
# If the sorted_sequence is not one-dimensional, its shape must match that of values
# in all but the last dimension.
torch._check(
len(sorted_sequence.shape) <= 1
or sorted_sequence.shape[:-1] == self.shape[:-1],
lambda: (
"torch.searchsorted(): boundaries tensor should be 1 dimension or the "
"first N-1 dimensions of boundaries tensor and input value tensor must "
f"match, but we got boundaries tensor {list(sorted_sequence.shape)} and "
f"input value tensor {list(self.shape)}"
),
)
# If a sorter array is provided, its dimensions must exactly match sorted_sequence.
torch._check(
sorter is None or sorted_sequence.shape == sorter.shape,
lambda: (
"torch.searchsorted(): boundary and sorter must have the same size, but "
f"got boundary tensor {list(sorted_sequence.shape)} and got sorter tensor "
f"{list(sorter.shape) if sorter is not None else []}"
),
)
# Per the docs, if side == "left" and right is True, we error.
torch._check(
side != "left" or not right,
"torch.searchsorted(): side and right can't be set to opposites, got side of "
"left while right was True",
)
dtype = torch.int32 if out_int32 else torch.int64
if isinstance(self, torch.Tensor):
return torch.empty_like(self, dtype=dtype).contiguous()