mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add lowering for aten.searchsorted (#135701)
Adds lowering for `aten.searchsorted`. This entails: 1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`. 2. Adding support for striding to `ops.bucketize`. 3. Adding support for sorting tensors to `ops.bucketize`. 4. Adding a lowering for `aten.searchsorted.Tensor`. 5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors. 6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions. Closes #135873 Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701 Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
58ec6a360c
commit
a968576777
@ -32,15 +32,20 @@ class TestDependencies(InductorTestCase):
|
||||
self._stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_bucketize_dependencies(self):
|
||||
def test_bucketize_dependencies_no_sorter(self):
|
||||
offsets = self._create_buffer("offsets", (1025,), torch.int32)
|
||||
|
||||
def inner_fn(index):
|
||||
idx = index[0]
|
||||
return ops.bucketize(
|
||||
values=idx,
|
||||
offsets_name=offsets.get_name(),
|
||||
offsets_size=offsets.get_size()[0],
|
||||
boundaries=(
|
||||
offsets.get_name(),
|
||||
offsets.get_size()[-1],
|
||||
offsets.get_size()[0] * offsets.get_stride()[0],
|
||||
offsets.get_stride()[-1],
|
||||
),
|
||||
boundary_indices=0,
|
||||
indexing_dtype=torch.int32,
|
||||
right=True,
|
||||
)
|
||||
@ -54,6 +59,39 @@ class TestDependencies(InductorTestCase):
|
||||
|
||||
self.assertEqual(len(pointwise.get_reads()), 1)
|
||||
|
||||
def test_bucketize_dependencies_sorter(self):
|
||||
offsets = self._create_buffer("offsets", (1025,), torch.int32)
|
||||
sorter = self._create_buffer("sorter", (1025,), torch.int32)
|
||||
|
||||
def inner_fn(index):
|
||||
idx = index[0]
|
||||
return ops.bucketize(
|
||||
values=idx,
|
||||
boundaries=(
|
||||
offsets.get_name(),
|
||||
offsets.get_size()[-1],
|
||||
offsets.get_size()[0] * offsets.get_stride()[0],
|
||||
offsets.get_stride()[-1],
|
||||
),
|
||||
boundary_indices=0,
|
||||
indexing_dtype=torch.int32,
|
||||
right=True,
|
||||
sorter=(
|
||||
sorter.get_name(),
|
||||
sorter.get_stride()[-1],
|
||||
),
|
||||
sorter_indices=0,
|
||||
)
|
||||
|
||||
pointwise = Pointwise.create(
|
||||
device=torch.device(GPU_TYPE),
|
||||
dtype=torch.int32,
|
||||
inner_fn=inner_fn,
|
||||
ranges=[1024 * 4],
|
||||
)
|
||||
|
||||
self.assertEqual(len(pointwise.get_reads()), 2)
|
||||
|
||||
def test_get_offset(self):
|
||||
x = sympy_index_symbol("x")
|
||||
y = sympy_index_symbol("y")
|
||||
|
@ -10437,6 +10437,53 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False)
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
def test_searchsorted(self):
|
||||
def fn(sorted_sequence, values, out_int32, right, side, sorter):
|
||||
return torch.searchsorted(
|
||||
sorted_sequence,
|
||||
values,
|
||||
out_int32=out_int32,
|
||||
right=right,
|
||||
side=side,
|
||||
sorter=sorter,
|
||||
)
|
||||
|
||||
shapes = (
|
||||
((1,), (16, 16)), # scalar sorted_sequence
|
||||
((16,), ()), # scalar values
|
||||
((32,), (16, 16)), # 1-D sorted_sequence
|
||||
((16, 32), (16, 16)), # N-D sorted_sequence
|
||||
((3, 5), (3, 7)), # prime dimensioned sequence, to flush out indexing bugs
|
||||
)
|
||||
booleans = (False, True)
|
||||
|
||||
for (seq_shape, value_shape), out_int32, right in itertools.product(
|
||||
shapes, booleans, booleans
|
||||
):
|
||||
unsorted_sequence = torch.rand(seq_shape)
|
||||
sorted_sequence, sorting_indices = torch.sort(unsorted_sequence)
|
||||
values = torch.rand(value_shape)
|
||||
|
||||
side = "right" if right else "left"
|
||||
self.common(
|
||||
fn,
|
||||
(sorted_sequence, values, out_int32, right, side, None),
|
||||
check_lowp=False,
|
||||
)
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
unsorted_sequence,
|
||||
values,
|
||||
out_int32,
|
||||
right,
|
||||
side,
|
||||
sorting_indices,
|
||||
),
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
def test_bucketize(self):
|
||||
def fn(input, boundaries, out_int32, right):
|
||||
return torch.bucketize(input, boundaries, out_int32=out_int32, right=right)
|
||||
@ -11511,7 +11558,7 @@ class CommonTemplate:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestFailure:
|
||||
suffixes: Tuple[str]
|
||||
suffixes: Tuple[str, ...]
|
||||
is_skip: bool = False
|
||||
__test__: bool = False
|
||||
|
||||
|
@ -120,7 +120,7 @@ test_failures = {
|
||||
"test_stack_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_tensor2_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_tensor3_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_to_device_constant_dynamic_shapes": TestFailure("cpu"),
|
||||
"test_to_device_constant_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_upsample_nearest2d_backward_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_views3_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_views4_dynamic_shapes": TestFailure(("cpu",)),
|
||||
@ -160,9 +160,10 @@ test_failures = {
|
||||
"test_empty1_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_bucketize_dynamic_shapes": TestFailure("cpu"),
|
||||
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
|
||||
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
|
||||
"test_bucketize_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_bucketize_int_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_searchsorted_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
@ -247,7 +248,7 @@ test_failures = {
|
||||
"test_views5_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_view_detach_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_view_on_aliased_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_linear_float64_dynamic_shapes": TestFailure("cpu"),
|
||||
"test_linear_float64_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_adaptive_avg_pool_with_output_size_0_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu")
|
||||
),
|
||||
|
@ -1780,10 +1780,12 @@ class Kernel(CodeGen):
|
||||
def bucketize(
|
||||
self,
|
||||
values: CSEVariable,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: CSEVariable,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[CSEVariable] = None,
|
||||
) -> CSEVariable:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
@ -2035,27 +2037,81 @@ class Kernel(CodeGen):
|
||||
@staticmethod
|
||||
def bucketize(
|
||||
values: CSEVariable,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: CSEVariable,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[CSEVariable] = None,
|
||||
) -> CSEVariable:
|
||||
"""
|
||||
[Note: Inductor bucketize op]
|
||||
|
||||
Given values (tensor) and offsets_name (reference to the name of a 1D
|
||||
tensor), calculate the bucket that each value belongs to.
|
||||
Inputs:
|
||||
-------
|
||||
values: the values to be bucketized.
|
||||
boundaries: a tuple containing
|
||||
(a) the name of the boundaries tensor (which must be sorted, unless
|
||||
the sorting tensor is present),
|
||||
(b) the length of the tensor in the last dimension (i.e. the length of
|
||||
one set of boundaries),
|
||||
(c) the number of elements in the underlying storage (i.e. the length
|
||||
of the flattened tensor, ignoring striding), and
|
||||
(d) the stride of the tensor in the last dimension.
|
||||
boundary_indices: indices into a flattened version of the boundaries
|
||||
tensor, of the same size and shape as "values". Each index points to
|
||||
the first element in the set of boundaries to be used for the
|
||||
corresponding value.
|
||||
indexing_dtype: the dtype to use when indexing into the boundaries
|
||||
tensor. This must be int64 or int32. This additionally specifies the
|
||||
dtype of the return value.
|
||||
right: see "Details" below.
|
||||
sorter: an optional tuple containing
|
||||
(a) the name of an optional sorting tensor, used to access unsorted
|
||||
boundaries without reordering the boundaries tensor, and
|
||||
(b) the stride of the tensor in the last dimension.
|
||||
The values in the sorting tensor are used as indices into the *last*
|
||||
dimension of the boundaries tensor, with all other indices matching.
|
||||
The size of the sorting and boundaries tensors must be equivalent.
|
||||
sorter_indices: must be present if the sorting array is present; see
|
||||
"boundary_indices" for the equivalent definition for the boundaries
|
||||
tensor.
|
||||
|
||||
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
||||
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
||||
Output:
|
||||
-------
|
||||
The buckets each value belongs in, within a given set of boundaries. 0
|
||||
indicates a position before the first boundary, and len(boundaries_set)
|
||||
represents a position after the last boundary.
|
||||
|
||||
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
||||
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
||||
Details:
|
||||
--------
|
||||
Given a value and a set of boundaries, calculate the bucket that each
|
||||
value belongs to. This works differently in 1-D and N-D cases.
|
||||
|
||||
Offsets must be non-decreasing or the result is undefined.
|
||||
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True
|
||||
return = [[ 0, 1, 1, 1], [1, 3, 3, 4]].
|
||||
|
||||
for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True
|
||||
return = [[ 0, 1, 1, 1], [0, 1, 1, 2]]
|
||||
|
||||
Note that in the N-D boundaries case, the shape of "values" and
|
||||
"boundaries" must match in every dimension _except_ the last.
|
||||
|
||||
When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]].
|
||||
When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]).
|
||||
|
||||
Boundaries must be non-decreasing, or a sorter must be provided which
|
||||
would re-index offsets in a non-decreasing order (e.g. the second output
|
||||
of torch.sort(offsets)). Otherwise, the result is undefined.
|
||||
"""
|
||||
return self.bucketize(
|
||||
values, offsets_name, offsets_size, indexing_dtype, right
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
|
@ -1952,10 +1952,12 @@ class TritonKernel(SIMDKernel):
|
||||
def bucketize(
|
||||
self,
|
||||
values: CSEVariable,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: CSEVariable,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[CSEVariable] = None,
|
||||
) -> CSEVariable:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
@ -1967,9 +1969,13 @@ class TritonKernel(SIMDKernel):
|
||||
# autotuning config with num_elements_per_warp=(warp_size) exists.
|
||||
self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD)
|
||||
|
||||
offsets_ptr = self.args.input(offsets_name)
|
||||
boundaries_ptr = self.args.input(boundaries[0])
|
||||
boundary_size = self.index_to_str(boundaries[1])
|
||||
boundaries_underlying_numel = self.index_to_str(boundaries[2])
|
||||
boundary_stride = self.index_to_str(boundaries[3])
|
||||
sorter_ptr = self.args.input(sorter[0]) if sorter else "None"
|
||||
sorter_stride = self.index_to_str(sorter[1]) if sorter else "None"
|
||||
block_size = self.dense_size_str()
|
||||
offsets_size_str = self.index_to_str(offsets_size)
|
||||
|
||||
if indexing_dtype == torch.int32:
|
||||
triton_dtype = "tl.int32"
|
||||
@ -1982,7 +1988,15 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
result = self.cse.generate(
|
||||
self.compute,
|
||||
f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long
|
||||
f"triton_helpers.bucketize_binary_search({values}, "
|
||||
f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, "
|
||||
f"{boundary_indices}, "
|
||||
f"{triton_dtype}, "
|
||||
f"{right}, "
|
||||
f"{sorter_ptr}, {sorter_stride}, "
|
||||
f"{sorter_indices}, "
|
||||
f"{block_size}, "
|
||||
")",
|
||||
)
|
||||
|
||||
return result
|
||||
|
@ -996,3 +996,23 @@ def adaptive_max_pool2d(
|
||||
return aten.max_pool2d_with_indices(x, kernel_size)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@register_decomposition(aten.searchsorted.Scalar)
|
||||
def searchsorted_scalar(
|
||||
sorted_sequence: torch.Tensor,
|
||||
self: torch.types.Number,
|
||||
*,
|
||||
out_int32: bool = False,
|
||||
right: bool = False,
|
||||
side: Optional[str] = None,
|
||||
sorter: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return aten.searchsorted(
|
||||
sorted_sequence,
|
||||
torch.tensor([self], device=sorted_sequence.device),
|
||||
out_int32=out_int32,
|
||||
right=right,
|
||||
side=side,
|
||||
sorter=sorter,
|
||||
)[0]
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -26,6 +26,8 @@ from .utils import (
|
||||
from .virtualized import OpsHandler, ReductionType, V
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
is_indirect = re.compile(r"indirect|tmp").search
|
||||
|
||||
@ -506,14 +508,18 @@ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
values: T,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
):
|
||||
self._reads.add(StarDep(offsets_name))
|
||||
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> None:
|
||||
"""Records the names of the buffers that bucketize will read from."""
|
||||
self._reads.add(StarDep(boundaries[0]))
|
||||
if sorter is not None:
|
||||
self._reads.add(StarDep(sorter[0]))
|
||||
|
||||
|
||||
class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
|
||||
@ -592,8 +598,10 @@ def extract_read_writes(
|
||||
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
||||
inner.index_expr(name_to_index[entry.index_name], None)
|
||||
for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
|
||||
# All that matters is that we record the buffer name, so place it in the
|
||||
# "boundaries" name position to ensure that it's recorded.
|
||||
inner.bucketize(
|
||||
None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type]
|
||||
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
|
||||
)
|
||||
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
||||
else:
|
||||
|
@ -59,8 +59,13 @@ def get_inverse_offsets(
|
||||
idx = index[0]
|
||||
bucket = ops.bucketize(
|
||||
values=ops.index_expr(idx, dtype),
|
||||
offsets_name=offsets.get_name(),
|
||||
offsets_size=offsets.get_size()[0],
|
||||
boundaries=(
|
||||
offsets.get_name(),
|
||||
offsets.get_size()[-1],
|
||||
offsets.get_size()[0] * offsets.get_stride()[0],
|
||||
offsets.get_stride()[-1],
|
||||
),
|
||||
boundary_indices=0,
|
||||
indexing_dtype=dtype,
|
||||
right=True,
|
||||
)
|
||||
|
@ -5,7 +5,17 @@ import functools
|
||||
import itertools
|
||||
import re
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import sympy
|
||||
|
||||
@ -20,6 +30,9 @@ from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InterpreterShim(torch.fx.Interpreter):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@ -479,17 +492,51 @@ class LoopBodyBlock:
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
values: T,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
):
|
||||
offsets_size = add_index(
|
||||
offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
boundaries = (
|
||||
boundaries[0],
|
||||
add_index(
|
||||
boundaries[1],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[2],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
add_index(
|
||||
boundaries[3],
|
||||
MemoryUsageType.BUCKETIZE,
|
||||
buffer_name=boundaries[0],
|
||||
),
|
||||
)
|
||||
if sorter is not None:
|
||||
sorter = (
|
||||
sorter[0],
|
||||
add_index(
|
||||
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
||||
),
|
||||
)
|
||||
|
||||
return self._inner.bucketize(
|
||||
values, offsets_name, offsets_size, indexing_dtype, right
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -2078,6 +2078,113 @@ def inductor_randint(
|
||||
)
|
||||
|
||||
|
||||
def _boundaries_helper(tb: TensorBox) -> Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
|
||||
return (
|
||||
tb.get_name(),
|
||||
tb.get_size()[-1],
|
||||
tb.get_size()[0] * tb.get_stride()[0],
|
||||
tb.get_stride()[-1],
|
||||
)
|
||||
|
||||
|
||||
def _sorter_helper(tb: TensorBox) -> Tuple[str, sympy.Expr]:
|
||||
return tb.get_name(), tb.get_stride()[-1]
|
||||
|
||||
|
||||
@register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
|
||||
def searchsorted(
|
||||
sorted_sequence: TensorBox,
|
||||
self: TensorBox,
|
||||
*,
|
||||
out_int32: bool = False,
|
||||
right: bool = False,
|
||||
side: Optional[str] = None,
|
||||
sorter: Optional[TensorBox] = None,
|
||||
) -> TensorBox:
|
||||
validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
|
||||
tb, BackendFeature.BUCKETIZE
|
||||
)
|
||||
if (
|
||||
not validate_bucketize(sorted_sequence)
|
||||
or not validate_bucketize(self)
|
||||
or (sorter is not None and not validate_bucketize(sorter))
|
||||
):
|
||||
return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
|
||||
sorted_sequence,
|
||||
self,
|
||||
out_int32=out_int32,
|
||||
right=right,
|
||||
side=side,
|
||||
sorter=sorter,
|
||||
)
|
||||
|
||||
# If side is present, override the value of right if needed. This assumes that
|
||||
# validation of the two options being non-contradictory is already done by the
|
||||
# searchsorted meta-function.
|
||||
if side is not None and side == "right":
|
||||
right = True
|
||||
|
||||
index_dtype = torch.int32 if out_int32 else torch.int64
|
||||
values_loader = self.make_loader()
|
||||
|
||||
# The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
|
||||
# realize it into global memory; or in other words, we can't guarantee that
|
||||
# sorted_sequence.get_name() (used below) will exist unless we call
|
||||
# sorted_sequence.realize().
|
||||
sorted_sequence.realize()
|
||||
|
||||
if sorter is not None:
|
||||
sorter.realize()
|
||||
|
||||
if len(sorted_sequence.get_size()) == 1:
|
||||
|
||||
def inner_fn(idx):
|
||||
val = values_loader(idx)
|
||||
return ops.bucketize(
|
||||
val,
|
||||
_boundaries_helper(sorted_sequence),
|
||||
0,
|
||||
index_dtype,
|
||||
right,
|
||||
sorter=None if sorter is None else _sorter_helper(sorter),
|
||||
sorter_indices=None if sorter is None else 0,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def inner_fn(idx):
|
||||
val = values_loader(idx)
|
||||
|
||||
# Get index to the beginning of the sorted sequence within a flattened
|
||||
# version of the array.
|
||||
def get_flattened_index(tb: TensorBox):
|
||||
strides = tb.get_stride()
|
||||
return ops.index_expr(
|
||||
functools.reduce(
|
||||
operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
|
||||
),
|
||||
index_dtype,
|
||||
)
|
||||
|
||||
return ops.bucketize(
|
||||
val,
|
||||
_boundaries_helper(sorted_sequence),
|
||||
get_flattened_index(sorted_sequence),
|
||||
index_dtype,
|
||||
right,
|
||||
sorter=None if sorter is None else _sorter_helper(sorter),
|
||||
sorter_indices=None if sorter is None else get_flattened_index(sorter),
|
||||
)
|
||||
|
||||
device = self.get_device()
|
||||
return Pointwise.create(
|
||||
device=device,
|
||||
dtype=index_dtype,
|
||||
inner_fn=inner_fn,
|
||||
ranges=self.shape,
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.bucketize, type_promotion_kind=None)
|
||||
def bucketize(
|
||||
input: TensorBox,
|
||||
@ -2101,7 +2208,6 @@ def bucketize(
|
||||
# guarantee that boundaries.get_name() (used below) will exist unless
|
||||
# we call boundaries.realize().
|
||||
boundaries.realize()
|
||||
boundaries_size = boundaries.get_size()[0]
|
||||
device = input.get_device()
|
||||
input_loader = input.make_loader()
|
||||
|
||||
@ -2111,8 +2217,8 @@ def bucketize(
|
||||
val = input_loader(index)
|
||||
indices = ops.bucketize(
|
||||
val,
|
||||
boundaries.get_name(),
|
||||
boundaries_size,
|
||||
_boundaries_helper(boundaries),
|
||||
0,
|
||||
index_dtype,
|
||||
right,
|
||||
)
|
||||
@ -2246,7 +2352,6 @@ make_fallback(aten.uniform, warn=False)
|
||||
make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
|
||||
make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
|
||||
make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
|
||||
make_fallback(aten.searchsorted) # bucketized is implemented (see eager impl)
|
||||
|
||||
|
||||
# 1.5) Easy or Impossible
|
||||
|
@ -294,10 +294,12 @@ class OpsHandler(Protocol[T]):
|
||||
def bucketize(
|
||||
self,
|
||||
values: T,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
# See [Note: Inductor bucketize op]
|
||||
...
|
||||
@ -1016,18 +1018,31 @@ class OpCounterCSE:
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
values: T,
|
||||
boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: T,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
):
|
||||
sorter: Optional[Tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
val = self.parent_handler.bucketize(
|
||||
values, offsets_name, offsets_size, indexing_dtype, right
|
||||
values,
|
||||
boundaries,
|
||||
boundary_indices,
|
||||
indexing_dtype,
|
||||
right,
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
if val not in self.var_names:
|
||||
self._used_ops.add("bucketize")
|
||||
self._read_names.append(offsets_name)
|
||||
self._read_names.append(boundaries[0])
|
||||
if sorter is not None:
|
||||
self._read_names.append(sorter[0])
|
||||
return self._update_count(val)
|
||||
|
||||
def getvalue(self):
|
||||
|
@ -226,25 +226,71 @@ def any(a, dim):
|
||||
|
||||
@triton.jit
|
||||
def bucketize_binary_search(
|
||||
values, # 1D tensor
|
||||
offsets_ptr,
|
||||
indexing_dtype,
|
||||
right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
|
||||
OFFSETS_SIZE: int,
|
||||
BLOCK_SHAPE, # tuple/list of block shape
|
||||
values: tl.tensor,
|
||||
boundaries_ptr: tl.tensor,
|
||||
BOUNDARIES_SIZE: int,
|
||||
BOUNDARIES_UNDERLYING_NUMEL: int,
|
||||
BOUNDARIES_STRIDE: int,
|
||||
boundary_indices: tl.tensor,
|
||||
indexing_dtype: tl.dtype,
|
||||
right: "bool", # triton can't handle the unquoted bool annotation
|
||||
sorter_ptr: tl.tensor,
|
||||
SORTER_STRIDE: int,
|
||||
sorter_indices: tl.tensor,
|
||||
BLOCK_SHAPE,
|
||||
):
|
||||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
|
||||
Inputs:
|
||||
-------
|
||||
values: the values to bucketize.
|
||||
boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D.
|
||||
BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one
|
||||
individual set of boundaries).
|
||||
BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring
|
||||
any striding.
|
||||
BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor
|
||||
boundary_indices: a tensor of the same size as "values"; each element is an index
|
||||
into a 1-D, un-strided boundaries tensor, pointing to the first element in the set
|
||||
of boundaries used for that value.
|
||||
indexing_dtype: the dtype used for indexing into the boundaries tensor, and the
|
||||
return dtype.
|
||||
right: if true, use boundary intervals closed on the left; otherwise use intervals
|
||||
closed on the right.
|
||||
sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries,
|
||||
but potentially different striding. If present, this allows us to treat boundaries
|
||||
as sorted even if the elements of boundaries are unsorted.
|
||||
SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last
|
||||
dimension of the sorter tensor.
|
||||
sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices".
|
||||
BLOCK_SHAPE: the shape of the data block being processed.
|
||||
"""
|
||||
|
||||
low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
|
||||
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
|
||||
high = tl.full(BLOCK_SHAPE, BOUNDARIES_SIZE, dtype=indexing_dtype)
|
||||
|
||||
full_range = OFFSETS_SIZE + 1
|
||||
full_range = BOUNDARIES_SIZE + 1
|
||||
while full_range > 1:
|
||||
mid = (high + low) // 2
|
||||
mask = mid < OFFSETS_SIZE
|
||||
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)
|
||||
mask = (
|
||||
mid * BOUNDARIES_STRIDE + boundary_indices
|
||||
) < BOUNDARIES_UNDERLYING_NUMEL and mid < BOUNDARIES_SIZE
|
||||
mid_indices = (
|
||||
mid
|
||||
if sorter_ptr is None or SORTER_STRIDE is None
|
||||
else tl.load(
|
||||
sorter_ptr + sorter_indices + SORTER_STRIDE * mid,
|
||||
mask=mask,
|
||||
other=0,
|
||||
)
|
||||
)
|
||||
|
||||
bucket_upper_bound = tl.load(
|
||||
boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices,
|
||||
mask=mask,
|
||||
other=0,
|
||||
)
|
||||
if right:
|
||||
is_above = values >= bucket_upper_bound
|
||||
else:
|
||||
|
@ -6247,6 +6247,36 @@ def meta_searchsorted(
|
||||
side=None,
|
||||
sorter=None,
|
||||
):
|
||||
# If the sorted_sequence is not one-dimensional, its shape must match that of values
|
||||
# in all but the last dimension.
|
||||
torch._check(
|
||||
len(sorted_sequence.shape) <= 1
|
||||
or sorted_sequence.shape[:-1] == self.shape[:-1],
|
||||
lambda: (
|
||||
"torch.searchsorted(): boundaries tensor should be 1 dimension or the "
|
||||
"first N-1 dimensions of boundaries tensor and input value tensor must "
|
||||
f"match, but we got boundaries tensor {list(sorted_sequence.shape)} and "
|
||||
f"input value tensor {list(self.shape)}"
|
||||
),
|
||||
)
|
||||
|
||||
# If a sorter array is provided, its dimensions must exactly match sorted_sequence.
|
||||
torch._check(
|
||||
sorter is None or sorted_sequence.shape == sorter.shape,
|
||||
lambda: (
|
||||
"torch.searchsorted(): boundary and sorter must have the same size, but "
|
||||
f"got boundary tensor {list(sorted_sequence.shape)} and got sorter tensor "
|
||||
f"{list(sorter.shape) if sorter is not None else []}"
|
||||
),
|
||||
)
|
||||
|
||||
# Per the docs, if side == "left" and right is True, we error.
|
||||
torch._check(
|
||||
side != "left" or not right,
|
||||
"torch.searchsorted(): side and right can't be set to opposites, got side of "
|
||||
"left while right was True",
|
||||
)
|
||||
|
||||
dtype = torch.int32 if out_int32 else torch.int64
|
||||
if isinstance(self, torch.Tensor):
|
||||
return torch.empty_like(self, dtype=dtype).contiguous()
|
||||
|
Reference in New Issue
Block a user