Support int32 indices and offsets in nn.EmbeddingBag (#46758)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46758

It's in general helpful to support int32 indices and offsets, especially when such tensors are large and need to be transferred to accelerator backends. Since it may not be very useful to support the combination of int32 indices and int64 offsets, here we enforce that these two must have the same type.

Test Plan: unit tests

Reviewed By: ngimel

Differential Revision: D24470808

fbshipit-source-id: 94b8a1d0b7fc9fe3d128247aa042c04d7c227f0b
This commit is contained in:
Qi Zhou
2020-11-03 23:31:24 -08:00
committed by Facebook GitHub Bot
parent a2f9c7d4e3
commit 0ec717c830
21 changed files with 990 additions and 869 deletions

View File

@ -557,6 +557,25 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_index_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _it = ::detail::scalar_type(the_index_type); \
switch (_it) { \
case at::ScalarType::Int: { \
using index_t = int32_t; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Long: { \
using index_t = int64_t; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_it), "'"); \
} \
}()
// ----------------------------------------------------------------------------
// DEPRECATED MACROS, DON'T USE THESE
// ----------------------------------------------------------------------------

View File

@ -15,7 +15,7 @@ Tensor embedding(const Tensor & weight, const Tensor & indices,
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D");
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding", indices_arg, kLong);
checkScalarTypes("embedding", indices_arg, {kLong, kInt});
auto zerofill_padding = [&](Tensor& embedding) {
if (padding_idx >= 0) {
@ -57,7 +57,7 @@ Tensor embedding_sparse_backward(
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices_, "indices", 2);
checkScalarType("embedding_backward", indices_arg, kLong);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
// TODO: implement scale_grad_by_freq
if (scale_grad_by_freq) {
@ -79,14 +79,14 @@ Tensor embedding_sparse_backward(
// check if all our grad come from padding_idx
if (grad.numel() == 0) {
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options()),
return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)),
at::empty({0, num_features}, dense_options),
weight_size);
}
auto index = indices.reshape({1, -1});
auto values = grad.reshape({-1, num_features});
return at::_sparse_coo_tensor_unsafe(index, values, weight_size);
return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size);
}
Tensor embedding_dense_backward_cpu(
@ -94,50 +94,48 @@ Tensor embedding_dense_backward_cpu(
int64_t padding_idx, bool scale_grad_by_freq) {
auto indices_arg = TensorArg(indices, "indices", 2);
checkScalarType("embedding_backward", indices_arg, kLong);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
auto indices_contig = indices.contiguous();
auto indices_data = indices_contig.data_ptr<int64_t>();
int64_t numel = indices.numel();
std::unique_ptr<int64_t[]> counts;
if (scale_grad_by_freq) {
counts.reset(new int64_t[num_weights]);
for (int i = 0; i < numel; i++) {
counts[indices_data[i]] = 0;
}
for (int i = 0; i < numel; i++) {
counts[indices_data[i]]++;
}
}
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
auto indices_contig = indices.contiguous();
int64_t numel = indices.numel();
auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
auto parallel_section = [&](int64_t start, int64_t end) {
for (int64_t i = 0; i < numel; i++) {
if (indices_data[i] != padding_idx) {
int64_t k = indices_data[i];
if (k >= start && k < end) {
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[k];
}
grad_weight[k].add_(grad[i], scale);
}
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
auto indices_data = indices_contig.data_ptr<index_t>();
std::unique_ptr<index_t[]> counts;
if (scale_grad_by_freq) {
counts.reset(new index_t[num_weights]);
for (int i = 0; i < numel; i++) {
counts[indices_data[i]] = 0;
}
for (int i = 0; i < numel; i++) {
counts[indices_data[i]]++;
}
}
};
if (numel > 1000) {
// The strategy is to parallelize over sections of the vocabulary, so that
// thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread
// has to traverse the entire input, but the dominating factor is the axpy
// BLAS call.
at::parallel_for(0, num_weights, 0, parallel_section);
} else {
parallel_section(0, num_weights);
}
auto parallel_section = [&](index_t start, index_t end) {
for (int64_t i = 0; i < numel; i++) {
if (indices_data[i] != padding_idx) {
index_t k = indices_data[i];
if (k >= start && k < end) {
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[k];
}
grad_weight[k].add_(grad[i], scale);
}
}
}
};
if (numel > 1000) {
at::parallel_for(0, num_weights, 0, parallel_section);
} else {
parallel_section(0, num_weights);
}
});
return grad_weight;
}
@ -147,28 +145,30 @@ Tensor & embedding_renorm_cpu_(
auto self_arg = TensorArg(self, "self", 1);
auto indices_arg = TensorArg(indices, "indices", 2);
checkDim("embedding_renorm_", self_arg, 2);
checkScalarType("embedding_renorm_", indices_arg, kLong);
checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
auto indices_contig = indices.contiguous();
auto num_indices = indices.numel();
auto data_ptr = indices_contig.data_ptr<int64_t>();
auto sorted_indices = std::vector<int64_t>(data_ptr, data_ptr + num_indices);
std::sort(sorted_indices.begin(), sorted_indices.end(), std::less<int64_t>());
// Note that we cannot use at::parallel_for here because we perform operations on
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
for (auto i = 0; i < num_indices; i++) {
if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
continue;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
auto data_ptr = indices_contig.data_ptr<index_t>();
auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
std::sort(sorted_indices.begin(), sorted_indices.end());
// Note that we cannot use at::parallel_for here because we perform operations on
// Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
for (auto i = 0; i < num_indices; i++) {
if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
continue;
}
auto row = self[sorted_indices[i]];
auto norm = row.norm(norm_type).item<double>();
if (norm > max_norm) {
auto scale = max_norm / (norm + 1e-7);
row *= scale;
}
}
auto row = self[sorted_indices[i]];
auto norm = row.norm(norm_type).item<double>();
if (norm > max_norm) {
auto scale = max_norm / (norm + 1e-7);
row *= scale;
}
}
});
return self;
}

View File

