[MPS] Compute offset2bag/bag_size/max_indices in _embedding_bag (#163281)

Part of #162270

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163281
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-09-19 10:55:25 -05:00
committed by PyTorch MergeBot
parent b879ef7c0d
commit 20149080f2
6 changed files with 173 additions and 62 deletions

View File

@ -14,7 +14,7 @@ struct EmbeddingBagParams {
::c10::metal::array<idx_type_t, 2> output_strides;
::c10::metal::array<idx_type_t, 2> max_indices_strides;
idx_type_t per_sample_weights_strides;
idx_type_t per_sample_weights_stride;
idx_type_t num_indices;
idx_type_t num_bags;

View File

@ -23,54 +23,72 @@ struct ReductionOpInit<EmbeddingBagMode::MAX, T> {
template <EmbeddingBagMode M, typename T>
struct ReductionOp {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> weight_val,
opmath_t<T> out_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_strides);
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::SUM, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_strides) {
if (per_sample_weights_strides) {
T per_sample_weight = per_sample_weights
[per_sample_weights_strides * per_sample_weights_index];
return static_cast<opmath_t<T>>(per_sample_weight) *
static_cast<opmath_t<T>>(weight_val) +
out_val;
} else {
return static_cast<opmath_t<T>>(weight_val) + out_val;
}
}
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::MEAN, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> out_val,
uint32_t,
constant T*,
uint32_t) {
return static_cast<opmath_t<T>>(weight_val) + out_val;
bool is_first) {
return weight_val + out_val;
}
};
template <typename T>
struct ReductionOp<EmbeddingBagMode::MAX, T> {
inline opmath_t<T> operator()(
T weight_val,
opmath_t<T> weight_val,
opmath_t<T> out_val,
uint32_t,
constant T*,
uint32_t) {
return max(static_cast<opmath_t<T>>(weight_val), out_val);
bool is_first) {
return (is_first || weight_val > out_val) ? weight_val : out_val;
}
};
template <EmbeddingBagMode M, typename T>
struct MaybeApplyPerSampleWeight {
inline opmath_t<T> operator()(
opmath_t<T> weight_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_stride) {
return weight_val;
}
};
template <typename T>
struct MaybeApplyPerSampleWeight<EmbeddingBagMode::SUM, T> {
inline opmath_t<T> operator()(
opmath_t<T> weight_val,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_stride) {
if (per_sample_weights_stride) {
T per_sample_weight = per_sample_weights
[per_sample_weights_stride * per_sample_weights_index];
return static_cast<opmath_t<T>>(per_sample_weight) * weight_val;
} else {
return weight_val;
}
}
};
template <EmbeddingBagMode M, typename T, typename I>
struct MaybeCalcMaxIndex {
inline void operator()(
opmath_t<T> weight_val,
opmath_t<T> out_val,
bool is_first,
thread I& max_idx,
I weight_idx,
bool pad) {}
};
template <typename T, typename I>
struct MaybeCalcMaxIndex<EmbeddingBagMode::MAX, T, I> {
inline void operator()(
opmath_t<T> weight_val,
opmath_t<T> out_val,
bool is_first,
thread I& max_idx,
I weight_idx,
bool pad) {
max_idx = !pad && (is_first || weight_val > out_val) ? weight_idx : max_idx;
}
};
@ -96,6 +114,30 @@ struct ReductionOpFinal<EmbeddingBagMode::MAX, T> {
}
};
template <EmbeddingBagMode M, typename I>
struct MaybeWriteMaxIndex {
inline void operator()(
device I*,
const constant ::c10::metal::array<uint32_t, 2>&,
uint32_t,
uint32_t,
I) {}
};
template <typename I>
struct MaybeWriteMaxIndex<EmbeddingBagMode::MAX, I> {
inline void operator()(
device I* max_indices,
const constant ::c10::metal::array<uint32_t, 2>& max_indices_strides,
uint32_t bag_idx,
uint32_t feature_idx,
I max_idx) {
max_indices
[bag_idx * max_indices_strides[0] +
feature_idx * max_indices_strides[1]] = max_idx;
}
};
template <EmbeddingBagMode M, typename T, typename I>
void embedding_bag_impl(
constant T* weight,
@ -112,7 +154,7 @@ void embedding_bag_impl(
auto num_bags = params.num_bags;
auto feature_size = params.feature_size;
auto padding_idx = params.padding_idx;
auto per_sample_weights_strides = params.per_sample_weights_strides;
auto per_sample_weights_stride = params.per_sample_weights_stride;
constant auto& output_strides = params.output_strides;
constant auto& weight_strides = params.weight_strides;
constant auto& max_indices_strides = params.max_indices_strides;
@ -120,8 +162,6 @@ void embedding_bag_impl(
auto bag_idx = tid / feature_size;
auto feature_idx = tid % feature_size;
output += bag_idx * output_strides[0] + feature_idx * output_strides[1];
uint32_t offsets_end = min(bag_idx + 1, num_bags - 1);
bool is_last_bag = bag_idx + 1 == num_bags;
uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]);
@ -131,28 +171,37 @@ void embedding_bag_impl(
auto out_val = ReductionOpInit<M, T>()();
uint32_t bag_size_ = 0;
I max_idx = 0;
for (uint32_t indices_idx = indices_start; indices_idx < indices_end;
indices_idx++) {
I weight_idx = indices[indices_idx];
bool pad = (weight_idx == padding_idx);
T weight_val = weight
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
feature_idx * weight_strides[1]];
auto weight_val = static_cast<opmath_t<T>>(
weight
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
feature_idx * weight_strides[1]]);
weight_val = MaybeApplyPerSampleWeight<M, T>()(
weight_val, indices_idx, per_sample_weights, per_sample_weights_stride);
auto new_out_val = ReductionOp<M, T>()(weight_val, out_val, bag_size_ == 0);
MaybeCalcMaxIndex<M, T, I>()(
weight_val, out_val, bag_size_ == 0, max_idx, weight_idx, pad);
out_val = pad ? out_val : new_out_val;
offset2bag[indices_idx] = bag_idx;
bag_size_ += static_cast<uint32_t>(!pad);
auto tmp_val = ReductionOp<M, T>()(
weight_val,
out_val,
indices_idx,
per_sample_weights,
per_sample_weights_strides);
out_val = pad ? out_val : tmp_val;
}
*output = ReductionOpFinal<M, T>()(out_val, bag_size_);
output[bag_idx * output_strides[0] + feature_idx * output_strides[1]] =
ReductionOpFinal<M, T>()(out_val, bag_size_);
bag_size[bag_idx] = bag_size_;
MaybeWriteMaxIndex<M, I>()(
max_indices, max_indices_strides, bag_idx, feature_idx, max_idx);
}
#define DISPATCH_IMPL(MODE) \

