mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
reland Add offsets-based reduction to segment_reduce (CPU, CUDA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79725 Approved by: https://github.com/george-qi
This commit is contained in:
committed by
PyTorch MergeBot
parent
64f3742b2b
commit
7360b53ff3
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: scatter & gather ops"]
|
||||
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -52,6 +53,11 @@ class TestSegmentReductions(TestCase):
|
||||
lengths_dtype=torch.int,
|
||||
):
|
||||
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
|
||||
# generate offsets from lengths
|
||||
zeros_shape = list(lengths.shape)
|
||||
zeros_shape[-1] = 1
|
||||
offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
|
||||
|
||||
data = torch.tensor(
|
||||
data_arr,
|
||||
device=device,
|
||||
@ -60,52 +66,56 @@ class TestSegmentReductions(TestCase):
|
||||
)
|
||||
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
|
||||
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data,
|
||||
reduce=reduction,
|
||||
lengths=lengths,
|
||||
axis=axis,
|
||||
unsafe=unsafe,
|
||||
initial=initial_value,
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
|
||||
)
|
||||
|
||||
if not check_backward:
|
||||
return
|
||||
|
||||
# Test backward
|
||||
actual_result.sum().backward()
|
||||
self.assertEqual(
|
||||
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
|
||||
)
|
||||
|
||||
# gradcheck does not work well with bfloat16 or fp16 cpu types
|
||||
# also there is small numerical difference with fp32
|
||||
if dtype not in [torch.half, torch.bfloat16, torch.float]:
|
||||
# gradcheck does not like "nan" input, setting to random 10
|
||||
d_non_nan = np.nan_to_num(data_arr, nan=10)
|
||||
data = torch.tensor(
|
||||
# [10 if v == float("nan") else v for v in data],
|
||||
d_non_nan,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
for mode in ['lengths', 'offsets']:
|
||||
segment_reduce_kwargs = dict(
|
||||
axis=axis,
|
||||
unsafe=unsafe,
|
||||
initial=initial_value)
|
||||
if (mode == 'lengths'):
|
||||
segment_reduce_kwargs['lengths'] = lengths
|
||||
else:
|
||||
segment_reduce_kwargs['offsets'] = offsets
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data,
|
||||
reduce=reduction,
|
||||
**segment_reduce_kwargs
|
||||
)
|
||||
self.assertTrue(
|
||||
gradcheck(
|
||||
lambda x: torch.segment_reduce(
|
||||
data=x,
|
||||
reduce=reduction,
|
||||
lengths=lengths,
|
||||
axis=axis,
|
||||
unsafe=unsafe,
|
||||
initial=initial_value,
|
||||
),
|
||||
(data,),
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
|
||||
)
|
||||
|
||||
if not check_backward:
|
||||
return
|
||||
|
||||
# Test backward
|
||||
actual_result.sum().backward()
|
||||
self.assertEqual(
|
||||
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
|
||||
)
|
||||
data = data.clone().detach().requires_grad_(True)
|
||||
|
||||
# gradcheck does not work well with bfloat16 or fp16 cpu types
|
||||
# also there is small numerical difference with fp32
|
||||
if dtype not in [torch.half, torch.bfloat16, torch.float]:
|
||||
# gradcheck does not like "nan" input, setting to random 10
|
||||
d_non_nan = np.nan_to_num(data_arr, nan=10)
|
||||
new_data = torch.tensor(
|
||||
# [10 if v == float("nan") else v for v in data],
|
||||
d_non_nan,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
self.assertTrue(
|
||||
gradcheck(
|
||||
lambda x: torch.segment_reduce(
|
||||
data=x,
|
||||
reduce=reduction,
|
||||
**segment_reduce_kwargs
|
||||
),
|
||||
(new_data,),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@dtypes(
|
||||
*product(
|
||||
@ -384,8 +394,18 @@ class TestSegmentReductions(TestCase):
|
||||
)
|
||||
self.assertEqual(actual_result, expected)
|
||||
|
||||
# test offsets
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data,
|
||||
reduce=reduce,
|
||||
offsets=indptr,
|
||||
axis=dim,
|
||||
unsafe=True,
|
||||
)
|
||||
self.assertEqual(actual_result, expected)
|
||||
|
||||
if val_dtype == torch.float64:
|
||||
def fn(x):
|
||||
def fn(x, mode='lengths'):
|
||||
initial = 1
|
||||
# supply initial values to prevent gradcheck from failing for 0 length segments
|
||||
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
|
||||
@ -393,8 +413,16 @@ class TestSegmentReductions(TestCase):
|
||||
initial = 1000
|
||||
elif reduce == 'max':
|
||||
initial = -1000
|
||||
return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial)
|
||||
self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True))))
|
||||
segment_reduce_args = {x, reduce}
|
||||
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
|
||||
if mode == 'lengths':
|
||||
segment_reduce_kwargs[mode] = lengths
|
||||
elif mode == 'offsets':
|
||||
segment_reduce_kwargs[mode] = indptr
|
||||
return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
|
||||
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
|
||||
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
|
||||
|
||||
|
||||
@dtypes(
|
||||
*product(
|
||||
|
||||
Reference in New Issue
Block a user