add meta for _segment_reduce_backward (#137442)

reland of https://github.com/pytorch/pytorch/pull/124988

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137442
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh
2024-10-08 07:34:12 -07:00
committed by PyTorch MergeBot
parent 1aac1ffce1
commit 53af729a66
3 changed files with 54 additions and 6 deletions

View File

@ -1636,6 +1636,42 @@ class TestMeta(TestCase):
)
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
def test_segment_reduce_backward(self):
grad = torch.ones(16, dtype=torch.float)
output = torch.ones(16, dtype=torch.float)
data = torch.ones(16, dtype=torch.float)
reduce_str = 'max'
lengths = torch.ones(16, dtype=torch.long)
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
out_meta = torch.ops.aten._segment_reduce_backward(
grad.to(device='meta'),
output.to(device='meta'),
data.to(device='meta'),
reduce_str,
lengths=lengths.to(device='meta'),
)
self.assertEqual(out.shape, out_meta.shape)
self.assertEqual(out.stride(), out_meta.stride())
self.assertEqual(out.dtype, out_meta.dtype)
self.assertEqual(out.layout, out_meta.layout)
# noncontiguous
grad = torch.ones(16, 2, dtype=torch.float)[:, 1]
data = torch.ones(16, 2, dtype=torch.float)[:, 1]
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
out_meta = torch.ops.aten._segment_reduce_backward(
grad.to(device='meta'),
output.to(device='meta'),
data.to(device='meta'),
reduce_str,
lengths=lengths.to(device='meta'),
)
self.assertEqual(out.shape, out_meta.shape)
self.assertEqual(out.stride(), out_meta.stride())
self.assertEqual(out.dtype, out_meta.dtype)
self.assertEqual(out.layout, out_meta.layout)
def test_embedding_bag_dense_backward_per_sample_weights(self):
weight = torch.randn(4, 3, requires_grad=True)
indices = torch.tensor([1, 0, 2, 1, 3])