mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Compute offset2bag/bag_size/max_indices
in _embedding_bag
(#163281)
Part of #162270 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163281 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b879ef7c0d
commit
20149080f2
@ -14,7 +14,7 @@ struct EmbeddingBagParams {
|
||||
::c10::metal::array<idx_type_t, 2> output_strides;
|
||||
::c10::metal::array<idx_type_t, 2> max_indices_strides;
|
||||
|
||||
idx_type_t per_sample_weights_strides;
|
||||
idx_type_t per_sample_weights_stride;
|
||||
|
||||
idx_type_t num_indices;
|
||||
idx_type_t num_bags;
|
||||
|
@ -23,54 +23,72 @@ struct ReductionOpInit<EmbeddingBagMode::MAX, T> {
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct ReductionOp {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_strides);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::SUM, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_strides) {
|
||||
if (per_sample_weights_strides) {
|
||||
T per_sample_weight = per_sample_weights
|
||||
[per_sample_weights_strides * per_sample_weights_index];
|
||||
return static_cast<opmath_t<T>>(per_sample_weight) *
|
||||
static_cast<opmath_t<T>>(weight_val) +
|
||||
out_val;
|
||||
} else {
|
||||
return static_cast<opmath_t<T>>(weight_val) + out_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::MEAN, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t,
|
||||
constant T*,
|
||||
uint32_t) {
|
||||
return static_cast<opmath_t<T>>(weight_val) + out_val;
|
||||
bool is_first) {
|
||||
return weight_val + out_val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ReductionOp<EmbeddingBagMode::MAX, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
T weight_val,
|
||||
opmath_t<T> weight_val,
|
||||
opmath_t<T> out_val,
|
||||
uint32_t,
|
||||
constant T*,
|
||||
uint32_t) {
|
||||
return max(static_cast<opmath_t<T>>(weight_val), out_val);
|
||||
bool is_first) {
|
||||
return (is_first || weight_val > out_val) ? weight_val : out_val;
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct MaybeApplyPerSampleWeight {
|
||||
inline opmath_t<T> operator()(
|
||||
opmath_t<T> weight_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_stride) {
|
||||
return weight_val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaybeApplyPerSampleWeight<EmbeddingBagMode::SUM, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
opmath_t<T> weight_val,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_stride) {
|
||||
if (per_sample_weights_stride) {
|
||||
T per_sample_weight = per_sample_weights
|
||||
[per_sample_weights_stride * per_sample_weights_index];
|
||||
return static_cast<opmath_t<T>>(per_sample_weight) * weight_val;
|
||||
} else {
|
||||
return weight_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T, typename I>
|
||||
struct MaybeCalcMaxIndex {
|
||||
inline void operator()(
|
||||
opmath_t<T> weight_val,
|
||||
opmath_t<T> out_val,
|
||||
bool is_first,
|
||||
thread I& max_idx,
|
||||
I weight_idx,
|
||||
bool pad) {}
|
||||
};
|
||||
|
||||
template <typename T, typename I>
|
||||
struct MaybeCalcMaxIndex<EmbeddingBagMode::MAX, T, I> {
|
||||
inline void operator()(
|
||||
opmath_t<T> weight_val,
|
||||
opmath_t<T> out_val,
|
||||
bool is_first,
|
||||
thread I& max_idx,
|
||||
I weight_idx,
|
||||
bool pad) {
|
||||
max_idx = !pad && (is_first || weight_val > out_val) ? weight_idx : max_idx;
|
||||
}
|
||||
};
|
||||
|
||||
@ -96,6 +114,30 @@ struct ReductionOpFinal<EmbeddingBagMode::MAX, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename I>
|
||||
struct MaybeWriteMaxIndex {
|
||||
inline void operator()(
|
||||
device I*,
|
||||
const constant ::c10::metal::array<uint32_t, 2>&,
|
||||
uint32_t,
|
||||
uint32_t,
|
||||
I) {}
|
||||
};
|
||||
|
||||
template <typename I>
|
||||
struct MaybeWriteMaxIndex<EmbeddingBagMode::MAX, I> {
|
||||
inline void operator()(
|
||||
device I* max_indices,
|
||||
const constant ::c10::metal::array<uint32_t, 2>& max_indices_strides,
|
||||
uint32_t bag_idx,
|
||||
uint32_t feature_idx,
|
||||
I max_idx) {
|
||||
max_indices
|
||||
[bag_idx * max_indices_strides[0] +
|
||||
feature_idx * max_indices_strides[1]] = max_idx;
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T, typename I>
|
||||
void embedding_bag_impl(
|
||||
constant T* weight,
|
||||
@ -112,7 +154,7 @@ void embedding_bag_impl(
|
||||
auto num_bags = params.num_bags;
|
||||
auto feature_size = params.feature_size;
|
||||
auto padding_idx = params.padding_idx;
|
||||
auto per_sample_weights_strides = params.per_sample_weights_strides;
|
||||
auto per_sample_weights_stride = params.per_sample_weights_stride;
|
||||
constant auto& output_strides = params.output_strides;
|
||||
constant auto& weight_strides = params.weight_strides;
|
||||
constant auto& max_indices_strides = params.max_indices_strides;
|
||||
@ -120,8 +162,6 @@ void embedding_bag_impl(
|
||||
auto bag_idx = tid / feature_size;
|
||||
auto feature_idx = tid % feature_size;
|
||||
|
||||
output += bag_idx * output_strides[0] + feature_idx * output_strides[1];
|
||||
|
||||
uint32_t offsets_end = min(bag_idx + 1, num_bags - 1);
|
||||
bool is_last_bag = bag_idx + 1 == num_bags;
|
||||
uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]);
|
||||
@ -131,28 +171,37 @@ void embedding_bag_impl(
|
||||
auto out_val = ReductionOpInit<M, T>()();
|
||||
|
||||
uint32_t bag_size_ = 0;
|
||||
I max_idx = 0;
|
||||
|
||||
for (uint32_t indices_idx = indices_start; indices_idx < indices_end;
|
||||
indices_idx++) {
|
||||
I weight_idx = indices[indices_idx];
|
||||
bool pad = (weight_idx == padding_idx);
|
||||
T weight_val = weight
|
||||
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
|
||||
feature_idx * weight_strides[1]];
|
||||
auto weight_val = static_cast<opmath_t<T>>(
|
||||
weight
|
||||
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
|
||||
feature_idx * weight_strides[1]]);
|
||||
|
||||
weight_val = MaybeApplyPerSampleWeight<M, T>()(
|
||||
weight_val, indices_idx, per_sample_weights, per_sample_weights_stride);
|
||||
|
||||
auto new_out_val = ReductionOp<M, T>()(weight_val, out_val, bag_size_ == 0);
|
||||
|
||||
MaybeCalcMaxIndex<M, T, I>()(
|
||||
weight_val, out_val, bag_size_ == 0, max_idx, weight_idx, pad);
|
||||
|
||||
out_val = pad ? out_val : new_out_val;
|
||||
offset2bag[indices_idx] = bag_idx;
|
||||
bag_size_ += static_cast<uint32_t>(!pad);
|
||||
|
||||
auto tmp_val = ReductionOp<M, T>()(
|
||||
weight_val,
|
||||
out_val,
|
||||
indices_idx,
|
||||
per_sample_weights,
|
||||
per_sample_weights_strides);
|
||||
|
||||
out_val = pad ? out_val : tmp_val;
|
||||
}
|
||||
|
||||
*output = ReductionOpFinal<M, T>()(out_val, bag_size_);
|
||||
output[bag_idx * output_strides[0] + feature_idx * output_strides[1]] =
|
||||
ReductionOpFinal<M, T>()(out_val, bag_size_);
|
||||
|
||||
bag_size[bag_idx] = bag_size_;
|
||||
|
||||
MaybeWriteMaxIndex<M, I>()(
|
||||
max_indices, max_indices_strides, bag_idx, feature_idx, max_idx);
|
||||
}
|
||||
|
||||
#define DISPATCH_IMPL(MODE) \
|
||||
|
@ -66,11 +66,12 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
|
||||
int64_t num_indices = indices.size(0);
|
||||
int64_t num_bags = offsets.size(0);
|
||||
if (include_last_offset) {
|
||||
TORCH_CHECK(num_bags >= 1, "include_last_offset: number of offsets should be at least 1");
|
||||
num_bags -= 1;
|
||||
}
|
||||
int64_t feature_size = weight.size(1);
|
||||
|
||||
auto bag_size = at::empty(offsets.sizes(), indices.options());
|
||||
auto bag_size = at::empty({num_bags}, indices.options());
|
||||
auto offset2bag = at::empty({indices.size(0)}, indices.options());
|
||||
auto output = at::empty({num_bags, feature_size}, weight.options());
|
||||
|
||||
@ -94,7 +95,7 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
|
||||
}
|
||||
|
||||
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
|
||||
params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
|
||||
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
|
||||
|
||||
params.num_indices = num_indices;
|
||||
params.num_bags = num_bags;
|
||||
|
@ -7274,8 +7274,6 @@ GPU_TEST_FAILURES = {
|
||||
}
|
||||
|
||||
MPS_TEST_FAILURES = {
|
||||
# aten::_embedding_bag backward is not currently implemented for the MPS device.
|
||||
"test_embedding_bag": fail_mps(),
|
||||
# aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device.
|
||||
"test_scaled_dot_product_efficient_attention": fail_mps(),
|
||||
# aten::_int_mm is not implemented for MPS backend
|
||||
|
@ -5675,7 +5675,6 @@ class CommonTemplate:
|
||||
(torch.randn([2, 4, 4, 8]),),
|
||||
)
|
||||
|
||||
@xfail_if_mps_unimplemented
|
||||
def test_embedding_bag(self):
|
||||
def fn(w, i, o):
|
||||
return aten._embedding_bag(w, i, o, False, 0, False, None)
|
||||
|
@ -6940,6 +6940,70 @@ class TestMPS(TestCaseMPS):
|
||||
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
|
||||
helper(22, 0, [])
|
||||
|
||||
# TODO: This test can be removed once the backward pass of embedding_bag is
|
||||
# implemented and tested
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@parametrize("idx_dtype", [torch.long, torch.int])
|
||||
@parametrize("padding_idx", [-1, 1])
|
||||
@parametrize("include_last_offset", [True, False])
|
||||
@parametrize("mode", ['sum', 'mean', 'max'])
|
||||
def test__embedding_bag(self, dtype, idx_dtype, padding_idx, include_last_offset, mode):
|
||||
import time
|
||||
torch.manual_seed(time.time() * 1000)
|
||||
mode_num = {'sum': 0, 'mean': 1, 'max': 2}[mode]
|
||||
num_words = 10
|
||||
feature_size = 7
|
||||
num_indices = 40
|
||||
num_bags = 5
|
||||
|
||||
weight_cpu = torch.randn(num_words, feature_size, dtype=dtype)
|
||||
|
||||
# Test nan value behavior.
|
||||
# Set second element of each word to nan.
|
||||
weight_cpu[:, 1] = float('nan')
|
||||
# Set third element of a randomized half of the words to nan.
|
||||
weight_cpu[torch.randperm(num_words)[:num_words // 2], 2] = float('nan')
|
||||
# Set fourth element of one randomized word to nan.
|
||||
weight_cpu[torch.randint(0, num_words, ()), 3] = float('nan')
|
||||
|
||||
input_cpu = torch.randint(0, num_words, (num_indices,), dtype=idx_dtype)
|
||||
offsets_cpu = torch.tensor(
|
||||
[0] + (torch.randperm(num_indices - 1)[:num_bags - 1].sort()[0] + 1).tolist(),
|
||||
dtype=idx_dtype)
|
||||
|
||||
if include_last_offset:
|
||||
offsets_cpu[-1] = input_cpu.numel()
|
||||
|
||||
per_sample_weights_cpu = torch.randn(num_indices, dtype=dtype) if mode == 'sum' else None
|
||||
|
||||
r_cpu, offset2bag_cpu, bag_size_cpu, max_indices_cpu = torch._embedding_bag(
|
||||
weight_cpu,
|
||||
input_cpu,
|
||||
offsets_cpu,
|
||||
per_sample_weights=per_sample_weights_cpu,
|
||||
mode=mode_num,
|
||||
padding_idx=padding_idx,
|
||||
include_last_offset=include_last_offset,
|
||||
)
|
||||
r_mps, offset2bag_mps, bag_size_mps, max_indices_mps = torch._embedding_bag(
|
||||
weight_cpu.to('mps'),
|
||||
input_cpu.to('mps'),
|
||||
offsets_cpu.to('mps'),
|
||||
per_sample_weights=per_sample_weights_cpu.to('mps') if per_sample_weights_cpu is not None else None,
|
||||
mode=mode_num,
|
||||
padding_idx=padding_idx,
|
||||
include_last_offset=include_last_offset,
|
||||
)
|
||||
|
||||
self.assertEqual(r_cpu, r_mps)
|
||||
|
||||
if mode != 'sum':
|
||||
self.assertEqual(offset2bag_cpu, offset2bag_mps)
|
||||
self.assertEqual(bag_size_cpu, bag_size_mps)
|
||||
|
||||
if mode == 'max':
|
||||
self.assertEqual(max_indices_cpu, max_indices_mps)
|
||||
|
||||
def test_embedding_dense_backward(self):
|
||||
def helper(n, d, m, idx):
|
||||
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
|
||||
|
Reference in New Issue
Block a user