Add embedding_bag meta functions

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78997

Approved by: https://github.com/Chillee, https://github.com/Lezcano
This commit is contained in:
Edward Z. Yang
2022-06-08 11:31:04 -07:00
committed by PyTorch MergeBot
parent 7710d872fc
commit 50f2af84da
2 changed files with 77 additions and 4 deletions

View File

@ -432,7 +432,6 @@ meta_function_expected_failures = {
torch.nn.functional.conv_transpose2d: {f32, f64, i64},
torch.nn.functional.conv_transpose3d: {f32, f64, i64},
torch.nn.functional.ctc_loss: {f32, f64},
torch.nn.functional.embedding_bag: {f16, f32, f64}, # aten::_embedding_bag_forward_only
torch.nn.functional.gaussian_nll_loss: {bf16, f32, f64}, # aten::_local_scalar_dense
torch.nn.functional.grid_sample: {f32, f64}, # aten::grid_sampler_2d, aten::grid_sampler_3d
torch.nn.functional.max_pool3d: {f32, f64}, # aten::max_pool3d_with_indices
@ -552,7 +551,6 @@ meta_function_device_expected_failures['cuda'] = {
torch.nn.functional.conv_transpose1d: {bf16, f16},
torch.nn.functional.conv_transpose2d: {bf16, f16},
torch.nn.functional.conv_transpose3d: {bf16, f16},
torch.nn.functional.embedding_bag: {bf16}, # aten::_embedding_bag_forward_only
torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
torch.nn.functional.grid_sample: {f16}, # aten::grid_sampler_2d, aten::grid_sampler_3d
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
@ -635,7 +633,6 @@ meta_dispatch_expected_failures = {
aten._conj_physical.default: {c32},
aten._convolution.default: {c64, i64, f64, c128, bf16, f32},
aten._ctc_loss.default: {f64, f32},
aten._embedding_bag_forward_only.default: {f16, f64, f32},
aten._fft_r2c.default: {i64, u8, b8, f32, i8, f64, i16, i32},
aten._histogramdd_bin_edges.default: {f64, f32},
aten._histogramdd_from_bin_cts.default: {f64, f32},
@ -758,7 +755,6 @@ meta_dispatch_device_skips = defaultdict(dict)
meta_dispatch_device_expected_failures['cuda'] = {
aten._conj_physical.default: {f16}, # aten::conj_physical.out
aten._convolution.default: {f16, c32},
aten._embedding_bag_forward_only.default: {bf16}, # aten::_embedding_bag_forward_only
aten._fft_c2c.default: {c32, f16}, # aten::_fft_c2c
aten._fft_c2c.out: {c32, f16}, # aten::_fft_c2c.out
aten._fft_c2r.default: {c32, f16}, # aten::_fft_c2r

View File

@ -329,3 +329,80 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
output_shape.extend([r1, r2])
return x1.new_empty(output_shape)
@torch.library.impl(meta_lib, "_embedding_bag")
def meta_embedding_bag(
weight, indices, offsets, scale_grad_by_freq=False, mode=0,
sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1
):
check(indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}")
check(offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}")
check(
utils.is_float_dtype(weight.dtype),
lambda: f"expected weight to be floating point type, got {weight.dtype}"
)
num_bags = offsets.size(0)
if include_last_offset:
check(num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1")
num_bags -= 1
output = weight.new_empty(num_bags, weight.size(1))
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
if per_sample_weights is not None:
check(mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'")
check(
per_sample_weights.dtype == weight.dtype,
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype"
)
check(per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D")
check(
per_sample_weights.numel() == indices.numel(),
lambda: (
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
f"to be the same as indices.numel() ({indices.numel()})"
)
)
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
return is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
def is_fast_path_index_select(src, output, padding_idx):
return (
(src.dtype == torch.float or src.dtype == torch.half)
and src.stride(1) == 1
and output.stride(1) == 1
and padding_idx < 0
)
def is_fast_path(src, scale, output, padding_idx):
if scale is not None:
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
else:
return is_fast_path_index_select(src, output, padding_idx)
if offsets.device.type != "cpu":
offset2bag = indices.new_empty(indices.size(0))
bag_size = indices.new_empty(offsets.size())
if mode == MODE_MAX:
max_indices = indices.new_empty(num_bags, weight.size(1))
else:
max_indices = indices.new_empty(0)
else:
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
offset2bag = offsets.new_empty(indices.size(0))
else:
offset2bag = offsets.new_empty(0)
bag_size = offsets.new_empty(num_bags)
max_indices = offsets.new_empty(bag_size.size())
return output, offset2bag, bag_size, max_indices
@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(weight, indices, offsets, *args)
if offsets.device.type == "cpu":
bag_size = offsets.new_empty(offsets.size())
return output, offset2bag, bag_size, max_indices