View File

@ -66,11 +66,12 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
int64_t num_indices = indices.size(0);
int64_t num_bags = offsets.size(0);
if (include_last_offset) {
TORCH_CHECK(num_bags >= 1, "include_last_offset: number of offsets should be at least 1");
num_bags -= 1;
}
int64_t feature_size = weight.size(1);
auto bag_size = at::empty(offsets.sizes(), indices.options());
auto bag_size = at::empty({num_bags}, indices.options());
auto offset2bag = at::empty({indices.size(0)}, indices.options());
auto output = at::empty({num_bags, feature_size}, weight.options());
@ -94,7 +95,7 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
}
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
params.num_indices = num_indices;
params.num_bags = num_bags;

View File

@ -7274,8 +7274,6 @@ GPU_TEST_FAILURES = {
}
MPS_TEST_FAILURES = {
# aten::_embedding_bag backward is not currently implemented for the MPS device.
"test_embedding_bag": fail_mps(),
# aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device.
"test_scaled_dot_product_efficient_attention": fail_mps(),
# aten::_int_mm is not implemented for MPS backend

View File

@ -5675,7 +5675,6 @@ class CommonTemplate:
(torch.randn([2, 4, 4, 8]),),
)
@xfail_if_mps_unimplemented
def test_embedding_bag(self):
def fn(w, i, o):
return aten._embedding_bag(w, i, o, False, 0, False, None)

View File

@ -6940,6 +6940,70 @@ class TestMPS(TestCaseMPS):
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
helper(22, 0, [])
# TODO: This test can be removed once the backward pass of embedding_bag is
# implemented and tested
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("idx_dtype", [torch.long, torch.int])
@parametrize("padding_idx", [-1, 1])
@parametrize("include_last_offset", [True, False])
@parametrize("mode", ['sum', 'mean', 'max'])
def test__embedding_bag(self, dtype, idx_dtype, padding_idx, include_last_offset, mode):
import time
torch.manual_seed(time.time() * 1000)
mode_num = {'sum': 0, 'mean': 1, 'max': 2}[mode]
num_words = 10
feature_size = 7
num_indices = 40
num_bags = 5
weight_cpu = torch.randn(num_words, feature_size, dtype=dtype)
# Test nan value behavior.
# Set second element of each word to nan.
weight_cpu[:, 1] = float('nan')
# Set third element of a randomized half of the words to nan.
weight_cpu[torch.randperm(num_words)[:num_words // 2], 2] = float('nan')
# Set fourth element of one randomized word to nan.
weight_cpu[torch.randint(0, num_words, ()), 3] = float('nan')
input_cpu = torch.randint(0, num_words, (num_indices,), dtype=idx_dtype)
offsets_cpu = torch.tensor(
[0] + (torch.randperm(num_indices - 1)[:num_bags - 1].sort()[0] + 1).tolist(),
dtype=idx_dtype)
if include_last_offset:
offsets_cpu[-1] = input_cpu.numel()
per_sample_weights_cpu = torch.randn(num_indices, dtype=dtype) if mode == 'sum' else None
r_cpu, offset2bag_cpu, bag_size_cpu, max_indices_cpu = torch._embedding_bag(
weight_cpu,
input_cpu,
offsets_cpu,
per_sample_weights=per_sample_weights_cpu,
mode=mode_num,
padding_idx=padding_idx,
include_last_offset=include_last_offset,
)
r_mps, offset2bag_mps, bag_size_mps, max_indices_mps = torch._embedding_bag(
weight_cpu.to('mps'),
input_cpu.to('mps'),
offsets_cpu.to('mps'),
per_sample_weights=per_sample_weights_cpu.to('mps') if per_sample_weights_cpu is not None else None,
mode=mode_num,
padding_idx=padding_idx,
include_last_offset=include_last_offset,
)
self.assertEqual(r_cpu, r_mps)
if mode != 'sum':
self.assertEqual(offset2bag_cpu, offset2bag_mps)
self.assertEqual(bag_size_cpu, bag_size_mps)
if mode == 'max':
self.assertEqual(max_indices_cpu, max_indices_mps)
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')