mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
add meta for _segment_reduce_backward (#137442)
reland of https://github.com/pytorch/pytorch/pull/124988 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137442 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
1aac1ffce1
commit
53af729a66
@ -6307,12 +6307,6 @@ symbolic_aot_autograd_failures = {
|
||||
xfail(
|
||||
"nn.functional.nll_loss", ""
|
||||
), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail(
|
||||
"_segment_reduce", "lengths"
|
||||
), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail(
|
||||
"_segment_reduce", "offsets"
|
||||
), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail(
|
||||
"_upsample_bilinear2d_aa"
|
||||
|
@ -1636,6 +1636,42 @@ class TestMeta(TestCase):
|
||||
)
|
||||
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
|
||||
|
||||
def test_segment_reduce_backward(self):
|
||||
grad = torch.ones(16, dtype=torch.float)
|
||||
output = torch.ones(16, dtype=torch.float)
|
||||
data = torch.ones(16, dtype=torch.float)
|
||||
reduce_str = 'max'
|
||||
lengths = torch.ones(16, dtype=torch.long)
|
||||
|
||||
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
|
||||
out_meta = torch.ops.aten._segment_reduce_backward(
|
||||
grad.to(device='meta'),
|
||||
output.to(device='meta'),
|
||||
data.to(device='meta'),
|
||||
reduce_str,
|
||||
lengths=lengths.to(device='meta'),
|
||||
)
|
||||
self.assertEqual(out.shape, out_meta.shape)
|
||||
self.assertEqual(out.stride(), out_meta.stride())
|
||||
self.assertEqual(out.dtype, out_meta.dtype)
|
||||
self.assertEqual(out.layout, out_meta.layout)
|
||||
|
||||
# noncontiguous
|
||||
grad = torch.ones(16, 2, dtype=torch.float)[:, 1]
|
||||
data = torch.ones(16, 2, dtype=torch.float)[:, 1]
|
||||
out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths)
|
||||
out_meta = torch.ops.aten._segment_reduce_backward(
|
||||
grad.to(device='meta'),
|
||||
output.to(device='meta'),
|
||||
data.to(device='meta'),
|
||||
reduce_str,
|
||||
lengths=lengths.to(device='meta'),
|
||||
)
|
||||
self.assertEqual(out.shape, out_meta.shape)
|
||||
self.assertEqual(out.stride(), out_meta.stride())
|
||||
self.assertEqual(out.dtype, out_meta.dtype)
|
||||
self.assertEqual(out.layout, out_meta.layout)
|
||||
|
||||
def test_embedding_bag_dense_backward_per_sample_weights(self):
|
||||
weight = torch.randn(4, 3, requires_grad=True)
|
||||
indices = torch.tensor([1, 0, 2, 1, 3])
|
||||
|
@ -5984,6 +5984,24 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True):
|
||||
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_meta(aten._segment_reduce_backward)
|
||||
@out_wrapper()
|
||||
def meta__segment_reduce_backward(
|
||||
grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None
|
||||
):
|
||||
assert (
|
||||
lengths is not None or offsets is not None
|
||||
), "segment_reduce(): Either lengths or offsets must be defined"
|
||||
data_contig = data.contiguous()
|
||||
grad_contig = grad.contiguous()
|
||||
return torch.empty_like(
|
||||
data_contig,
|
||||
dtype=grad_contig.dtype,
|
||||
device=grad_contig.device,
|
||||
layout=grad_contig.layout,
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten.kthvalue.default, aten.kthvalue.values])
|
||||
@out_wrapper("values", "indices")
|
||||
def kthvalue_meta(self, k, dim=-1, keepdim=False):
|
||||
|
Reference in New Issue
Block a user