mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add embedding_bag meta functions
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78997 Approved by: https://github.com/Chillee, https://github.com/Lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
7710d872fc
commit
50f2af84da
@ -432,7 +432,6 @@ meta_function_expected_failures = {
|
||||
torch.nn.functional.conv_transpose2d: {f32, f64, i64},
|
||||
torch.nn.functional.conv_transpose3d: {f32, f64, i64},
|
||||
torch.nn.functional.ctc_loss: {f32, f64},
|
||||
torch.nn.functional.embedding_bag: {f16, f32, f64}, # aten::_embedding_bag_forward_only
|
||||
torch.nn.functional.gaussian_nll_loss: {bf16, f32, f64}, # aten::_local_scalar_dense
|
||||
torch.nn.functional.grid_sample: {f32, f64}, # aten::grid_sampler_2d, aten::grid_sampler_3d
|
||||
torch.nn.functional.max_pool3d: {f32, f64}, # aten::max_pool3d_with_indices
|
||||
@ -552,7 +551,6 @@ meta_function_device_expected_failures['cuda'] = {
|
||||
torch.nn.functional.conv_transpose1d: {bf16, f16},
|
||||
torch.nn.functional.conv_transpose2d: {bf16, f16},
|
||||
torch.nn.functional.conv_transpose3d: {bf16, f16},
|
||||
torch.nn.functional.embedding_bag: {bf16}, # aten::_embedding_bag_forward_only
|
||||
torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
|
||||
torch.nn.functional.grid_sample: {f16}, # aten::grid_sampler_2d, aten::grid_sampler_3d
|
||||
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
|
||||
@ -635,7 +633,6 @@ meta_dispatch_expected_failures = {
|
||||
aten._conj_physical.default: {c32},
|
||||
aten._convolution.default: {c64, i64, f64, c128, bf16, f32},
|
||||
aten._ctc_loss.default: {f64, f32},
|
||||
aten._embedding_bag_forward_only.default: {f16, f64, f32},
|
||||
aten._fft_r2c.default: {i64, u8, b8, f32, i8, f64, i16, i32},
|
||||
aten._histogramdd_bin_edges.default: {f64, f32},
|
||||
aten._histogramdd_from_bin_cts.default: {f64, f32},
|
||||
@ -758,7 +755,6 @@ meta_dispatch_device_skips = defaultdict(dict)
|
||||
meta_dispatch_device_expected_failures['cuda'] = {
|
||||
aten._conj_physical.default: {f16}, # aten::conj_physical.out
|
||||
aten._convolution.default: {f16, c32},
|
||||
aten._embedding_bag_forward_only.default: {bf16}, # aten::_embedding_bag_forward_only
|
||||
aten._fft_c2c.default: {c32, f16}, # aten::_fft_c2c
|
||||
aten._fft_c2c.out: {c32, f16}, # aten::_fft_c2c.out
|
||||
aten._fft_c2r.default: {c32, f16}, # aten::_fft_c2r
|
||||
|
@ -329,3 +329,80 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
|
||||
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
|
||||
output_shape.extend([r1, r2])
|
||||
return x1.new_empty(output_shape)
|
||||
|
||||
@torch.library.impl(meta_lib, "_embedding_bag")
|
||||
def meta_embedding_bag(
|
||||
weight, indices, offsets, scale_grad_by_freq=False, mode=0,
|
||||
sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1
|
||||
):
|
||||
check(indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}")
|
||||
check(offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}")
|
||||
check(
|
||||
utils.is_float_dtype(weight.dtype),
|
||||
lambda: f"expected weight to be floating point type, got {weight.dtype}"
|
||||
)
|
||||
|
||||
num_bags = offsets.size(0)
|
||||
if include_last_offset:
|
||||
check(num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1")
|
||||
num_bags -= 1
|
||||
|
||||
output = weight.new_empty(num_bags, weight.size(1))
|
||||
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
|
||||
|
||||
if per_sample_weights is not None:
|
||||
check(mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'")
|
||||
check(
|
||||
per_sample_weights.dtype == weight.dtype,
|
||||
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype"
|
||||
)
|
||||
check(per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D")
|
||||
check(
|
||||
per_sample_weights.numel() == indices.numel(),
|
||||
lambda: (
|
||||
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
|
||||
f"to be the same as indices.numel() ({indices.numel()})"
|
||||
)
|
||||
)
|
||||
|
||||
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
|
||||
return is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
|
||||
|
||||
def is_fast_path_index_select(src, output, padding_idx):
|
||||
return (
|
||||
(src.dtype == torch.float or src.dtype == torch.half)
|
||||
and src.stride(1) == 1
|
||||
and output.stride(1) == 1
|
||||
and padding_idx < 0
|
||||
)
|
||||
|
||||
def is_fast_path(src, scale, output, padding_idx):
|
||||
if scale is not None:
|
||||
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
|
||||
else:
|
||||
return is_fast_path_index_select(src, output, padding_idx)
|
||||
|
||||
|
||||
if offsets.device.type != "cpu":
|
||||
offset2bag = indices.new_empty(indices.size(0))
|
||||
bag_size = indices.new_empty(offsets.size())
|
||||
if mode == MODE_MAX:
|
||||
max_indices = indices.new_empty(num_bags, weight.size(1))
|
||||
else:
|
||||
max_indices = indices.new_empty(0)
|
||||
else:
|
||||
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
|
||||
if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
|
||||
offset2bag = offsets.new_empty(indices.size(0))
|
||||
else:
|
||||
offset2bag = offsets.new_empty(0)
|
||||
bag_size = offsets.new_empty(num_bags)
|
||||
max_indices = offsets.new_empty(bag_size.size())
|
||||
return output, offset2bag, bag_size, max_indices
|
||||
|
||||
@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
|
||||
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
|
||||
output, offset2bag, bag_size, max_indices = meta_embedding_bag(weight, indices, offsets, *args)
|
||||
if offsets.device.type == "cpu":
|
||||
bag_size = offsets.new_empty(offsets.size())
|
||||
return output, offset2bag, bag_size, max_indices
|
||||
|
Reference in New Issue
Block a user