@ -32,11 +32,11 @@ namespace native {
template<typename scalar_t>
scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) {
static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) {
offset2bag.index_add_(
0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2]
}
namespace {
@ -52,18 +52,19 @@ bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor&
// This function combines index_select (using select_indices as the index) and
// index_add (using add_indices as the index), without creating an intermediary
// tensor to hold the selected embeddings
template<typename T>
void index_select_add(const Tensor &select_indices,
template<typename data_t, typename index_t>
typename std::enable_if<!std::is_same<data_t, float>::value, void>::type
index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* src_data = src.data_ptr<T>();
auto* output_data = output.data_ptr<T>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* src_data = src.data_ptr<data_t>();
auto* output_data = output.data_ptr<data_t>();
auto numel = add_indices.numel();
int64_t ddim = src.size(1);
auto src_stride0 = src.stride(0);
@ -72,29 +73,30 @@ void index_select_add(const Tensor &select_indices,
auto output_stride1 = output.stride(1);
for (int64_t i = 0; i < numel; i++) {
THBlas_axpy<T>(ddim, 1,
THBlas_axpy<data_t>(ddim, 1,
src_data + src_stride0 * select_indices_data[i], src_stride1,
output_data + output_stride0 * add_indices_data[i], output_stride1);
}
}
template<>
void index_select_add<float>(const Tensor &select_indices,
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& offsets,
bool include_last_offset) {
int64_t ddim = src.size(1);
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* output_data = output.data_ptr<float>();
if (isFastPathIndexSelect(src, output)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<float>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<int64_t>();
std::vector<int64_t> offsets_include_last;
auto* offsets_data = offsets.data_ptr<index_t>();
std::vector<index_t> offsets_include_last;
if (include_last_offset) {
output_size = offsets.numel() - 1;
@ -103,15 +105,15 @@ void index_select_add<float>(const Tensor &select_indices,
offsets_include_last.resize(offsets.numel() + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * offsets.numel());
offsets.data_ptr<index_t>(),
sizeof(index_t) * offsets.numel());
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#ifdef USE_FBGEMM
auto kernel_fp32_i64 =
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
auto kernel_fp32_index_t =
fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
/* block_size */ddim,
/* has_weight */false,
/* normalize_by_lengths */false,
@ -121,9 +123,9 @@ void index_select_add<float>(const Tensor &select_indices,
);
#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
#ifdef USE_FBGEMM
kernel_fp32_i64(
kernel_fp32_index_t(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
@ -150,7 +152,7 @@ void index_select_add<float>(const Tensor &select_indices,
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<float>();
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
@ -172,8 +174,9 @@ void index_select_add<float>(const Tensor &select_indices,
// index_select (using select_indices as the index)
// mul (scaling by per_sample_weights)
// index_add (using add_indices as the index)
template<typename T>
static void index_select_scale_add(const Tensor &select_indices,
template<typename data_t, typename index_t>
static typename std::enable_if<!std::is_same<data_t, float>::value, void>::type
index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
@ -181,10 +184,10 @@ static void index_select_scale_add(const Tensor &select_indices,
const Tensor& /*offsets*/,
bool /*include_last_offset*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* src_data = src.data_ptr<T>();
auto* output_data = output.data_ptr<T>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* src_data = src.data_ptr<data_t>();
auto* output_data = output.data_ptr<data_t>();
auto numel = add_indices.numel();
int64_t ddim = src.size(1);
auto src_stride0 = src.stride(0);
@ -192,7 +195,7 @@ static void index_select_scale_add(const Tensor &select_indices,
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
auto* scale_data = scale.data_ptr<T>();
auto* scale_data = scale.data_ptr<data_t>();
auto scale_stride = scale.stride(0);
for (int64_t i = 0; i < numel; i++) {
@ -205,8 +208,9 @@ static void index_select_scale_add(const Tensor &select_indices,
}
}
template<>
void index_select_scale_add<float>(const Tensor &select_indices,
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
@ -215,15 +219,15 @@ void index_select_scale_add<float>(const Tensor &select_indices,
bool include_last_offset) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data_ptr<float>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* output_data = output.data_ptr<float>();
if (isFastPathIndexSelectScale(src, scale, output)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<float>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<int64_t>();
std::vector<int64_t> offsets_include_last;
auto* offsets_data = offsets.data_ptr<index_t>();
std::vector<index_t> offsets_include_last;
if (include_last_offset) {
output_size = offsets.numel() - 1;
@ -232,15 +236,15 @@ void index_select_scale_add<float>(const Tensor &select_indices,
offsets_include_last.resize(offsets.numel() + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * offsets.numel());
offsets.data_ptr<index_t>(),
sizeof(index_t) * offsets.numel());
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#ifdef USE_FBGEMM
auto kernel_fp32_i64 =
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
auto kernel_fp32_index_t =
fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>(
/* block_size */ddim,
/* has_weight */true,
/* normalize_by_lengths */false,
@ -250,9 +254,9 @@ void index_select_scale_add<float>(const Tensor &select_indices,
);
#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
#ifdef USE_FBGEMM
kernel_fp32_i64(
kernel_fp32_index_t(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
@ -279,7 +283,7 @@ void index_select_scale_add<float>(const Tensor &select_indices,
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<float>();
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
@ -308,7 +312,7 @@ static at::Tensor make_bag_size(
const bool requires_grad) {
at::Tensor bag_size;
if (mode == MODE_MEAN || mode == MODE_MAX) {
bag_size = at::zeros(offsets.sizes(), indices.options());
bag_size = at::zeros(offsets.sizes(), offsets.options());
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
if (offsets.size(0) != 1) {
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
@ -318,7 +322,7 @@ static at::Tensor make_bag_size(
bag_size[-1] = indices.size(0) - offsets[-1];
} else if (requires_grad) {
// in MODE_SUM, only allocate bag_size if we need gradients
bag_size = at::empty(offsets.sizes(), indices.options());
bag_size = at::empty(offsets.sizes(), offsets.options());
}
return bag_size;
}
@ -384,35 +388,36 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
}
auto max_indices =
at::zeros({numBags, featureSize}, indices.options());
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max", [&] {
auto* indices_data = indices.data_ptr<index_t>();
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
auto* indices_data = indices.data_ptr<int64_t>();
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
auto* max_indices_data = max_indices.data_ptr<index_t>();
auto max_indices_stride = max_indices.stride(0);
auto* max_indices_data = max_indices.data_ptr<int64_t>();
auto max_indices_stride = max_indices.stride(0);
auto* weight_data = weight.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
auto weight_stride0 = weight.stride(0);
auto weight_stride1 = weight.stride(1);
auto output_stride = output.stride(0);
auto* weight_data = weight.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
auto weight_stride0 = weight.stride(0);
auto weight_stride1 = weight.stride(1);
auto output_stride = output.stride(0);
for (int i = 0; i < numIndices; ++i) {
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];
for (int i = 0; i < numIndices; i++) {
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];
for (int dim = 0; dim < featureSize; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item =
weight_data[weight_stride0 * word_idx + dim * weight_stride1];
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
for (int dim = 0; dim < featureSize; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item =
weight_data[weight_stride0 * word_idx + dim * weight_stride1];
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
if (is_first_for_bag || weight_item > current_item) {
current_item = weight_item;
max_indices_data[max_indices_stride * bag + dim] = word_idx;
if (is_first_for_bag || weight_item > current_item) {
current_item = weight_item;
max_indices_data[max_indices_stride * bag + dim] = word_idx;
}
}
}
}
});
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
output, offset2bag, bag_size, max_indices);
@ -429,19 +434,23 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
bool include_last_offset,
bool requires_grad) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag", indices_arg, offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
int64_t offset_0 = offsets.data_ptr<int64_t>()[0];
int64_t offset_n = offsets.data_ptr<int64_t>()[offsets.size(0)-1];
TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
"in the mini-batch has to start from position 0. "
"However, got ", offsets[0]);
TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
"be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
offset_n);
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
index_t offset_0 = offsets.data_ptr<index_t>()[0];
index_t offset_n = offsets.data_ptr<index_t>()[offsets.size(0)-1];
TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
"in the mini-batch has to start from position 0. "
"However, got ", offsets[0]);
TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
"be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
offset_n);
});
if (per_sample_weights.defined()) {
TORCH_CHECK(mode == MODE_SUM,
@ -494,9 +503,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
// throw out of bounds error. So to keep it simple we just add one more
// entry to the end then get rid of it after make_offset2bag.
offset2bag = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
{indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag);
make_offset2bag(offsets, offset2bag);
offset2bag.resize_({indices.sizes()[0]});
@ -505,14 +514,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
}
if (mode == MODE_MEAN || mode == MODE_SUM) {
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
if (per_sample_weights.defined()) {
AT_ASSERT(mode == MODE_SUM);
index_select_scale_add<scalar_t>(
indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset);
} else {
index_select_add<scalar_t>(indices, offset2bag, weight, output, offsets, include_last_offset);
}
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu",
[&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu",
[&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode]() {
if (per_sample_weights.defined()) {
AT_ASSERT(mode == MODE_SUM);
index_select_scale_add<scalar_t, index_t>(
indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset);
} else {
index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset);
}
});
});
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
@ -598,23 +613,24 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
bool sparse,
const Tensor& per_sample_weights) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
checkContiguous("embedding_bag", indices_arg);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag", indices_arg, offsets_arg);
checkContiguous("embedding_bag", offsets_arg);
Tensor offset2bag_;
if (indices.numel() != 0 && offset2bag.numel() == 0) {
offset2bag_ = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
{indices.sizes()[0] + 1}, offsets.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag_);
make_offset2bag(offsets, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
checkContiguous("embedding_bag", offset2bag_arg);
offset2bag_ = offset2bag;
}
@ -648,11 +664,12 @@ static Tensor _embedding_bag_dense_backward_cpu_max(
return index_grad_weight;
}
static std::vector<int64_t> compute_counts(
template<typename index_t>
static std::vector<index_t> compute_counts(
int64_t num_weights,
int64_t* indices_data,
index_t* indices_data,
int64_t indices_length) {
std::vector<int64_t> counts(num_weights, 0);
std::vector<index_t> counts(num_weights, 0);
for (int i = 0; i < indices_length; i++) {
counts[indices_data[i]]++;
}
@ -668,12 +685,13 @@ static std::vector<int64_t> compute_counts(
// counts_uniq: [3, 4, 6, 7]
//
// The unique indices can be found at index 0, 3, 4, 6.
static std::vector<int64_t> compute_counts_uniq(
template<typename index_t>
static std::vector<index_t> compute_counts_uniq(
int64_t num_weights,
int64_t* indices_data,
index_t* indices_data,
int64_t indices_length,
const std::vector<int64_t>& counts) {
std::vector<int64_t> counts_uniq;
const std::vector<index_t>& counts) {
std::vector<index_t> counts_uniq;
counts_uniq.reserve(num_weights);
int64_t o = 0;
for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
@ -714,54 +732,66 @@ void _embedding_bag_dense_backward_cpu_sum_mean(
per_sample_weights_stride = per_sample_weights->stride(0);
}
auto* indices_data = indices.data_ptr<int64_t>();
auto* offsets_data = offsets_.data_ptr<int64_t>();
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
int64_t numel = indices.numel();
auto counts = compute_counts(num_weights, indices_data, numel);
auto next_unique_index_idx =
compute_counts_uniq(num_weights, indices_data, numel, counts);
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
[&indices, &offsets_, &offset2bag, &num_weights, &numel, &per_sample_weights,
&per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
&grad, &index_grad_weight] {
auto* indices_data = indices.data_ptr<index_t>();
auto* offsets_data = offsets_.data_ptr<index_t>();
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
auto loop = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
int64_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
int64_t index = indices_data[start];
for (int64_t j = start; j < next_unique_index_idx[i]; j++) {
int64_t source = offset2bag_data[j];
double scale = 1.0;
if (per_sample_weights) {
AT_ASSERT(mode == MODE_SUM);
scale = per_sample_weights_data[*per_sample_weights_stride * j];
}
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.size(0) == 1) {
auto bag_size = indices.size(0);
scale /= bag_size;
} else {
if (source == offsets_.size(0) - 1) {
scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
auto counts = compute_counts(num_weights, indices_data, numel);
auto next_unique_index_idx =
compute_counts_uniq(num_weights, indices_data, numel, counts);
auto loop =
[&next_unique_index_idx, &indices_data, &offset2bag_data, &per_sample_weights,
&mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
&counts, &offsets_, &indices, &offsets_data, &grad, &index_grad_weight](index_t start, index_t end) {
for (index_t i = start; i < end; i++) {
index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
index_t index = indices_data[start];
for (index_t j = start; j < next_unique_index_idx[i]; j++) {
index_t source = offset2bag_data[j];
double scale = 1.0;
if (per_sample_weights) {
AT_ASSERT(mode == MODE_SUM);
scale = per_sample_weights_data[*per_sample_weights_stride * j];
}
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.size(0) == 1) {
auto bag_size = indices.size(0);
scale /= bag_size;
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
if (source == offsets_.size(0) - 1) {
scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
}
}
}
int64_t ddim = grad.size(1);
auto igwd = index_grad_weight.data_ptr<scalar_t>();
auto gd = grad.data_ptr<scalar_t>();
THBlas_axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
int64_t ddim = grad.size(1);
auto igwd = index_grad_weight.data_ptr<scalar_t>();
auto gd = grad.data_ptr<scalar_t>();
THBlas_axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
};
if (numel > 1000) {
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
} else {
loop(0, (int64_t)next_unique_index_idx.size());
}
};
if (numel > 1000) {
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
} else {
loop(0, (int64_t)next_unique_index_idx.size());
}
});
}
Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
@ -820,20 +850,20 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
auto output = at::zeros({num_samples}, grad.options());
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
checkContiguous("embedding_bag", indices_arg);
Tensor offset2bag_;
if (indices.numel() != 0 && offset2bag.numel() == 0) {
offset2bag_ = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
{indices.sizes()[0] + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag_);
make_offset2bag(offsets, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
checkContiguous("embedding_bag", offset2bag_arg);
offset2bag_ = offset2bag;
}
@ -846,23 +876,31 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
auto weight_stride0 = weight.stride(0);
auto weight_stride1 = weight.stride(1);
auto* indices_data = indices.data_ptr<int64_t>();
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template",
[&indices, &output, &offset2bag_, &num_samples, &embedding_features,
&grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1] () {
auto* indices_data = indices.data_ptr<index_t>();
// The following are contiguous
auto* output_data = output.data_ptr<scalar_t>();
auto* offset2bag_data = offset2bag_.data_ptr<int64_t>();
// The following are contiguous
auto* output_data = output.data_ptr<scalar_t>();
auto* offset2bag_data = offset2bag_.data_ptr<index_t>();
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) {
for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) {
auto bag_idx = offset2bag_data[sample_idx];
auto embedding_idx = indices_data[sample_idx];
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
parallel_for(0, num_samples, 64,
[&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0,
&weight_stride1, &offset2bag_data, &indices_data, &output_data](index_t begin, index_t end) {
for (index_t sample_idx = begin; sample_idx < end; sample_idx++) {
auto bag_idx = offset2bag_data[sample_idx];
auto embedding_idx = indices_data[sample_idx];
output_data[sample_idx] = dot_impl<scalar_t>(
embedding_features,
grad_data + grad_stride0 * bag_idx, grad_stride1,
weight_data + weight_stride0 * embedding_idx, weight_stride1);
}
output_data[sample_idx] = dot_impl<scalar_t>(
embedding_features,
grad_data + grad_stride0 * bag_idx, grad_stride1,
weight_data + weight_stride0 * embedding_idx, weight_stride1);
}
});
});
return output;
}

View File

@ -381,7 +381,8 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
auto numel = index.numel();
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
"index_add_(): Expected dtype int32/int64 for index");
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
"index_add_(): self and source must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < source.dim(),
@ -394,7 +395,6 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
at::assert_no_partial_overlap(self, source);
auto index_contig = index.contiguous();
auto index_data = index_contig.data_ptr<int64_t>();
if (self.dim() > 1) {
// Equivalent to:
@ -414,32 +414,41 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T
auto self_dim_size = self.size(dim);
auto iter = TensorIterator::binary_op(selfSlice, selfSlice, sourceSlice);
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
iter.unsafe_replace_operand(0, self_data);
iter.unsafe_replace_operand(1, self_data);
iter.unsafe_replace_operand(2, source_data);
add_stub(iter.device_type(), iter, 1);
}
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
auto index_data = index_contig.data_ptr<index_t>();
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
iter.unsafe_replace_operand(0, self_data);
iter.unsafe_replace_operand(1, self_data);
iter.unsafe_replace_operand(2, source_data);
add_stub(iter.device_type(), iter, 1);
}
});
}
else {
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&] {
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "index_add_", [&self, &source, &dim, &index_contig, &numel] {
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
// TODO: Maybe TensorAccessor can beused here?
auto* self_ptr = self.data_ptr<scalar_t>();
auto* source_ptr = source.data_ptr<scalar_t>();
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
scalar_t *self_ip = self_ptr + self_i * self_stride;
*self_ip += *(source_ptr + i * source_stride);
}
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
[&index_contig, &numel, &self, &self_ptr, &self_stride, &source_ptr, &source_stride] {
auto index_data = index_contig.data_ptr<index_t>();
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self.numel()), "index out of range in self");
scalar_t *self_ip = self_ptr + self_i * self_stride;
*self_ip += *(source_ptr + i * source_stride);
}
});
});
}
return self;
@ -454,7 +463,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
auto numel = index.numel();
TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_select(): Expected dtype int64 for index");
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"index_select(): self and result must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < self.dim(),
@ -468,7 +477,6 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
result.resize_(result_size);
auto index_contig = index.contiguous();
auto index_data = index_contig.data_ptr<int64_t>();
if (self.dim() > 1) {
if (numel == 0 || self.numel() == 0) {
@ -492,17 +500,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
.build();
auto grain_size = at::internal::GRAIN_SIZE;
auto outer_loop = [&](int64_t start, int64_t end) {
auto outer_loop =
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
[&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data,
&result_stride_bytes](int64_t start, int64_t end) {
auto sub_iter = TensorIterator(iter);
for (int64_t i = start; i < end; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes;
auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
sub_iter.unsafe_replace_operand(0, result_data);
sub_iter.unsafe_replace_operand(1, self_data);
copy_stub(sub_iter.device_type(), sub_iter, false);
}
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
[&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
&resultSlice_data, &result_stride_bytes] () {
auto index_data = index_contig.data_ptr<index_t>();
for (int64_t i = start; i < end; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes;
auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
sub_iter.unsafe_replace_operand(0, result_data);
sub_iter.unsafe_replace_operand(1, self_data);
copy_stub(sub_iter.device_type(), sub_iter, false);
};
});
};
// parallel on inner loop in case the slice is large enough;
@ -513,14 +530,23 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
// use a fast loop when self and result are contiguous and of the same data type
if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) {
auto slice_size_bytes = slice_size * elementSize(self.scalar_type());
at::parallel_for(0, numel, grain_size / slice_size, [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes;
auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
memcpy(result_data, self_data, slice_size_bytes);
}
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
at::parallel_for(0, numel, grain_size / slice_size,
[&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
&self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) {
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
[&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
&self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
auto index_data = index_contig.data_ptr<index_t>();
for (int64_t i = start; i < end; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
auto self_data = static_cast<char*>(selfSlice_data) + self_i * self_stride_bytes;
auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
memcpy(result_data, self_data, slice_size_bytes);
}
});
});
} else {
at::parallel_for(0, numel, grain_size / slice_size, outer_loop);
@ -528,20 +554,26 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
}
} else {
TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select", [&] {
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "index_select",
[&index_contig, &self, &result, &dim, &numel] {
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
auto self_data_ptr = self.data_ptr<scalar_t>();
auto result_data_ptr = result.data_ptr<scalar_t>();
auto self_numel = self.numel();
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
scalar_t *self_ip = self_data_ptr + self_i * self_stride;
*(result_data_ptr + i * result_stride) = *self_ip;
}
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
[&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
auto index_data = index_contig.data_ptr<index_t>();
for (auto i = 0; i < numel; i++) {
auto self_i = index_data[i];
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
scalar_t *self_ip = self_data_ptr + self_i * self_stride;
*(result_data_ptr + i * result_stride) = *self_ip;
}
});
});
}

