mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add meta for _segment_reduce_backward (#137442)
reland of https://github.com/pytorch/pytorch/pull/124988 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137442 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
1aac1ffce1
commit
53af729a66
@ -1636,6 +1636,42 @@ class TestMeta(TestCase):
|
||||
)
|
||||
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
|
||||
|
||||
def test_segment_reduce_backward(self):
|
||||
grad = torch.ones(16, dtype=torch.float)
|
||||
output = torch.ones(16, dtype=torch.float)
|
||||
data = torch.ones(16, dtype=torch.float)
|
||||
reduce_str = 'max'
|
||||
lengths = torch.ones(16, dtype=torch.long)
|
||||
|
||||
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
|
||||
out_meta = torch.ops.aten._segment_reduce_backward(
|
||||
grad.to(device='meta'),
|
||||
output.to(device='meta'),
|
||||
data.to(device='meta'),
|
||||
reduce_str,
|
||||
lengths=lengths.to(device='meta'),
|
||||
)
|
||||
self.assertEqual(out.shape, out_meta.shape)
|
||||
self.assertEqual(out.stride(), out_meta.stride())
|
||||
self.assertEqual(out.dtype, out_meta.dtype)
|
||||
self.assertEqual(out.layout, out_meta.layout)
|
||||
|
||||
# noncontiguous
|
||||
grad = torch.ones(16, 2, dtype=torch.float)[:, 1]
|
||||
data = torch.ones(16, 2, dtype=torch.float)[:, 1]
|
||||
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
|
||||
out_meta = torch.ops.aten._segment_reduce_backward(
|
||||
grad.to(device='meta'),
|
||||
output.to(device='meta'),
|
||||
data.to(device='meta'),
|
||||
reduce_str,
|
||||
lengths=lengths.to(device='meta'),
|
||||
)
|
||||
self.assertEqual(out.shape, out_meta.shape)
|
||||
self.assertEqual(out.stride(), out_meta.stride())
|
||||
self.assertEqual(out.dtype, out_meta.dtype)
|
||||
self.assertEqual(out.layout, out_meta.layout)
|
||||
|
||||
def test_embedding_bag_dense_backward_per_sample_weights(self):
|
||||
weight = torch.randn(4, 3, requires_grad=True)
|
||||
indices = torch.tensor([1, 0, 2, 1, 3])
|
||||
|
Reference in New Issue
Block a user