add meta for _segment_reduce_backward (#137442)

reland of https://github.com/pytorch/pytorch/pull/124988

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137442
Approved by: https://github.com/albanD
This commit is contained in:
Brian Hirsh
2024-10-08 07:34:12 -07:00
committed by PyTorch MergeBot
parent 1aac1ffce1
commit 53af729a66
3 changed files with 54 additions and 6 deletions

View File

@ -6307,12 +6307,6 @@ symbolic_aot_autograd_failures = {
xfail(
"nn.functional.nll_loss", ""
), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"_segment_reduce", "lengths"
), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail(
"_segment_reduce", "offsets"
), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"_upsample_bilinear2d_aa"

View File

@ -1636,6 +1636,42 @@ class TestMeta(TestCase):
)
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
def test_segment_reduce_backward(self):
grad = torch.ones(16, dtype=torch.float)
output = torch.ones(16, dtype=torch.float)
data = torch.ones(16, dtype=torch.float)
reduce_str = 'max'
lengths = torch.ones(16, dtype=torch.long)
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
out_meta = torch.ops.aten._segment_reduce_backward(
grad.to(device='meta'),
output.to(device='meta'),
data.to(device='meta'),
reduce_str,
lengths=lengths.to(device='meta'),
)
self.assertEqual(out.shape, out_meta.shape)
self.assertEqual(out.stride(), out_meta.stride())
self.assertEqual(out.dtype, out_meta.dtype)
self.assertEqual(out.layout, out_meta.layout)
# noncontiguous
grad = torch.ones(16, 2, dtype=torch.float)[:, 1]
data = torch.ones(16, 2, dtype=torch.float)[:, 1]
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
out_meta = torch.ops.aten._segment_reduce_backward(
grad.to(device='meta'),
output.to(device='meta'),
data.to(device='meta'),
reduce_str,
lengths=lengths.to(device='meta'),
)
self.assertEqual(out.shape, out_meta.shape)
self.assertEqual(out.stride(), out_meta.stride())
self.assertEqual(out.dtype, out_meta.dtype)
self.assertEqual(out.layout, out_meta.layout)
def test_embedding_bag_dense_backward_per_sample_weights(self):
weight = torch.randn(4, 3, requires_grad=True)
indices = torch.tensor([1, 0, 2, 1, 3])

View File

@ -5984,6 +5984,24 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True):
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
@register_meta(aten._segment_reduce_backward)
@out_wrapper()
def meta__segment_reduce_backward(
grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None
):
assert (
lengths is not None or offsets is not None
), "segment_reduce(): Either lengths or offsets must be defined"
data_contig = data.contiguous()
grad_contig = grad.contiguous()
return torch.empty_like(
data_contig,
dtype=grad_contig.dtype,
device=grad_contig.device,
layout=grad_contig.layout,
)
@register_meta([aten.kthvalue.default, aten.kthvalue.values])
@out_wrapper("values", "indices")
def kthvalue_meta(self, k, dim=-1, keepdim=False):