View File

@ -29,9 +29,10 @@ static const int BLOCKDIMY = 32;
template
<typename scalar_t,
typename accscalar_t>
typename accscalar_t,
typename index_t>
__global__ void embedding_backward_feature_kernel
(int64_t* indices,
(index_t* indices,
const scalar_t* __restrict__ grad,
scalar_t* __restrict__ grad_weight,
int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot
@ -117,10 +118,10 @@ __global__ void embedding_backward_feature_kernel
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void embedding_backward_kernel(
int64_t* input, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
int64_t* count, int64_t numel, int64_t stride, int padding_idx) {
index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
index_t* count, int64_t numel, int64_t stride, int padding_idx) {
using accscalar_t = acc_type<scalar_t, true>;
int idx = blockIdx.x * 4 + threadIdx.y;
@ -179,9 +180,9 @@ __global__ void embedding_backward_kernel(
}
/* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */
template <typename scalar_t, typename accscalar_t>
template <typename scalar_t, typename accscalar_t, typename index_t>
__global__ void renorm_kernel(
scalar_t* weights, int64_t* indices, accscalar_t max_norm,
scalar_t* weights, index_t* indices, accscalar_t max_norm,
accscalar_t norm_type, int64_t dim,
int64_t weights_stride0, int64_t weights_stride1) {
@ -228,7 +229,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
bool scale_grad_by_freq) {
auto grad_arg = TensorArg(grad_, "grad", 1);
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_backward", indices_arg, kLong);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
checkSameGPU("embedding_backward", grad_arg, indices_arg);
auto num_indices = indices.numel();
@ -250,18 +251,20 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
{
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
embedding_backward_feature_kernel<scalar_t, accscalar_t>
<<<grid,
block,
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
stream>>>
(indices_contig.data_ptr<int64_t>(),
grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
static_cast<int>(num_indices),
static_cast<int64_t>(stride),
static_cast<int>(padding_idx));
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
<<<grid,
block,
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
stream>>>
(indices_contig.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
static_cast<int>(num_indices),
static_cast<int64_t>(stride),
static_cast<int>(padding_idx));
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
return grad_weight;
@ -269,61 +272,63 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
using device_ptr = thrust::device_ptr<int64_t>;
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<int64_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<int64_t>());
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
ThrustLTOp<int64_t>());
}
Tensor count;
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
auto count_data = device_ptr(count.data_ptr<int64_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<int64_t>(),
thrust::maximum<int64_t>()
);
}
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<index_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
ThrustLTOp<index_t>());
}
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
auto count_data = device_ptr(count.data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
}
});
return embedding_backward_cuda_kernel(grad, orig_indices,
sorted_indices, count, num_weights, padding_idx);
@ -340,31 +345,33 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
using device_ptr = thrust::device_ptr<int64_t>;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
auto num_indices = indices.numel();
auto indices_contig = std::get<0>(indices.sort()).contiguous();
auto indices_data = device_ptr(indices_contig.data_ptr<int64_t>());
auto num_indices = indices.numel();
auto indices_contig = std::get<0>(indices.sort()).contiguous();
auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
auto unique_indices = at::empty(indices.numel(), indices.options());
auto unique_data = device_ptr(unique_indices.data_ptr<int64_t>());
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
auto num_unique_indices = static_cast<int>(end - unique_data);
auto unique_indices = at::empty(indices.numel(), indices.options());
auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
auto num_unique_indices = static_cast<int>(end - unique_data);
dim3 grid(num_unique_indices);
dim3 block(128);
int dim = self.stride(0);
dim3 grid(num_unique_indices);
dim3 block(128);
int dim = self.stride(0);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
self.data_ptr<scalar_t>(),
unique_indices.data_ptr<int64_t>(),
static_cast<accscalar_t>(max_norm),
static_cast<accscalar_t>(norm_type),
dim, self.stride(0), self.stride(1));
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
self.data_ptr<scalar_t>(),
unique_indices.data_ptr<index_t>(),
static_cast<accscalar_t>(max_norm),
static_cast<accscalar_t>(norm_type),
dim, self.stride(0), self.stride(1));
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
return self;

