[MPS] Add embedding_bag forward pass (#163012)

Part of #162270

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163012
Approved by: https://github.com/kulinseth, https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-09-16 13:48:08 -05:00
committed by PyTorch MergeBot
parent 167ad09be5
commit 5236007806
8 changed files with 423 additions and 5 deletions

View File

@ -1,3 +1,4 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <cstdint>

View File

@ -0,0 +1,25 @@
#pragma once
#include <c10/metal/common.h>
#ifdef __METAL__
enum class EmbeddingBagMode { SUM = 0, MEAN, MAX };
#else
#include <ATen/native/EmbeddingBag.h>
using at::native::EmbeddingBagMode;
#endif
template <typename idx_type_t = uint32_t>
struct EmbeddingBagParams {
::c10::metal::array<idx_type_t, 2> weight_strides;
::c10::metal::array<idx_type_t, 2> output_strides;
::c10::metal::array<idx_type_t, 2> max_indices_strides;
idx_type_t per_sample_weights_strides;
idx_type_t num_indices;
idx_type_t num_bags;
idx_type_t feature_size;
EmbeddingBagMode mode;
int64_t padding_idx;
};

View File

@ -0,0 +1,212 @@
#include <ATen/native/mps/kernels/EmbeddingBag.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
template <EmbeddingBagMode M, typename T>
struct ReductionOpInit {
inline opmath_t<T> operator()() {
return 0;
}
};
template <typename T>
struct ReductionOpInit<EmbeddingBagMode::MAX, T> {
inline opmath_t<T> operator()() {
return static_cast<opmath_t<T>>(-INFINITY);
}
};
template <EmbeddingBagMode M, typename T>
struct ReductionOp {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_strides);
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::SUM, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_strides) {
if (per_sample_weights_strides) {
T per_sample_weight = per_sample_weights
[per_sample_weights_strides * per_sample_weights_index];
return static_cast<opmath_t<T>>(per_sample_weight) *
static_cast<opmath_t<T>>(weight_val) +
out_val;
} else {
return static_cast<opmath_t<T>>(weight_val) + out_val;
}
}
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::MEAN, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t,
constant T*,
uint32_t) {
return static_cast<opmath_t<T>>(weight_val) + out_val;
}
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::MAX, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t,
constant T*,
uint32_t) {
return max(static_cast<opmath_t<T>>(weight_val), out_val);
}
};
template <EmbeddingBagMode M, typename T>
struct ReductionOpFinal {
inline T operator()(opmath_t<T> val, uint32_t) {
return static_cast<T>(val);
}
};
template <typename T>
struct ReductionOpFinal<EmbeddingBagMode::MEAN, T> {
inline T operator()(opmath_t<T> val, uint32_t count) {
auto out = val / count;
return static_cast<T>((count == 0) ? 0 : out);
}
};
template <typename T>
struct ReductionOpFinal<EmbeddingBagMode::MAX, T> {
inline T operator()(opmath_t<T> val, uint32_t count) {
return static_cast<T>((count == 0) ? 0 : val);
}
};
template <EmbeddingBagMode M, typename T, typename I>
void embedding_bag_impl(
constant T* weight,
constant I* indices,
constant I* offsets,
constant T* per_sample_weights,
device T* output,
device I* offset2bag,
device I* bag_size,
device I* max_indices,
constant EmbeddingBagParams<uint32_t>& params,
uint tid) {
auto num_indices = params.num_indices;
auto num_bags = params.num_bags;
auto feature_size = params.feature_size;
auto padding_idx = params.padding_idx;
auto per_sample_weights_strides = params.per_sample_weights_strides;
constant auto& output_strides = params.output_strides;
constant auto& weight_strides = params.weight_strides;
constant auto& max_indices_strides = params.max_indices_strides;
auto bag_idx = tid / feature_size;
auto feature_idx = tid % feature_size;
output += bag_idx * output_strides[0] + feature_idx * output_strides[1];
uint32_t offsets_end = min(bag_idx + 1, num_bags - 1);
bool is_last_bag = bag_idx + 1 == num_bags;
uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]);
uint32_t indices_end = is_last_bag * (num_indices) +
(!is_last_bag) * (static_cast<uint32_t>(offsets[offsets_end]));
auto out_val = ReductionOpInit<M, T>()();
uint32_t bag_size_ = 0;
for (uint32_t indices_idx = indices_start; indices_idx < indices_end;
indices_idx++) {
I weight_idx = indices[indices_idx];
bool pad = (weight_idx == padding_idx);
T weight_val = weight
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
feature_idx * weight_strides[1]];
bag_size_ += static_cast<uint32_t>(!pad);
auto tmp_val = ReductionOp<M, T>()(
weight_val,
out_val,
indices_idx,
per_sample_weights,
per_sample_weights_strides);
out_val = pad ? out_val : tmp_val;
}
*output = ReductionOpFinal<M, T>()(out_val, bag_size_);
}
#define DISPATCH_IMPL(MODE) \
return embedding_bag_impl<MODE>( \
weight, \
indices, \
offsets, \
per_sample_weights, \
output, \
offset2bag, \
bag_size, \
max_indices, \
params, \
tid)
template <typename T, typename I>
kernel void embedding_bag(
constant T* weight [[buffer(0)]],
constant I* indices [[buffer(1)]],
constant I* offsets [[buffer(2)]],
constant T* per_sample_weights [[buffer(3)]],
device T* output [[buffer(4)]],
device I* offset2bag [[buffer(5)]],
device I* bag_size [[buffer(6)]],
device I* max_indices [[buffer(7)]],
constant EmbeddingBagParams<uint32_t>& params [[buffer(8)]],
uint tid [[thread_position_in_grid]]) {
switch (params.mode) {
case EmbeddingBagMode::SUM:
DISPATCH_IMPL(EmbeddingBagMode::SUM);
case EmbeddingBagMode::MEAN:
DISPATCH_IMPL(EmbeddingBagMode::MEAN);
case EmbeddingBagMode::MAX:
DISPATCH_IMPL(EmbeddingBagMode::MAX);
}
}
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
template [[host_name("embedding_bag_" #T "_" #I)]] \
kernel void embedding_bag<T, I>( \
constant T * weight [[buffer(0)]], \
constant I * indices [[buffer(1)]], \
constant I * offsets [[buffer(2)]], \
constant T * per_sample_weights [[buffer(3)]], \
device T * output [[buffer(4)]], \
device I * offset2bag [[buffer(5)]], \
device I * bag_size [[buffer(6)]], \
device I * max_indices [[buffer(7)]], \
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_EMBEDDING_BAG_OP(float, int);
REGISTER_EMBEDDING_BAG_OP(float, long);
REGISTER_EMBEDDING_BAG_OP(half, int);
REGISTER_EMBEDDING_BAG_OP(half, long);
REGISTER_EMBEDDING_BAG_OP(bfloat, int);
REGISTER_EMBEDDING_BAG_OP(bfloat, long);

View File

@ -0,0 +1,179 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/EmbeddingBag.h>
#include <ATen/native/Pool.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/EmbeddingBag.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_embedding_bag_forward_only_native.h>
#include <ATen/ops/_embedding_bag_native.h>
#include <ATen/ops/empty.h>
#endif
namespace at::native {
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/EmbeddingBag_metallib.h>
#endif
namespace {
std::pair<Tensor, Tensor> promoteIndicesAndOffsets(const Tensor& indices, const Tensor& offsets) {
const auto commonType = promoteTypes(offsets.scalar_type(), indices.scalar_type());
return {indices.scalar_type() == commonType ? indices : indices.toType(commonType),
offsets.scalar_type() == commonType ? offsets : offsets.toType(commonType)};
}
} // namespace
namespace mps {
static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
const Tensor& weight,
const Tensor& indices_,
const Tensor& offsets_,
const bool scale_grad_by_freq,
const int64_t mode,
bool sparse,
const std::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset,
int64_t padding_idx) {
TORCH_CHECK(indices_.dim() == 1, "input has to be a 1D Tensor, but got Tensor of dimension ", indices_.dim());
if (indices_.dim() == 1) {
TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor, but got Tensor of dimension ", offsets_.dim());
}
TORCH_CHECK(weight.dim() == 2, "weight has to be a 2D Tensor, but got Tensor of dimension ", weight.dim());
Tensor indices, offsets;
std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding_bag_mps", indices_arg, {kLong, kInt});
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarTypes("embedding_bag_mps", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag_mps", indices_arg, offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
int64_t num_indices = indices.size(0);
int64_t num_bags = offsets.size(0);
if (include_last_offset) {
num_bags -= 1;
}
int64_t feature_size = weight.size(1);
auto bag_size = at::empty(offsets.sizes(), indices.options());
auto offset2bag = at::empty({indices.size(0)}, indices.options());
auto output = at::empty({num_bags, feature_size}, weight.options());
Tensor max_indices;
if (mode == EmbeddingBagMode::MAX) {
max_indices = at::empty({num_bags, feature_size}, indices.options());
} else {
max_indices = at::empty({0}, indices.options());
}
EmbeddingBagParams<uint32_t> params;
for (const auto dim : c10::irange(weight.dim())) {
params.weight_strides[dim] = safe_downcast<uint32_t, int64_t>(weight.stride(dim));
params.output_strides[dim] = safe_downcast<uint32_t, int64_t>(output.stride(dim));
if (mode == EmbeddingBagMode::MAX) {
params.max_indices_strides[dim] = safe_downcast<uint32_t, int64_t>(max_indices.stride(dim));
}
}
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
params.num_indices = num_indices;
params.num_bags = num_bags;
params.feature_size = feature_size;
params.mode = static_cast<EmbeddingBagMode>(mode);
params.padding_idx = padding_idx;
auto num_threads = output.numel();
MPSStream* stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(
fmt::format("embedding_bag_{}_{}", scalarToMetalTypeString(weight), scalarToMetalTypeString(indices)));
getMPSProfiler().beginProfileKernel(pipeline_state, "embedding_bag", {weight, indices, offsets});
[computeEncoder setComputePipelineState:pipeline_state];
mtl_setArgs(computeEncoder,
weight,
indices,
offsets,
use_per_sample_weights ? per_sample_weights_opt : std::nullopt,
output,
offset2bag,
bag_size,
max_indices,
params);
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices));
}
} // namespace mps
std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps(const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const bool scale_grad_by_freq,
const int64_t mode,
bool sparse,
const std::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset,
int64_t padding_idx) {
return mps::_embedding_bag_mps_impl(weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights_opt,
include_last_offset,
padding_idx);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only_mps(
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const bool scale_grad_by_freq,
const int64_t mode,
bool sparse,
const std::optional<Tensor>& per_sample_weights_opt,
bool include_last_offset,
int64_t padding_idx) {
return _embedding_bag_mps(weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights_opt,
include_last_offset,
padding_idx);
}
} // namespace at::native

View File

@ -2351,6 +2351,7 @@
dispatch:
CPU: _embedding_bag_forward_only_cpu
CUDA: _embedding_bag_forward_only_cuda
MPS: _embedding_bag_forward_only_mps
autogen: _embedding_bag_forward_only.out
- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
@ -2372,6 +2373,7 @@
dispatch:
CPU: _embedding_bag_cpu
CUDA: _embedding_bag_cuda
MPS: _embedding_bag_mps
autogen: _embedding_bag.out
tags: core

View File

@ -7277,11 +7277,8 @@ GPU_TEST_FAILURES = {
}
MPS_TEST_FAILURES = {
# aten::_embedding_bag is not currently implemented for the MPS device.
# aten::_embedding_bag backward is not currently implemented for the MPS device.
"test_embedding_bag": fail_mps(),
# aten::_embedding_bag is not currently implemented for the MPS device.
"test_misc_1_max_autotune_False": fail_mps(),
"test_misc_1_max_autotune_True": fail_mps(),
# aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device.
"test_scaled_dot_product_efficient_attention": fail_mps(),
# aten::_int_mm is not implemented for MPS backend

View File

@ -15,6 +15,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d(AtenTensorH
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);

View File

@ -348,7 +348,6 @@ if torch.backends.mps.is_available():
"nn.functional.interpolatearea": None,
"nn.functional.interpolatebicubic": [torch.uint8],
"nn.functional.ctc_loss": None,
"nn.functional.embedding_bag": None,
"nn.functional.multi_margin_loss": None,
"nn.functional.multilabel_margin_loss": None,
"nn.functional.pdist": None,
@ -740,6 +739,7 @@ if torch.backends.mps.is_available():
"equal": [torch.float16, torch.float32],
# 'float' object is not iterable
"item": [torch.float16, torch.float32],
"nn.functional.embedding_bag": None,
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
"nn.functional.smooth_l1_loss": [torch.float16],
# cpu error: grad requires non-empty inputs