[pytorch] Add native support for segment reduce step1: API definition (#53727)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53727

This is first diff to add native support for segment reduction in PyTorch. It provides similar functionality like torch.scatter or "numpy.ufunc.reduceat".

This diff mainly focuses on API layer to make sure future improvements will not cause backward compatibility issues. Once API is settled, here are next steps I am planning:
- Add support for other major reduction types (e.g. min, sum) for 1D tensor
- Add Cuda support
- Backward support
- Documentation for the op
- Perf optimizations and benchmark util
- Support for multi dimensional tensors (on data and lengths) (not high priority)
- Support for 'indices' (not high priority)

Test Plan: Added unit test

Reviewed By: ngimel

Differential Revision: D26952075

fbshipit-source-id: 8040ec96def3013e7240cf675d499ee424437560
This commit is contained in:
Serhat Yilmaz
2021-03-23 15:56:00 -07:00
committed by Facebook GitHub Bot
parent 591084abb8
commit 7e3cf1ee24
7 changed files with 149 additions and 0 deletions

View File

@ -618,6 +618,7 @@ _(aten, rrelu_with_noise_forward) \
_(aten, rsqrt) \
_(aten, scatter) \
_(aten, scatter_add) \
_(aten, segment_reduce) \
_(aten, select) \
_(aten, selu) \
_(aten, set) \

View File

@ -0,0 +1,84 @@
#include <ATen/native/SegmentReduce.h>
#include <ATen/ATen.h>
#include <ATen/NumericUtils.h>
namespace at {
namespace native {
DEFINE_DISPATCH(segment_reduce_stub);
enum ReductionType { MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"max", MAX},
};
Tensor _segment_reduce_cpu(
const Tensor& data,
std::string reduce,
const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices,
int64_t axis,
bool unsafe) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
TORCH_CHECK(data.dim() == 1);
TORCH_CHECK(data.numel() > 0);
TORCH_CHECK(
reduce2REDUCE.at(reduce) == MAX,
"Currently only 'max' reduction is supported!");
// length related checks
TORCH_CHECK(
lengths.has_value() && !indices.has_value(),
"Currently only lengths based reduction is supported!")
const auto& lengths_value = lengths.value();
TORCH_CHECK(lengths_value.dim() == 1);
TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim());
const auto lengths_contig = lengths_value.contiguous();
const auto data_contig = data.contiguous();
int64_t batch_size = lengths_contig.numel();
auto output = at::empty({batch_size}, data.options());
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
if (!unsafe) {
int64_t sum = 0;
for (int64_t i = 0; i < batch_size; ++i) {
TORCH_CHECK(lengths_data[i] > 0);
sum += lengths_data[i];
}
TORCH_CHECK(sum == data.numel());
}
AT_DISPATCH_ALL_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
int64_t k = 0;
for (int64_t i = 0; i < batch_size; ++i) {
scalar_t reduction = std::numeric_limits<scalar_t>::lowest();
for (int64_t j = 0; j < lengths_data[i]; ++j) {
const auto data = values_data[k];
reduction =
at::_isnan(data) ? data : std::max<scalar_t>(reduction, data);
k++;
}
// If unsafe is false, check on lengths or indices should cover cases
// where lengths for a particular segment is non-positive. If unsafe
// is true, simply set to numerical limits for particular reduction
output_data[i] = reduction;
}
}));
return output;
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/Optional.h>
namespace at {
namespace native {
using segment_reduce_fn = void (*)(
const Tensor&,
std::string,
const c10::optional<Tensor>&,
const c10::optional<Tensor>&,
int64_t,
bool);
DECLARE_DISPATCH(segment_reduce_fn, segment_reduce_stub);
} // namespace native
} // namespace at

View File

@ -9111,3 +9111,8 @@
- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor
cpp_no_default_args: ['a', 'b']
python_module: nn
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
variants: function
dispatch:
CPU: _segment_reduce_cpu

View File

@ -0,0 +1,37 @@
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
dtypes,
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
class TestSegmentReductions(TestCase):
@onlyCPU
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_max_simple_1d(self, device, dtype):
lengths = torch.tensor([1, 2, 3], device=device)
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
instantiate_device_type_tests(TestSegmentReductions, globals())
if __name__ == "__main__":
run_tests()

View File

@ -923,6 +923,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/ReplicationPadding.cpp",
"aten/src/ATen/native/Resize.cpp",
"aten/src/ATen/native/RowwisePrune.cpp",
"aten/src/ATen/native/SegmentReduce.cpp",
"aten/src/ATen/native/Scalar.cpp",
"aten/src/ATen/native/SobolEngineOps.cpp",
"aten/src/ATen/native/SobolEngineOpsUtils.cpp",

View File

@ -794,6 +794,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.scatter: lambda input, dim, index, src: -1,
torch.scatter_add: lambda input, dim, index, src: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,
torch.selu: lambda input, inplace=False: -1,
torch.sigmoid: lambda input, out=None: -1,