View File

@ -40,8 +40,9 @@ int64_t ceil_div(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
template <typename index_t>
__global__
void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
int64_t num_of_segments, int64_t numel) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
@ -52,18 +53,19 @@ void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets,
}
}
template <typename index_t>
__global__
void krn_partial_segment_offset(
int64_t *ret,
const int64_t *partials_per_segment,
const int64_t *partials_per_segment_offset,
const int64_t *segment_offsets,
index_t *ret,
const index_t *partials_per_segment,
const index_t *partials_per_segment_offset,
const index_t *segment_offsets,
int64_t num_of_segments) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id < num_of_segments) {
int64_t idx = partials_per_segment_offset[id];
const int64_t num_partials = partials_per_segment[id];
const int64_t segment_offset = segment_offsets[id];
index_t idx = partials_per_segment_offset[id];
const index_t num_partials = partials_per_segment[id];
const index_t segment_offset = segment_offsets[id];
for (int64_t i=0; i<num_partials; ++i) {
ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
}
@ -71,13 +73,13 @@ void krn_partial_segment_offset(
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void compute_grad_weight_bags(
int64_t *indices, scalar_t *gradOutput,
int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
int64_t stride, int mode_mean, const int64_t *bag_size,
index_t *indices, scalar_t *gradOutput,
index_t *offset2bag, index_t *count, ptrdiff_t numel,
int64_t stride, int mode_mean, const index_t *bag_size,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
int64_t* segment_offsets, int64_t num_of_segments,
index_t* segment_offsets, int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
@ -113,14 +115,14 @@ __global__ void compute_grad_weight_bags(
grad_weight_per_segment[id * stride + startFeature] = weight;
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void compute_grad_weight(
int64_t *indices,
index_t *indices,
scalar_t *gradOutput,
int64_t *count,
index_t *count,
ptrdiff_t numel,
int64_t stride,
int64_t* segment_offsets,
index_t* segment_offsets,
int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {
@ -140,7 +142,7 @@ __global__ void compute_grad_weight(
accscalar_t weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
const int64_t target_row = indices[idx];
const index_t target_row = indices[idx];
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
weight += gradOutput[target_row * stride + startFeature] * scale;
}
@ -148,12 +150,12 @@ __global__ void compute_grad_weight(
}
// This kernel assumes that all input tensors are contiguous.
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void sum_and_scatter(
int64_t *input, scalar_t *gradWeight, int64_t stride,
int64_t* segment_offsets, int64_t num_of_segments,
index_t *input, scalar_t *gradWeight, int64_t stride,
index_t* segment_offsets, int64_t num_of_segments,
const acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
const index_t *segment_sizes_offsets, int64_t num_of_partial_segments,
const int64_t padding_idx,
const int64_t stride_warped) {
@ -206,118 +208,120 @@ Tensor embedding_backward_cuda_kernel(
// spawn a warp per index. In this context, a segment is a number of rows that should
// be summarized.
// Unit: index in `sorted_indices` and `orig_indices`
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
{
auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data_ptr<int64_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<int64_t>(dummy.data_ptr<int64_t>());
auto ends = thrust::unique_by_key_copy(
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
{
auto sorted_indices_dev = thrust::device_ptr<index_t>(sorted_indices.data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.data_ptr<index_t>()));
num_of_segments = thrust::get<0>(ends) - dummy_dev;
}
// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
{
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partials_per_segment.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
num_of_segments,
numel);
}
// In order to compute `partial_segment_offset`, which is the start index
// of each partial-segment in `sorted_indices`, we need to compute the
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
thrust::exclusive_scan(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<int64_t>(segment_offsets.data_ptr<int64_t>()));
num_of_segments = thrust::get<0>(ends) - dummy_dev;
}
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()),
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()+num_of_segments),
thrust::device_ptr<index_t>(partials_per_segment_offset.data_ptr<index_t>()));
// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
{
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partials_per_segment.data_ptr<int64_t>(),
segment_offsets.data_ptr<int64_t>(),
num_of_segments,
numel);
}
// The total number of partial-segments is the sum of `partials_per_segment_offset`
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
partials_per_segment_offset[num_of_segments-1].item<index_t>();
// In order to compute `partial_segment_offset`, which is the start index
// of each partial-segment in `sorted_indices`, we need to compute the
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
thrust::exclusive_scan(
policy,
thrust::device_ptr<int64_t>(partials_per_segment.data_ptr<int64_t>()),
thrust::device_ptr<int64_t>(partials_per_segment.data_ptr<int64_t>()+num_of_segments),
thrust::device_ptr<int64_t>(partials_per_segment_offset.data_ptr<int64_t>()));
// Now we can compute the start position of each partial-segment
// Unit: index in `sorted_indices` and `orig_indices`
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
{
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partial_segment_offset.data_ptr<index_t>(),
partials_per_segment.data_ptr<index_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
segment_offsets.data_ptr<index_t>(),
num_of_segments);
}
// The total number of partial-segments is the sum of `partials_per_segment_offset`
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<int64_t>() +
partials_per_segment_offset[num_of_segments-1].item<int64_t>();
const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
// Now we can compute the start position of each partial-segment
// Unit: index in `sorted_indices` and `orig_indices`
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
{
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
partial_segment_offset.data_ptr<int64_t>(),
partials_per_segment.data_ptr<int64_t>(),
partials_per_segment_offset.data_ptr<int64_t>(),
segment_offsets.data_ptr<int64_t>(),
num_of_segments);
}
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] {
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr, numel, stride,
mode_mean, bag_size.data_ptr<index_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
count.defined() ? count.data_ptr<index_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
}
AT_CUDA_CHECK(cudaGetLastError());
const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_backward_cuda_compute_grad_weight", [&] {
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<int64_t>(),
grad.data_ptr<scalar_t>(),
offset2bag.data_ptr<int64_t>(),
count.defined() ? count.data_ptr<int64_t>() : nullptr, numel, stride,
mode_mean, bag_size.data_ptr<int64_t>(),
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data_ptr<int64_t>(),
num_of_partial_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
stride_warped);
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data_ptr<int64_t>(),
grad.data_ptr<scalar_t>(),
count.defined() ? count.data_ptr<int64_t>() : nullptr,
numel, stride,
partial_segment_offset.data_ptr<int64_t>(),
// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<index_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<index_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
num_of_partial_segments,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
padding_idx,
stride_warped);
}
AT_CUDA_CHECK(cudaGetLastError());
// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data_ptr<int64_t>(),
grad_weight.data_ptr<scalar_t>(),
stride,
segment_offsets.data_ptr<int64_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<int64_t>(),
num_of_partial_segments,
padding_idx,
stride_warped);
AT_CUDA_CHECK(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
});
});
});
return grad_weight;

View File

