reland Add offsets-based reduction to segment_reduce (CPU, CUDA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79725

Approved by: https://github.com/george-qi
This commit is contained in:
Mikayla Gawarecki
2022-06-16 17:32:10 +00:00
committed by PyTorch MergeBot
parent 64f3742b2b
commit 7360b53ff3
9 changed files with 491 additions and 175 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: scatter & gather ops"]
from itertools import product
from functools import partial
import numpy as np
import torch
@ -52,6 +53,11 @@ class TestSegmentReductions(TestCase):
lengths_dtype=torch.int,
):
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
# generate offsets from lengths
zeros_shape = list(lengths.shape)
zeros_shape[-1] = 1
offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
data = torch.tensor(
data_arr,
device=device,
@ -60,52 +66,56 @@ class TestSegmentReductions(TestCase):
)
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
actual_result = torch.segment_reduce(
data=data,
reduce=reduction,
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
)
self.assertEqual(
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
)
if not check_backward:
return
# Test backward
actual_result.sum().backward()
self.assertEqual(
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
)
# gradcheck does not work well with bfloat16 or fp16 cpu types
# also there is small numerical difference with fp32
if dtype not in [torch.half, torch.bfloat16, torch.float]:
# gradcheck does not like "nan" input, setting to random 10
d_non_nan = np.nan_to_num(data_arr, nan=10)
data = torch.tensor(
# [10 if v == float("nan") else v for v in data],
d_non_nan,
device=device,
dtype=dtype,
requires_grad=True,
for mode in ['lengths', 'offsets']:
segment_reduce_kwargs = dict(
axis=axis,
unsafe=unsafe,
initial=initial_value)
if (mode == 'lengths'):
segment_reduce_kwargs['lengths'] = lengths
else:
segment_reduce_kwargs['offsets'] = offsets
actual_result = torch.segment_reduce(
data=data,
reduce=reduction,
**segment_reduce_kwargs
)
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
data=x,
reduce=reduction,
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
),
(data,),
self.assertEqual(
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
)
if not check_backward:
return
# Test backward
actual_result.sum().backward()
self.assertEqual(
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
)
data = data.clone().detach().requires_grad_(True)
# gradcheck does not work well with bfloat16 or fp16 cpu types
# also there is small numerical difference with fp32
if dtype not in [torch.half, torch.bfloat16, torch.float]:
# gradcheck does not like "nan" input, setting to random 10
d_non_nan = np.nan_to_num(data_arr, nan=10)
new_data = torch.tensor(
# [10 if v == float("nan") else v for v in data],
d_non_nan,
device=device,
dtype=dtype,
requires_grad=True,
)
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
data=x,
reduce=reduction,
**segment_reduce_kwargs
),
(new_data,),
)
)
)
@dtypes(
*product(
@ -384,8 +394,18 @@ class TestSegmentReductions(TestCase):
)
self.assertEqual(actual_result, expected)
# test offsets
actual_result = torch.segment_reduce(
data=data,
reduce=reduce,
offsets=indptr,
axis=dim,
unsafe=True,
)
self.assertEqual(actual_result, expected)
if val_dtype == torch.float64:
def fn(x):
def fn(x, mode='lengths'):
initial = 1
# supply initial values to prevent gradcheck from failing for 0 length segments
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
@ -393,8 +413,16 @@ class TestSegmentReductions(TestCase):
initial = 1000
elif reduce == 'max':
initial = -1000
return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial)
self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True))))
segment_reduce_args = {x, reduce}
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
if mode == 'lengths':
segment_reduce_kwargs[mode] = lengths
elif mode == 'offsets':
segment_reduce_kwargs[mode] = indptr
return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
@dtypes(
*product(