Add meta for _embedding_bag_dense_backward and _embedding_bag_per_sample_weights_backward (#125785)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125785
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2024-05-08 17:04:32 -07:00
committed by PyTorch MergeBot
parent ed48ea9997
commit aaa2f93a4f
3 changed files with 103 additions and 3 deletions

View File

@ -1602,6 +1602,66 @@ class TestMeta(TestCase):
self.assertEqual(eb.dtype, torch.float32)
self.assertEqual(eb.untyped_storage().data_ptr(), 0)
# Tests mean and max.
# Can't easily test sum, because there is a fast path for sum which
# causes offset2bag to not get allocated... but the backward function
# needs it, and the offset2bag computation lives inside the
# derivatives.yaml formula directly, so there is no way to access it.
# To test sum, need to manually compute offset2bag
@parametrize("mode", [1, 2])
def test_embedding_bag_dense_backward(self, mode):
weight = torch.randn(4, 3, requires_grad=True)
indices = torch.tensor([1, 0, 2, 1, 3])
offsets = torch.tensor([0, 2, 3, 5])
scale_grad_by_freq = False
sparse = False
per_sample_weights = None
include_last_offset = False
padding_idx = -1
output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
)
grad = torch.randn_like(output)
# Call the function with example inputs
grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
grad, indices, offset2bag, bag_size, maximum_indices, weight.size(0),
scale_grad_by_freq, mode, per_sample_weights, padding_idx
)
meta_grad_weight = torch.ops.aten._embedding_bag_dense_backward.default(
grad.to('meta'), indices.to('meta'), offset2bag.to('meta'), bag_size.to('meta'),
maximum_indices.to('meta'), weight.size(0),
scale_grad_by_freq, mode, per_sample_weights, padding_idx
)
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
def test_embedding_bag_dense_backward_per_sample_weights(self):
weight = torch.randn(4, 3, requires_grad=True)
indices = torch.tensor([1, 0, 2, 1, 3])
offsets = torch.tensor([0, 2, 3, 5])
scale_grad_by_freq = False
sparse = False
mode = 0
per_sample_weights = torch.randn(5, requires_grad=True)
include_last_offset = False
padding_idx = -1
output, offset2bag, bag_size, maximum_indices = torch.ops.aten._embedding_bag.default(
weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx
)
grad = torch.randn_like(output)
# Call the function with example inputs
grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
grad, weight, indices, offsets, offset2bag, mode, padding_idx
)
meta_grad_weight = torch.ops.aten._embedding_bag_per_sample_weights_backward.default(
grad.to('meta'), weight.to('meta'), indices.to('meta'),
offsets.to('meta'), offset2bag.to('meta'), mode, padding_idx
)
self.assertEqual(grad_weight.to('meta'), meta_grad_weight)
# opinfo test is using aten.fill_, it's not testing aten.fill
@onlyCUDA
def test_fill_stride(self):