@ -31,12 +31,12 @@ constexpr int MODE_MAX = 2;
// This kernel assumes that all input tensors except `weight` and
// per_sample_weights are contiguous.
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_updateOutputKernel(
int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
int64_t *offset2bag, int64_t numIndices, int64_t numBags,
index_t *input, index_t *offsets, scalar_t *weight, scalar_t *output,
index_t *offset2bag, int64_t numIndices, int64_t numBags,
int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
int mode, int64_t *bag_size, int64_t *max_indices,
int mode, index_t *bag_size, index_t *max_indices,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride) {
// the strategy here is that each bag x feature is handled by a single thread
@ -135,62 +135,65 @@ Tensor embedding_bag_backward_cuda_sum_avg(
auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
using device_ptr = thrust::device_ptr<int64_t>;
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<int64_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<int64_t>());
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
ThrustLTOp<int64_t>());
}
Tensor count;
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<int64_t>());
auto count_data = device_ptr(count.data_ptr<int64_t>());
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
thrust::make_constant_iterator(1),
count_data);
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy, thrust::make_reverse_iterator(sorted_data + numel),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + numel),
thrust::make_reverse_iterator(count_data + numel),
thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
}
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<index_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
ThrustLTOp<index_t>());
}
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
auto count_data = device_ptr(count.data_ptr<index_t>());
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
thrust::make_constant_iterator(1),
count_data);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy, thrust::make_reverse_iterator(sorted_data + numel),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + numel),
thrust::make_reverse_iterator(count_data + numel),
thrust::equal_to<index_t>(), thrust::maximum<index_t>());
}
});
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
count, num_weights, /* padding_idx= */ -1, scale_grad_by_freq,
mode == MODE_MEAN, offset2bag, bag_size, per_sample_weights);
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ void EmbeddingBag_accGradParametersKernel_max(
int64_t *max_indices, scalar_t *gradOutput,
index_t *max_indices, scalar_t *gradOutput,
scalar_t *gradWeight, int64_t stride, int64_t numBags) {
using accscalar_t = acc_type<scalar_t, true>;
@ -205,7 +208,7 @@ __global__ void EmbeddingBag_accGradParametersKernel_max(
if (featureDim < stride) {
int64_t bag = chunk / chunksPerBag;
int64_t word_idx = max_indices[bag * stride + featureDim];
index_t word_idx = max_indices[bag * stride + featureDim];
if (word_idx >= 0) {
// If bag is empty, we have max_indices[idx] set to -1 in forward.
gpuAtomicAdd(&(gradWeight[word_idx * stride + featureDim]),
@ -236,10 +239,12 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
EmbeddingBag_accGradParametersKernel_max<
scalar_t><<<grid, block, 0, stream>>>(
max_indices.data_ptr<int64_t>(), grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), stride, numBags);
AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
EmbeddingBag_accGradParametersKernel_max<
scalar_t, index_t><<<grid, block, 0, stream>>>(
max_indices.data_ptr<index_t>(), grad.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), stride, numBags);
});
});
AT_CUDA_CHECK(cudaGetLastError());
@ -275,9 +280,10 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
const Tensor& per_sample_weights,
bool include_last_offset) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
@ -320,14 +326,16 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
int grid = 1024;
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "embedding_bag_cuda", [&] {
EmbeddingBag_updateOutputKernel<scalar_t><<<grid, block, 0, stream>>>(
indices.data_ptr<int64_t>(), offsets.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<int64_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<int64_t>(),
mode == MODE_MAX ? max_indices.data_ptr<int64_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
EmbeddingBag_updateOutputKernel<scalar_t, index_t><<<grid, block, 0, stream>>>(
indices.data_ptr<index_t>(), offsets.data_ptr<index_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
offset2bag.data_ptr<index_t>(), numIndices, numBags, featureSize,
weight.stride(0), weight.stride(1), mode, bag_size.data_ptr<index_t>(),
mode == MODE_MAX ? max_indices.data_ptr<index_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.data_ptr<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0);
});
});
});
@ -387,12 +395,12 @@ static scalar_t warpReduceSum(scalar_t val) {
return val;
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ static void _embedding_bag_per_sample_weights_backward_kernel(
const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
const int64_t* indices, // contiguous
const int64_t* offset2bag, // contiguous
const index_t* indices, // contiguous
const index_t* offset2bag, // contiguous
int64_t num_samples,
int64_t embedding_features,
scalar_t* output) {
@ -457,16 +465,18 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
_embedding_bag_per_sample_weights_backward_kernel<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
weight.data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
indices.data_ptr<int64_t>(),
offset2bag.data_ptr<int64_t>(),
num_samples,
embedding_features,
output.data_ptr<scalar_t>());
AT_CUDA_CHECK(cudaGetLastError());
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
_embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
weight.data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
indices.data_ptr<index_t>(),
offset2bag.data_ptr<index_t>(),
num_samples,
embedding_features,
output.data_ptr<scalar_t>());
AT_CUDA_CHECK(cudaGetLastError());
});
}
);
return output;

View File

@ -308,10 +308,10 @@ static ptrdiff_t getSliceSize(const Tensor & dst,
// the number of indices chosen is large, then the
// indexAddLargeIndex kernel is a better choice to increase
// parallelism.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
__global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
cuda::detail::TensorInfo<T, IndexType> src,
cuda::detail::TensorInfo<int64_t, IndexType> indices,
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
int dstAddDim,
int srcAddDim,
IndexType innerSize,
@ -324,7 +324,7 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
// We stride over the output ignoring the indexed dimension
@ -351,11 +351,11 @@ __global__ void indexAddSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
// the number of indices chosen is small, then the
// indexAddSmallIndex kernel is a better choice to reduce memory
// accesses.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
bool IndexIsMajor>
__global__ void indexAddLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
cuda::detail::TensorInfo<T, IndexType> src,
cuda::detail::TensorInfo<int64_t, IndexType> indices,
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
int dstAddDim,
int srcAddDim,
IndexType totalSize,
@ -378,7 +378,7 @@ __global__ void indexAddLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
// Lua indices begin at 1
IndexType dstIndex =
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
IndexType dstOffset =
@ -438,7 +438,7 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
checkAllSameGPU("index_add", {self_arg, index_arg, source_arg});
TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_add_(): Expected dtype int64 for index");
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_add_(): Expected dtype int32/int64 for index");
TORCH_CHECK(self.scalar_type() == source.scalar_type(),
"index_add_(): self and source must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < source.dim(),
@ -476,15 +476,15 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
#define SMALL_INDEX(TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
indexAddSmallIndex<TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
indexAddSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
selfInfo, sourceInfo, indexInfo, \
selfAddDim, sourceAddDim, sliceSize, selfAddDimSize);
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
indexAddLargeIndex<TENSOR_TYPE, TYPE, \
indexAddLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
selfInfo, sourceInfo, indexInfo, \
@ -507,49 +507,50 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
int selfAddDim = selfInfo.collapseDims(dim);
selfInfo.reduceDim(selfAddDim);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
auto sourceInfo =
cuda::detail::getTensorInfo<scalar_t, unsigned int>(source_);
int sourceAddDim = sourceInfo.collapseDims(dim);
sourceInfo.reduceDim(sourceAddDim);
auto sourceInfo =
cuda::detail::getTensorInfo<scalar_t, unsigned int>(source_);
int sourceAddDim = sourceInfo.collapseDims(dim);
sourceInfo.reduceDim(sourceAddDim);
auto indexInfo =
cuda::detail::getTensorInfo<index_t, unsigned int>(index);
indexInfo.collapseDims();
auto indexInfo =
cuda::detail::getTensorInfo<int64_t, unsigned int>(index);
indexInfo.collapseDims();
// A reasonable choice for when to have each thread iterate over
// index to choose
if (numIndex <= 16) {
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2);
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2);
} else {
SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1);
}
} else {
bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true);
// A reasonable choice for when to have each thread iterate over
// index to choose
if (numIndex <= 16) {
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
} else {
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false);
}
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true);
} else {
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false);
SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
}
} else {
LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true);
bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
}
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
}
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
}
}
}
});
});
});
} else {
@ -565,11 +566,13 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const
int sourceAddDim = sourceInfo.collapseDims(dim);
sourceInfo.reduceDim(sourceAddDim);
cuda::detail::TensorInfo<int64_t, uint64_t> indexInfo =
cuda::detail::getTensorInfo<int64_t, uint64_t>(index);
indexInfo.collapseDims();
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
cuda::detail::TensorInfo<index_t, uint64_t> indexInfo =
cuda::detail::getTensorInfo<index_t, uint64_t>(index);
indexInfo.collapseDims();
LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true);
LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
});
});
});
}
@ -586,10 +589,10 @@ namespace {
// the number of indices chosen is large, then the
// indexSelectLargeIndex kernel is a better choice to increase
// parallelism.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
__global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
cuda::detail::TensorInfo<T, IndexType> src,
cuda::detail::TensorInfo<int64_t, IndexType> indices,
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
int dstSelectDim,
int srcSelectDim,
IndexType innerSize,
@ -601,7 +604,7 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst
// re-accessing indices in addition to src elements can be slow.
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
IndexType srcIndex =
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
// We stride over the output ignoring the indexed dimension
@ -628,11 +631,11 @@ __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst
// the number of indices chosen is small, then the
// indexSelectSmallIndex kernel is a better choice to reduce memory
// accesses.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
bool IndexIsMajor>
__global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
cuda::detail::TensorInfo<T, IndexType> src,
cuda::detail::TensorInfo<int64_t, IndexType> indices,
cuda::detail::TensorInfo<IndicesType, IndexType> indices,
int dstSelectDim,
int srcSelectDim,
IndexType totalSize,
@ -654,7 +657,7 @@ __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst
}
IndexType srcIndex =
indices.data[cuda::detail::IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
indices.data[cuda::detail::IndexToOffset<IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
IndexType dstOffset =
@ -722,16 +725,16 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
indexSelectSmallIndex<TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
indexSelectSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
outInfo, selfInfo, indicesInfo, \
outSelectDim, selfSelectDim, static_cast<TYPE>(sliceSize), \
selfSelectDimSize);
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
indexSelectLargeIndex<TENSOR_TYPE, TYPE, \
indexSelectLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
outInfo, selfInfo, indicesInfo, \
@ -755,42 +758,44 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
int selfSelectDim = selfInfo.collapseDims(dim);
selfInfo.reduceDim(selfSelectDim);
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<int64_t, unsigned int>(index));
indicesInfo.collapseDims();
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<index_t, unsigned int>(index));
indicesInfo.collapseDims();
// A reasonable choice for when to have each thread iterate over
// indices to choose
if (numIndices <= 16) {
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 1, 1, -2);
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 2, 2, -2);
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
SMALL_INDEX(scalar_t, unsigned int, 3, 3, -2);
} else {
SMALL_INDEX(scalar_t, unsigned int, -1, -1, -1);
}
} else {
bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
LARGE_INDEX(scalar_t, unsigned int, 1, 1, -2, true);
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, true);
// A reasonable choice for when to have each thread iterate over
// indices to choose
if (numIndices <= 16) {
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
} else {
LARGE_INDEX(scalar_t, unsigned int, 2, 2, -2, false);
}
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, true);
} else {
LARGE_INDEX(scalar_t, unsigned int, 3, 3, -2, false);
SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
}
} else {
LARGE_INDEX(scalar_t, unsigned int, -1, -1, -1, true);
bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
} else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
}
} else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
}
} else {
LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
}
}
}
});
} else {
auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(out));
int outSelectDim = outInfo.collapseDims(dim);
@ -799,11 +804,12 @@ void index_select_out_cuda_impl(Tensor& out, const Tensor& self, long dim,
auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(self));
int selfSelectDim = selfInfo.collapseDims(dim);
selfInfo.reduceDim(selfSelectDim);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<index_t, uint64_t>(index));
indicesInfo.collapseDims();
auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<int64_t, uint64_t>(index));
indicesInfo.collapseDims();
LARGE_INDEX(scalar_t, uint64_t, -1, -1, -1, true);
LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
});
}
#undef SMALL_INDEX
#undef LARGE_INDEX

View File

@ -17,7 +17,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -401,7 +401,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -425,7 +425,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -883,7 +883,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -1387,7 +1387,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -1410,7 +1410,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -1987,7 +1987,7 @@ static bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -2514,7 +2514,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {
@ -2538,7 +2538,7 @@ bool Fused8BitRowwiseEmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
bool normalize_by_lengths,
float* out) {

View File

@ -23,7 +23,7 @@ static bool EmbeddingLookupGenericSlowIdx(
const int64_t data_size,
const InType* input,
const IndexType* indices,
const int64_t* offsets,
const IndexType* offsets,
const float* weights, // optional, can be null for sum reducer
const float* scale_bias, // optional scale & bias params for uint8 input
bool normalize_by_lengths,
@ -85,7 +85,7 @@ static bool EmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \
@ -118,7 +118,7 @@ static bool EmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \
@ -163,7 +163,7 @@ static bool EmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \

View File

@ -48,7 +48,7 @@ void EmbeddingLookupIdx(
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const int64_t* offsets,
const IndexType* offsets,
const float* weights, // optional, can be null for non-weighted sum
const float* scale_bias, // optional scale & bias params for uint8 input
bool normalize_by_lengths,

View File

@ -17,7 +17,7 @@ static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -402,7 +402,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -427,7 +427,7 @@ bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
const int64_t data_size,
const float* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -891,7 +891,7 @@ static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -1396,7 +1396,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -1421,7 +1421,7 @@ bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
const int64_t data_size,
const at::Half* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -2005,7 +2005,7 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -2523,7 +2523,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,
@ -2548,7 +2548,7 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
const int64_t data_size,
const uint8_t* input,
const int* indices,
const int64_t* offsets,
const int* offsets,
const float* weights,
const float* scale_bias,
bool normalize_by_lengths,

View File

@ -22,7 +22,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
const int64_t data_size,
const InType* input,
const IndexType* indices,
const int64_t* offsets,
const IndexType* offsets,
const float* weights, // optional, can be null for sum reducer
bool normalize_by_lengths,
OutType* out) {
@ -88,7 +88,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
@ -118,7 +118,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
@ -160,7 +160,7 @@ static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int64_t* offsets, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \

View File

@ -50,7 +50,7 @@ void Fused8BitRowwiseEmbeddingLookupIdx(
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const int64_t* offsets,
const IndexType* offsets,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out);

View File

@ -450,7 +450,7 @@ for o in options:
args.append(" const " + InType + "* input,")
args.append(" const " + IndexType + "* indices,")
if opts.use_offsets:
args.append(" const int64_t* offsets,")
args.append(" const " + IndexType + "* offsets,")
else:
args.append(" const int* lengths,")
args.append(" const float* weights,")

View File

@ -11610,25 +11610,27 @@ class TestNNDeviceType(NNTestCase):
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, []))
self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, []))
def test_embedding_bag_empty_input(self, device):
@dtypes(torch.int, torch.long)
def test_embedding_bag_empty_input(self, device, dtype):
m = 4
n = 3
x = torch.tensor([], device=device, dtype=torch.long)
x = torch.tensor([], device=device, dtype=dtype)
for sparse in [True, False]:
Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
Embed.to(device)
output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long))
output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtype))
self.assertEqual(output, torch.zeros_like(output))
output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=torch.long))
output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtype))
self.assertEqual(output, torch.zeros_like(output))
def test_EmbeddingBag_per_sample_weights_failures(self, device):
@dtypes(torch.int, torch.long)
def test_EmbeddingBag_per_sample_weights_failures(self, device, dtype):
# Failure 1: mismatched embeddings / per_sample_weights dtype
es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
if device == 'cpu':
with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
@ -11638,14 +11640,14 @@ class TestNNDeviceType(NNTestCase):
es(input, offsets, per_sample_weights)
# Failure 2.1: input/per_sample_weights have different sizes (1d input)
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=torch.long, device=device)
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=torch.long, device=device)
input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
es(input, offsets, per_sample_weights)
# Failure 2.2: input/per_sample_weights have different sizes (2d input)
input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
input = torch.randint(5, (7, 3), dtype=dtype, device=device)
offsets = None
per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
with self.assertRaisesRegex(ValueError, 'same shape as the input'):
@ -11655,7 +11657,7 @@ class TestNNDeviceType(NNTestCase):
for unsupported_mode in ('max', 'mean'):
es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
dtype=torch.float, device=device)
input = torch.randint(5, (7, 3), dtype=torch.long, device=device)
input = torch.randint(5, (7, 3), dtype=dtype, device=device)
offsets = None
per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
with self.assertRaisesRegex(NotImplementedError,
@ -11673,7 +11675,8 @@ class TestNNDeviceType(NNTestCase):
assert input.numel() == per_sample_weights.numel()
bags = []
embeddings = weight.index_select(0, input) * per_sample_weights.unsqueeze(1)
long_input = input.to(torch.long)
embeddings = weight.index_select(0, long_input) * per_sample_weights.unsqueeze(1)
if include_last_offset:
for index in range(len(offsets) - 1):
offset = offsets[index]
@ -11698,7 +11701,7 @@ class TestNNDeviceType(NNTestCase):
if index + 1 < len(offsets):
next_offset = offsets[index + 1]
else:
next_offset = len(input)
next_offset = len(long_input)
length = next_offset - offset
if length == 0:
bags.append(
@ -11716,16 +11719,18 @@ class TestNNDeviceType(NNTestCase):
bags.append(embeddings.narrow(0, offset, length).max(0)[0])
return torch.stack(bags)
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device):
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
# Test empty input and per sample weight, and backward pass. There was a CUDA
# invalid configuration bug (more context in #46572)
def test_per_sample_weights(mode, dtype, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
def test_per_sample_weights(mode, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
es.weight.data.copy_(
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
input = torch.tensor([], device=device, dtype=torch.long)
offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=torch.long)
per_sample_weights = torch.randn_like(input, dtype=dtype) \
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
input = torch.tensor([], device=device, dtype=dtypes[0])
offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[0])
per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
.requires_grad_(trainable_scale)
ref_per_sample_weights = \
per_sample_weights.detach().requires_grad_(trainable_scale)
@ -11734,7 +11739,7 @@ class TestNNDeviceType(NNTestCase):
expected = self._embedding_bag_reference_impl(
input, reference_weights, offsets, mode, ref_per_sample_weights)
result = es(input, offsets, per_sample_weights)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
grad = torch.randn_like(expected)
result.backward(grad)
@ -11742,29 +11747,27 @@ class TestNNDeviceType(NNTestCase):
# simply be a zero tensor
ref_weights_grad = torch.zeros_like(es.weight)
self.assertEqual(es.weight.grad, ref_weights_grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if trainable_scale:
ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if device == 'cuda':
dtypes = (torch.float, torch.double, torch.half)
else:
dtypes = (torch.float, torch.double)
modes = ('sum',)
trainable_scale = (True, False)
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
test_per_sample_weights(mode, dtype, trainable)
for mode, trainable in itertools.product(modes, trainable_scale):
test_per_sample_weights(mode, trainable)
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device):
def test_per_sample_weights(mode, dtype, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtype, device=device)
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
def test_per_sample_weights(mode, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
es.weight.data.copy_(
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
per_sample_weights = torch.randn_like(input, dtype=dtype) \
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
.requires_grad_(trainable_scale)
ref_per_sample_weights = \
per_sample_weights.detach().requires_grad_(trainable_scale)
@ -11773,39 +11776,37 @@ class TestNNDeviceType(NNTestCase):
expected = self._embedding_bag_reference_impl(
input, reference_weights, offsets, mode, ref_per_sample_weights)
result = es(input, offsets, per_sample_weights)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
grad = torch.randn_like(expected)
result.backward(grad)
expected.backward(grad)
self.assertEqual(es.weight.grad, reference_weights.grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if trainable_scale:
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if device == 'cuda':
dtypes = (torch.float, torch.double, torch.half)
else:
dtypes = (torch.float, torch.double)
modes = ('sum',)
trainable_scale = (True, False)
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
test_per_sample_weights(mode, dtype, trainable)
for mode, trainable in itertools.product(modes, trainable_scale):
test_per_sample_weights(mode, trainable)
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device):
def test_per_sample_weights_new_offsets(mode, dtype, trainable_scale, include_last_offset, has_weight=True):
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtype, device=device)
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True):
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[1], device=device)
es.weight.data.copy_(
torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
if include_last_offset:
offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=torch.long)), 0)
offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[0])), 0)
if has_weight:
per_sample_weights = torch.randn_like(input, device=device, dtype=dtype) \
per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[1]) \
.requires_grad_(trainable_scale)
ref_per_sample_weights = \
per_sample_weights.detach().requires_grad_(trainable_scale)
@ -11818,51 +11819,48 @@ class TestNNDeviceType(NNTestCase):
expected = self._embedding_bag_reference_impl(
input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
result = es(input, offsets, per_sample_weights)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
grad = torch.randn_like(expected)
result.backward(grad)
expected.backward(grad)
self.assertEqual(es.weight.grad, reference_weights.grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if has_weight and trainable_scale:
self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
if device == 'cuda':
dtypes = (torch.float, torch.double, torch.half)
else:
dtypes = (torch.float, torch.double)
trainable_scale = (True, False)
include_last_offset = (True, False)
modes = (('sum', False), ('sum', True), ('max', False), ('mean', False))
for dtype, (mode, has_weight), trainable, include_last_offset in itertools.product(
dtypes, modes, trainable_scale, include_last_offset
for (mode, has_weight), trainable, include_last_offset in itertools.product(
modes, trainable_scale, include_last_offset
):
test_per_sample_weights_new_offsets(
mode, dtype, trainable, include_last_offset, has_weight
mode, trainable, include_last_offset, has_weight
)
def _test_EmbeddingBag_vs_Embedding(self, N, D, B, L, max_norm=None,
mode='mean',
device='cpu',
dtype=torch.float,
wdtype=torch.float,
dtype=torch.long,
test_per_sample_weights=False,
trainable_per_sample_weights=False,
sparse=False,
test_backward=True,
backward_prec=None):
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, dtype)
e = nn.Embedding(N, D, max_norm=max_norm).to(device, dtype)
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse, max_norm=max_norm).to(device, wdtype)
e = nn.Embedding(N, D, max_norm=max_norm).to(device, wdtype)
e.weight.data.copy_(es.weight)
input = torch.randint(N, (B, L), device=device, dtype=torch.long)
offsets = torch.arange(0, B, device=device, dtype=torch.long).mul_(L)
grad_output = torch.rand(B, D, device=device, dtype=dtype)
input = torch.randint(N, (B, L), device=device, dtype=dtype)
offsets = torch.arange(0, B, device=device, dtype=dtype).mul_(L)
grad_output = torch.rand(B, D, device=device, dtype=wdtype)
if test_per_sample_weights:
# To prevent large gradients, weights should sum to 1 for each bag
per_sample_weights = \
torch.randn(B, L, device=device, dtype=dtype).softmax(dim=-1)
torch.randn(B, L, device=device, dtype=wdtype).softmax(dim=-1)
per_sample_weights_reference = \
per_sample_weights.clone().requires_grad_(trainable_per_sample_weights)
per_sample_weights.requires_grad_(trainable_per_sample_weights)
@ -11884,7 +11882,7 @@ class TestNNDeviceType(NNTestCase):
assert not test_per_sample_weights
ref_output = e(input).max(1)[0]
self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(output, ref_output, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
if not test_backward:
return
@ -11897,7 +11895,7 @@ class TestNNDeviceType(NNTestCase):
# We have more floating point error here because we are dealing with larger numbers
if backward_prec is None:
needed_prec = dtype2prec_DONTUSE[dtype] * 3
needed_prec = dtype2prec_DONTUSE[wdtype] * 3
else:
needed_prec = backward_prec
@ -11905,13 +11903,15 @@ class TestNNDeviceType(NNTestCase):
if test_per_sample_weights and trainable_per_sample_weights:
self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
atol=dtype2prec_DONTUSE[dtype], rtol=0)
atol=dtype2prec_DONTUSE[wdtype], rtol=0)
@skipCUDAIf(True, "Temporarily disabled. See t54369166")
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device):
def run_tests(dtype, mode, sparse, trainable_per_sample_weights):
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.half, torch.float, torch.double)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device, dtypes):
def run_tests(mode, sparse, trainable_per_sample_weights):
kwargs = dict(test_per_sample_weights=True, device=device,
mode=mode, dtype=dtype, sparse=sparse,
mode=mode, wdtype=dtypes[1], dtype=dtypes[0], sparse=sparse,
trainable_per_sample_weights=trainable_per_sample_weights)
# Simple case
@ -11926,78 +11926,76 @@ class TestNNDeviceType(NNTestCase):
# Large embedding_dim
self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
dtypes = (torch.float, torch.double)
modes = ('sum',)
sparsity = (True, False)
trainable_scale = (True, False)
for dtype, mode, sparse, trainable_per_sample_weights in \
itertools.product(dtypes, modes, sparsity, trainable_scale):
run_tests(dtype, mode, sparse, trainable_per_sample_weights)
for mode, sparse, trainable_per_sample_weights in \
itertools.product(modes, sparsity, trainable_scale):
run_tests(mode, sparse, trainable_per_sample_weights)
# Test CUDA Dense on half precision
if device == 'cuda':
dtypes = (torch.half,)
modes = ('sum',)
sparsity = (False,)
trainable_scale = (True, False)
for dtype, mode, sparse, trainable_per_sample_weights in \
itertools.product(dtypes, modes, sparsity, trainable_scale):
run_tests(dtype, mode, sparse, trainable_per_sample_weights)
for mode, sparse, trainable_per_sample_weights in \
itertools.product(modes, sparsity, trainable_scale):
run_tests(mode, sparse, trainable_per_sample_weights)
def _test_EmbeddingBag(self, device, mode, sparse, dtype=torch.double, test_backward=True):
def _test_EmbeddingBag(self, device, mode, sparse, wdtype=torch.double, dtype=torch.long, test_backward=True):
# check a known test example
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype)
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=dtype).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=torch.long)
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=torch.long)
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=wdtype).view_as(es.weight))
input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtype)
grad_output = torch.tensor(
[1, 2,
3, 4], device=device, dtype=dtype).view(2, 2)
3, 4], device=device, dtype=wdtype).view(2, 2)
grad_output_with_empty = torch.tensor(
[99, 99,
1, 2,
99, 99,
3, 4,
99, 99], device=device, dtype=dtype).view(5, 2)
99, 99], device=device, dtype=wdtype).view(5, 2)
if mode == "sum" or mode == "mean":
denominator = 1 if mode == "sum" else 3
expected_output = torch.tensor(
[[13, 16],
[13, 16]], device=device, dtype=dtype) / denominator
[13, 16]], device=device, dtype=wdtype) / denominator
expected_output_with_empty = torch.tensor(
[[0, 0],
[13, 16],
[0, 0],
[13, 16],
[0, 0]], device=device, dtype=dtype) / denominator
[0, 0]], device=device, dtype=wdtype) / denominator
expected_grad_weight = torch.tensor(
[[3, 4],
[5, 8],
[0, 0],
[1, 2],
[3, 4]], device=device, dtype=dtype) / denominator
[3, 4]], device=device, dtype=wdtype) / denominator
elif mode == "max":
expected_output = torch.tensor(
[[7, 8],
[9, 10]], device=device, dtype=dtype)
[9, 10]], device=device, dtype=wdtype)
expected_output_with_empty = torch.tensor(
[[0, 0],
[7, 8],
[0, 0],
[9, 10],
[0, 0]], device=device, dtype=dtype)
[0, 0]], device=device, dtype=wdtype)
expected_grad_weight = torch.tensor(
[[0, 0],
[0, 0],
[0, 0],
[1, 2],
[3, 4]], device=device, dtype=dtype)
[3, 4]], device=device, dtype=wdtype)
output = es(input, offsets)
output.backward(grad_output_with_empty)
@ -12005,7 +12003,7 @@ class TestNNDeviceType(NNTestCase):
if sparse:
es_weight_grad = es.weight.grad.to_dense()
self.assertEqual(output, expected_output_with_empty)
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
# check same example except as 2D (2 x 3)
input = input.view(2, -1)
@ -12017,12 +12015,12 @@ class TestNNDeviceType(NNTestCase):
if sparse:
es_weight_grad = es.weight.grad.to_dense()
self.assertEqual(output, expected_output)
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(es_weight_grad, expected_grad_weight, atol=dtype2prec_DONTUSE[wdtype], rtol=0)
# test all empty bags
es.zero_grad()
inputs = torch.tensor([], dtype=torch.long, device=device)
offsets = torch.tensor([0, 0, 0, 0], device=device)
inputs = torch.tensor([], dtype=dtype, device=device)
offsets = torch.tensor([0, 0, 0, 0], dtype=dtype, device=device)
es(inputs, offsets).sum().backward()
dense_grad = es.weight.grad
if dense_grad.is_sparse:
@ -12031,7 +12029,7 @@ class TestNNDeviceType(NNTestCase):
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward)
kwargs = dict(mode=mode, sparse=sparse, device=device, wdtype=wdtype, dtype=dtype, test_backward=test_backward)
self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
for max_norm in (None, 3):
for p in itertools.product([1, 2], repeat=4):
@ -12039,8 +12037,8 @@ class TestNNDeviceType(NNTestCase):
# check that giving illegal input combos raises error
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
input = torch.ones(3, 4, dtype=torch.long)
offset = torch.arange(0, 3)
input = torch.ones(3, 4, dtype=dtype)
offset = torch.arange(0, 3, dtype=dtype)
self.assertRaises(ValueError, lambda: es(input, offset))
self.assertRaises(ValueError, lambda: es(input.view(-1)))
offset[0] = 1
@ -12050,35 +12048,35 @@ class TestNNDeviceType(NNTestCase):
offset[-1] = 100
self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float, torch.double)
def test_embedding_bag_device(self, device, dtype):
self._test_EmbeddingBag(device, 'sum', False, dtype)
self._test_EmbeddingBag(device, 'mean', False, dtype)
self._test_EmbeddingBag(device, 'max', False, dtype)
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_embedding_bag_device(self, device, dtypes):
self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[1], dtype=dtypes[0])
self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[1], dtype=dtypes[0])
self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[1], dtype=dtypes[0])
test_backward = False
if self.device_type == 'cuda':
# see 'todo' in test_embedding_bag.
test_backward = dtype is not torch.float16
test_backward = dtypes[1] is not torch.float16
elif self.device_type == 'cpu':
# TODO: figure out why precision on sparse embeddings isn't the
# same as for dense.
test_backward = dtype is not torch.float
test_backward = dtypes[1] is not torch.float
self._test_EmbeddingBag(device, 'sum', True, dtype, test_backward=test_backward)
self._test_EmbeddingBag(device, 'mean', True, dtype, test_backward=test_backward)
self._test_EmbeddingBag(device, 'sum', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
self._test_EmbeddingBag(device, 'mean', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float, torch.double)
def test_embedding_bag_non_contiguous_weight(self, device, dtype):
weight_tensor = torch.randn(3, 4, dtype=dtype, device=device)
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
weight_tensor = torch.randn(3, 4, dtype=dtypes[1], device=device)
weight_tensor_non_contig = weight_tensor[:, :3] # This is non-contiguous strided.
weight_tensor_contig = weight_tensor_non_contig.clone().contiguous() # Contig-strided.
index = torch.tensor([0, 1, 2], device=device)
offsets = torch.tensor([0, 2], device=device)
index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
offsets = torch.tensor([0, 2], dtype=dtypes[0], device=device)
for mode in ['sum', 'mean', 'max']:
output_non_contig = F.embedding_bag(
input=index,
@ -12097,9 +12095,10 @@ class TestNNDeviceType(NNTestCase):
@onlyCUDA
@skipCUDAIfNotRocm
def test_embedding_bag_bfloat16(self, device):
self._test_EmbeddingBag(device, 'sum', True, dtype=torch.bfloat16, test_backward=True)
self._test_EmbeddingBag(device, 'mean', True, dtype=torch.bfloat16, test_backward=True)
@dtypes(torch.int, torch.long)
def test_embedding_bag_bfloat16(self, device, dtype):
self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
@onlyCUDA

View File

@ -1618,23 +1618,25 @@ class AbstractTestCases:
reference[0.0, :, 0.0] = 1
def test_index_add(self):
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
for other_sizes in ((), (4, 5)):
num_copy, num_dest = 3, 3
dest = torch.randn(num_dest, *other_sizes)
if not dest_contig:
dest = torch.testing.make_non_contiguous(dest)
src = torch.randn(num_copy, *other_sizes)
if not src_contig:
src = torch.testing.make_non_contiguous(src)
idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
if not index_contig:
idx = torch.testing.make_non_contiguous(idx)
dest2 = dest.clone()
dest.index_add_(0, idx, src)
for i in range(idx.size(0)):
dest2[idx[i]] += src[i]
self.assertEqual(dest, dest2)
for device in torch.testing.get_all_device_types():
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
for other_sizes in ((), (4, 5)):
for dtype in [torch.int, torch.long]:
num_copy, num_dest = 3, 3
dest = torch.randn(num_dest, *other_sizes, device=device)
if not dest_contig:
dest = torch.testing.make_non_contiguous(dest)
src = torch.randn(num_copy, *other_sizes, device=device)
if not src_contig:
src = torch.testing.make_non_contiguous(src)
idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy)
if not index_contig:
idx = torch.testing.make_non_contiguous(idx)
dest2 = dest.clone()
dest.index_add_(0, idx, src)
for i in range(idx.size(0)):
dest2[idx[i]] += src[i]
self.assertEqual(dest, dest2)
# add coverage for issue with atomic add that appeared only for
# specific dtypes on cuda:
@ -1642,23 +1644,24 @@ class AbstractTestCases:
def test_index_add_all_dtypes(self):
for device in torch.testing.get_all_device_types():
for dtype in torch.testing.get_all_math_dtypes(device):
size = [5, 5]
if dtype.is_floating_point or dtype.is_complex:
tensor = torch.rand(size, dtype=dtype, device=device)
elif dtype.is_signed:
tensor = torch.randint(-5, 15, size, dtype=dtype, device=device)
else:
tensor = torch.randint(0, 10, size, dtype=dtype, device=device)
for idx_dtype in [torch.int, torch.long]:
size = [5, 5]
if dtype.is_floating_point or dtype.is_complex:
tensor = torch.rand(size, dtype=dtype, device=device)
elif dtype.is_signed:
tensor = torch.randint(-5, 15, size, dtype=dtype, device=device)
else:
tensor = torch.randint(0, 10, size, dtype=dtype, device=device)
# index_add calls atomicAdd on cuda.
zeros = torch.zeros(size, dtype=dtype, device=device)
# index_add calls atomicAdd on cuda.
zeros = torch.zeros(size, dtype=dtype, device=device)
# index_add is not supported for complex dtypes on cuda yet
if device.startswith('cuda') and dtype.is_complex:
continue
# index_add is not supported for complex dtypes on cuda yet
if device.startswith('cuda') and dtype.is_complex:
continue
added = zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor)
self.assertEqual(added, tensor)
added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor)
self.assertEqual(added, tensor)
def test_t(self):
# Test 0D tensors
@ -12735,36 +12738,37 @@ class TestTorchDeviceType(TestCase):
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device))
def test_index_select(self, device):
src = torch.randn(3, 4, 5, device=device)
# Index can be duplicated.
idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
for dtype in [torch.int, torch.long]:
src = torch.randn(3, 4, 5, device=device)
# Index can be duplicated.
idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
# Check that 'out' is used correctly.
out = torch.randn(5 * 4 * 5, device=device)
dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5))
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
out.fill_(0.123)
self.assertEqual(out, dest.view(-1)) # Must point to the same storage.
# Check that 'out' is used correctly.
out = torch.randn(5 * 4 * 5, device=device)
dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5))
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
out.fill_(0.123)
self.assertEqual(out, dest.view(-1)) # Must point to the same storage.
# Bool tensor
src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool)
idx = torch.tensor([1], dtype=torch.long, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(torch.tensor([True]), dest)
# Bool tensor
src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool)
idx = torch.tensor([1], dtype=dtype, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(torch.tensor([True]), dest)
# Complex Tensor
src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
# Complex Tensor
src = torch.randn(3, 4, 5, dtype=torch.complex64, device=device)
idx = torch.tensor([2, 1, 0, 1, 2], dtype=dtype, device=device)
dest = torch.index_select(src, 0, idx)
self.assertEqual(dest.shape, (5, 4, 5))
for i in range(idx.size(0)):
self.assertEqual(dest[i], src[idx[i]])
def test_take_empty(self, device):
for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:

View File

@ -1678,7 +1678,7 @@ Note:
Args:
dim (int): dimension along which to index
index (LongTensor): indices of :attr:`tensor` to select from
index (IntTensor or LongTensor): indices of :attr:`tensor` to select from
tensor (Tensor): the tensor containing values to add
Example::

View File

@ -3410,7 +3410,7 @@ of :attr:`index`; other dimensions have the same size as in the original tensor.
Args:
{input}
dim (int): the dimension in which we index
index (LongTensor): the 1-D tensor containing the indices to index
index (IntTensor or LongTensor): the 1-D tensor containing the indices to index
Keyword args:
{out}

View File

@ -1940,7 +1940,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
" fixed length sequences. However, found "
"offsets of type {}".format(type_str))
offsets = torch.arange(0, input.numel(), input.size(1),
dtype=torch.long, device=input.device)
dtype=input.dtype, device=input.device)
input = input.reshape(-1)
if per_sample_weights is not None:

View File

@ -34,7 +34,7 @@ class Embedding(Module):
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
@ -54,7 +54,7 @@ class Embedding(Module):
When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
:attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
modified in-place, performing a differentiable operation on ``Embedding.weight`` before
calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
:attr:`max_norm` is not ``None``. For example::
n, d, m = 3, 5, 7
@ -62,7 +62,7 @@ class Embedding(Module):
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t() # modifies weight in-place
b = embedding(idx) @ W.t() # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()
@ -246,9 +246,11 @@ class EmbeddingBag(Module):
weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
initialized from :math:`\mathcal{N}(0, 1)`.
Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and
:attr:`per_index_weights` (Tensor, optional)
- :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
- If :attr:`input` is 2D of shape `(B, N)`,
it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and