mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Move EmbeddingBag into ATen (#4856)
This diff creates code related to EmbeddingBag in ATen. It also allows sparse gradients.
This commit is contained in:
@ -115,7 +115,7 @@ Tensor embedding_backward_cpu(
|
||||
int64_t end = start + (num_weights/nthreads + 1);
|
||||
for (int64_t i = 0; i < numel; i++) {
|
||||
if (indices_data[i] != padding_idx) {
|
||||
int64_t k = indices_data[i] - TH_INDEX_BASE;
|
||||
int64_t k = indices_data[i];
|
||||
if (k >= start && k < end) {
|
||||
double scale = 1.0;
|
||||
if (scale_grad_by_freq) {
|
||||
|
236
aten/src/ATen/native/EmbeddingBag.cpp
Normal file
236
aten/src/ATen/native/EmbeddingBag.cpp
Normal file
@ -0,0 +1,236 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/Check.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include <TH/THBlas.h>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
|
||||
Tensor &offset2bag) {
|
||||
offset2bag.index_fill_(0, offsets, 1); // offset2bag = [1 0 1 0 1]
|
||||
offset2bag[0] = 0; // offset2bag = [0 0 1 0 1]
|
||||
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
|
||||
}
|
||||
|
||||
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
|
||||
const int64_t mode, Tensor &bag_size) {
|
||||
if (mode == 1) { // MODE_MEAN
|
||||
if (offsets.sizes()[0] != 1) {
|
||||
bag_size.slice(0, 0, bag_size.sizes()[0] - 1, 1) =
|
||||
offsets.slice(0, 1, offsets.sizes()[0], 1) -
|
||||
offsets.slice(0, 0, offsets.sizes()[0] - 1, 1);
|
||||
bag_size[-1] = indices.sizes()[0] - offsets[-1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
|
||||
const int64_t mode, Tensor &output,
|
||||
const Tensor &bag_size) {
|
||||
if (mode == 1) { // MODE_MEAN
|
||||
if (offsets.sizes()[0] == 1) {
|
||||
auto bag_size_ = indices.sizes()[0];
|
||||
output /= bag_size_;
|
||||
} else {
|
||||
auto bag_size_ =
|
||||
bag_size.toType(output.type()).unsqueeze(1).expand_as(output);
|
||||
output /= bag_size_;
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
static Tensor apply_bag_size_backward(const Tensor &offsets,
|
||||
const Tensor &indices, const int64_t mode,
|
||||
Tensor &output, const Tensor &offset2bag,
|
||||
const Tensor &bag_size) {
|
||||
if (mode == 1) { // MODE_MEAN
|
||||
if (offsets.sizes()[0] == 1) {
|
||||
auto bag_size_ = indices.sizes()[0];
|
||||
output /= bag_size_;
|
||||
} else {
|
||||
auto bag_size_ = bag_size.toType(output.type())
|
||||
.unsqueeze(1)
|
||||
.index_select(0, offset2bag);
|
||||
output /= bag_size_;
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
|
||||
const Tensor &offsets__, const bool scale_grad_by_freq,
|
||||
const int64_t mode, bool sparse) {
|
||||
auto indices_arg = TensorArg(indices__, "indices__", 1);
|
||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
||||
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
|
||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
||||
Tensor indices = indices__.contiguous();
|
||||
Tensor offsets = offsets__.contiguous();
|
||||
|
||||
auto bag_size = indices.type().zeros(offsets.sizes());
|
||||
auto offset2bag =
|
||||
indices__.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
make_offset2bag(offsets, indices, offset2bag);
|
||||
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
|
||||
auto index_output = weight.index_select(0, indices);
|
||||
output.index_add_(0, offset2bag, index_output);
|
||||
make_bag_size(offsets, indices, mode, bag_size);
|
||||
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
|
||||
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
|
||||
}
|
||||
|
||||
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
|
||||
const Tensor &offsets__,
|
||||
const Tensor &offset2bag__,
|
||||
const Tensor &bag_size_, int64_t num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode,
|
||||
bool sparse) {
|
||||
auto indices_arg = TensorArg(indices__, "indices__", 1);
|
||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
||||
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
|
||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
||||
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
|
||||
checkScalarType("embedding_bag", offset2bag_arg, kLong);
|
||||
checkContiguous("embedding_bag", offset2bag_arg);
|
||||
Tensor indices = indices__.contiguous();
|
||||
Tensor offsets = offsets__.contiguous();
|
||||
|
||||
if (sparse) {
|
||||
return at::embedding_bag_sparse_backward(
|
||||
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
|
||||
scale_grad_by_freq, mode);
|
||||
} else {
|
||||
return at::embedding_bag_dense_backward(
|
||||
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
|
||||
scale_grad_by_freq, mode);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
|
||||
const Tensor &offsets__,
|
||||
const Tensor &offset2bag__,
|
||||
const Tensor &bag_size_, int64_t num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode) {
|
||||
auto grad = grad_.contiguous();
|
||||
auto indices_arg = TensorArg(indices__, "indices__", 1);
|
||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
||||
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
|
||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
||||
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
|
||||
checkScalarType("embedding_bag", offset2bag_arg, kLong);
|
||||
checkContiguous("embedding_bag", offset2bag_arg);
|
||||
Tensor indices_ = indices__.contiguous();
|
||||
Tensor offsets_ = offsets__.contiguous();
|
||||
|
||||
Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);
|
||||
|
||||
auto ind_sort_ = indices_.sort();
|
||||
auto indices = std::get<0>(ind_sort_);
|
||||
auto ind_sort = std::get<1>(ind_sort_);
|
||||
auto offset2bag = offset2bag_.index_select(0, ind_sort);
|
||||
|
||||
auto indices_data = indices.data<int64_t>();
|
||||
auto offsets_data = offsets_.data<int64_t>();
|
||||
auto offset2bag_data = offset2bag.data<int64_t>();
|
||||
int64_t numel = indices.numel();
|
||||
|
||||
std::vector<int64_t> counts(num_weights);
|
||||
for (int i = 0; i < numel; i++) {
|
||||
counts[indices_data[i]] = 0;
|
||||
}
|
||||
for (int i = 0; i < numel; i++) {
|
||||
counts[indices_data[i]]++;
|
||||
}
|
||||
|
||||
std::vector<int64_t> counts_uniq;
|
||||
counts_uniq.reserve(num_weights);
|
||||
int64_t o = 0;
|
||||
for (int64_t i = 0; i < numel; i += counts[indices_data[i]]) {
|
||||
counts_uniq.push_back(counts[indices_data[i]]);
|
||||
if (o > 0) {
|
||||
counts_uniq[o] += counts_uniq[o - 1];
|
||||
}
|
||||
o++;
|
||||
}
|
||||
|
||||
auto index_grad_weight =
|
||||
grad.type().zeros({num_weights, grad.sizes()[1]}).contiguous();
|
||||
|
||||
#pragma omp parallel for if (numel > 1000)
|
||||
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
|
||||
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
|
||||
int64_t index = indices_data[start];
|
||||
for (int64_t j = start; j < counts_uniq[i]; j++) {
|
||||
int64_t source = offset2bag_data[j];
|
||||
double scale = 1.0;
|
||||
if (scale_grad_by_freq) {
|
||||
scale /= counts[indices_data[i]];
|
||||
}
|
||||
if (mode == 1) { // MODE_MEAN
|
||||
if (offsets_.sizes()[0] == 1) {
|
||||
auto bag_size = indices.sizes()[0];
|
||||
scale /= bag_size;
|
||||
} else {
|
||||
if (source == offsets_.sizes()[0] - 1) {
|
||||
scale /= indices.sizes()[0] - offsets_data[offsets_.sizes()[0] - 1];
|
||||
} else {
|
||||
scale /= offsets_data[source + 1] - offsets_data[source];
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t ddim = grad.sizes()[1];
|
||||
if (grad.type().scalarType() == kFloat) {
|
||||
auto igwd = index_grad_weight.data<float>();
|
||||
auto gd = grad.data<float>();
|
||||
THFloatBlas_axpy(ddim, (float)scale, gd + ddim * source, 1,
|
||||
igwd + ddim * index, 1);
|
||||
} else if (grad.type().scalarType() == kDouble) {
|
||||
auto igwd = index_grad_weight.data<double>();
|
||||
auto gd = grad.data<double>();
|
||||
THDoubleBlas_axpy(ddim, (double)scale, gd + ddim * source, 1,
|
||||
igwd + ddim * index, 1);
|
||||
} else {
|
||||
index_grad_weight[index].add_(grad[source], scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return index_grad_weight;
|
||||
}
|
||||
Tensor embedding_bag_sparse_backward(
|
||||
const Tensor &grad_, const Tensor &indices__, const Tensor &offsets__,
|
||||
const Tensor &offset2bag__, const Tensor &bag_size_, int64_t num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode) {
|
||||
auto indices_arg = TensorArg(indices__, "indices__", 1);
|
||||
checkScalarType("embedding_bag", indices_arg, kLong);
|
||||
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
|
||||
checkScalarType("embedding_bag", offsets_arg, kLong);
|
||||
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
|
||||
checkScalarType("embedding_bag", offset2bag_arg, kLong);
|
||||
Tensor indices = indices__.contiguous();
|
||||
Tensor offsets = offsets__.contiguous();
|
||||
Tensor offset2bag = offset2bag__.contiguous();
|
||||
|
||||
Tensor grad = grad_;
|
||||
Tensor index_grad = grad_.index_select(0, offset2bag);
|
||||
index_grad = apply_bag_size_backward(offsets, indices, mode, index_grad,
|
||||
offset2bag, bag_size_);
|
||||
return native::embedding_backward(index_grad, indices, num_weights, -1,
|
||||
scale_grad_by_freq, true);
|
||||
}
|
||||
}
|
||||
} // namespace at::native
|
283
aten/src/ATen/native/cuda/EmbeddingBag.cu
Normal file
283
aten/src/ATen/native/cuda/EmbeddingBag.cu
Normal file
@ -0,0 +1,283 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/Check.h"
|
||||
#include "ATen/Dispatch.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
|
||||
#include "ATen/cuda/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDATensorMethods.cuh"
|
||||
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <THC/THCTensorMathReduce.cuh>
|
||||
#include <THC/THCTensorSort.cuh>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <THCUNN/THCHalfAutoNumerics.cuh>
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
const int WARP_SIZE = 32;
|
||||
const int MODE_SUM = 0;
|
||||
const int MODE_MEAN = 1;
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void EmbeddingBag_updateOutputKernel(
|
||||
int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
|
||||
int64_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t stride,
|
||||
int mode, int64_t *bag_size) {
|
||||
|
||||
// the strategy here is that each bag x feature is handled by a single thread
|
||||
|
||||
using accscalar_t = cuda::acc_type<scalar_t>;
|
||||
int64_t chunksPerBag = THCCeilDiv(stride, (int64_t)blockDim.x);
|
||||
int64_t numChunks = numBags * chunksPerBag;
|
||||
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int64_t chunkStride = gridDim.x * blockDim.y;
|
||||
|
||||
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
|
||||
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
|
||||
if (featureDim < stride) {
|
||||
int64_t bag = chunk / chunksPerBag;
|
||||
scalar_t *weightFeat = weight + featureDim;
|
||||
int64_t begin = offsets[bag];
|
||||
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
|
||||
assert(end >= begin);
|
||||
accscalar_t weightFeatSum = scalar_cast<accscalar_t>(0);
|
||||
int64_t bag_size_ = 0;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
const int weightRow = ((int)input[emb]) * stride;
|
||||
weightFeatSum += scalar_cast<accscalar_t>(weightFeat[weightRow]);
|
||||
bag_size_++;
|
||||
if (featureDim == 0) {
|
||||
offset2bag[emb] = bag;
|
||||
}
|
||||
}
|
||||
if (mode == MODE_MEAN) {
|
||||
weightFeatSum = weightFeatSum / scalar_cast<accscalar_t>(bag_size_);
|
||||
bag_size[bag] = bag_size_;
|
||||
}
|
||||
(void)MODE_SUM; // silence warnings about unused MODE_SUM;
|
||||
output[bag * stride + featureDim] = scalar_cast<scalar_t>(weightFeatSum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: removed the accGradParametersKernelByFeature case present in
|
||||
// LookupTable. That kernel is faster at small sizes (<768 indices), which
|
||||
// does not need EmbeddingBag (LookupTable + Sum works fine), but would
|
||||
// still be nice to not be slow in that case.
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void EmbeddingBag_accGradParametersKernel(
|
||||
int64_t *input, int64_t *indices, scalar_t *gradOutput,
|
||||
scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
|
||||
int64_t stride, int mode, int64_t *bag_size) {
|
||||
|
||||
using accscalar_t = cuda::acc_type<scalar_t>;
|
||||
int idx = blockIdx.x * 4 + threadIdx.y;
|
||||
|
||||
// Each warp is responsible for an input into the LookupTable.
|
||||
// If the preceding input has the same as this input, then the warp
|
||||
// exits immediately. The warp also processes subsequent inputs with the
|
||||
// same value. //
|
||||
// Input Warp
|
||||
// 1 <warp 1>
|
||||
// 1 <warp 1> (<warp 2> exits without doing any work)
|
||||
// 5 <warp 3>
|
||||
// 8 <warp 4>
|
||||
|
||||
// Number of values proceessed by each thread (grain size)
|
||||
const int SZ = 4;
|
||||
|
||||
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
|
||||
do {
|
||||
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
|
||||
const int weightRow = ((int)input[idx]) * stride;
|
||||
|
||||
// Note: only this line changes from LookupTable_accgradParametersKernel
|
||||
const int origRow = ((int)indices[idx]);
|
||||
const int seq_number = offset2bag[origRow];
|
||||
const int gradOutputRow = ((int)seq_number) * stride;
|
||||
|
||||
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
|
||||
|
||||
accscalar_t gradient[SZ];
|
||||
accscalar_t weight[SZ];
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
int featureDim = startFeature + ii * WARP_SIZE;
|
||||
if (featureDim < stride) {
|
||||
gradient[ii] =
|
||||
scalar_cast<accscalar_t>(gradOutput[gradOutputRow + featureDim]);
|
||||
if (mode == MODE_MEAN) {
|
||||
gradient[ii] /= bag_size[seq_number];
|
||||
}
|
||||
weight[ii] =
|
||||
scalar_cast<accscalar_t>(gradWeight[weightRow + featureDim]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
weight[ii] += gradient[ii] * scale;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < SZ; ii++) {
|
||||
int featureDim = startFeature + ii * WARP_SIZE;
|
||||
if (featureDim < stride) {
|
||||
gradWeight[weightRow + featureDim] =
|
||||
scalar_cast<scalar_t>(weight[ii]);
|
||||
}
|
||||
}
|
||||
|
||||
idx++;
|
||||
} while (idx < numel && input[idx] == input[idx - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
|
||||
const Tensor &offsets, const bool scale_grad_by_freq,
|
||||
const int64_t mode, bool sparse) {
|
||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
|
||||
checkContiguous("embedding_bag_cuda", indices_arg);
|
||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
|
||||
checkContiguous("embedding_bag_cuda", offsets_arg);
|
||||
auto weight_arg = TensorArg(weight, "weight", 1);
|
||||
checkContiguous("embedding_bag_cuda", weight_arg);
|
||||
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
|
||||
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
|
||||
|
||||
int64_t numIndices = indices.sizes()[0];
|
||||
int64_t numBags = offsets.sizes()[0];
|
||||
int64_t stride = weight.sizes()[1];
|
||||
|
||||
auto bag_size = indices.type().zeros(offsets.sizes());
|
||||
auto offset2bag =
|
||||
indices.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
|
||||
|
||||
cudaStream_t stream = globalContext().getCurrentCUDAStream();
|
||||
|
||||
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
|
||||
|
||||
dim3 block = dim3(32, 8);
|
||||
int grid = 1024;
|
||||
DISPATCH_ALL_FLOATING_TYPES(weight.type(), "embedding_bag_cuda", [&]() {
|
||||
EmbeddingBag_updateOutputKernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
indices.data<int64_t>(), offsets.data<int64_t>(),
|
||||
weight.data<scalar_t>(), output.data<scalar_t>(),
|
||||
offset2bag.data<int64_t>(), numIndices, numBags, stride, mode,
|
||||
bag_size.data<int64_t>());
|
||||
});
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::tuple<Tensor, Tensor, Tensor>(output, offset2bag, bag_size);
|
||||
}
|
||||
|
||||
Tensor embedding_bag_backward_cuda(const Tensor &grad_, const Tensor &indices,
|
||||
const Tensor &offsets,
|
||||
const Tensor &offset2bag,
|
||||
const Tensor &bag_size_, int64_t num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode) {
|
||||
Tensor grad = grad_.contiguous();
|
||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
|
||||
checkContiguous("embedding_bag_cuda", indices_arg);
|
||||
auto offsets_arg = TensorArg(offsets, "offsets", 1);
|
||||
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
|
||||
checkContiguous("embedding_bag_cuda", offsets_arg);
|
||||
auto grad_arg = TensorArg(grad, "grad", 1);
|
||||
checkContiguous("embedding_bag_cuda", grad_arg);
|
||||
checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg);
|
||||
checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
|
||||
|
||||
Tensor &bag_size = const_cast<Tensor &>(bag_size_);
|
||||
|
||||
auto grad_weight = grad_.type().zeros({num_weights, grad.sizes()[1]});
|
||||
|
||||
int nDim = indices.ndimension();
|
||||
|
||||
ptrdiff_t numel = indices.numel();
|
||||
int64_t stride = grad_weight.stride(0);
|
||||
|
||||
cudaStream_t stream = globalContext().getCurrentCUDAStream();
|
||||
|
||||
auto sorted_indices = indices.type().tensor(indices.sizes());
|
||||
auto orig_indices = indices.type().tensor(indices.sizes());
|
||||
using device_ptr = thrust::device_ptr<int64_t>;
|
||||
|
||||
// Sort the inputs into sorted with the corresponding indices; we
|
||||
// don't need a stable or multidimensional sort, so just use Thrust
|
||||
// directly
|
||||
{
|
||||
sorted_indices.copy_(indices);
|
||||
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Fill sortedOrigIndices with sequential indices
|
||||
auto count_iter = thrust::counting_iterator<int64_t>(0);
|
||||
auto orig_data = device_ptr(orig_indices.data<int64_t>());
|
||||
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
|
||||
|
||||
// Sort; a stable sort is not required
|
||||
auto sorted_data = device_ptr(sorted_indices.data<int64_t>());
|
||||
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
|
||||
ThrustLTOp<int64_t>());
|
||||
}
|
||||
|
||||
Tensor count;
|
||||
if (scale_grad_by_freq) {
|
||||
count = indices.type().tensor(indices.sizes());
|
||||
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Compute an increasing sequence per unique item in sortedIndices:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 1 2 3 1 2 1 1 2
|
||||
auto sorted_data = device_ptr(sorted_indices.data<int64_t>());
|
||||
auto count_data = device_ptr(count.data<int64_t>());
|
||||
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
|
||||
thrust::make_constant_iterator(1),
|
||||
count_data);
|
||||
|
||||
// Take the maximum of each count per unique key in reverse:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 3 3 3 2 2 1 2 2
|
||||
thrust::inclusive_scan_by_key(
|
||||
policy, thrust::make_reverse_iterator(sorted_data + numel),
|
||||
thrust::make_reverse_iterator(sorted_data),
|
||||
thrust::make_reverse_iterator(count_data + numel),
|
||||
thrust::make_reverse_iterator(count_data + numel),
|
||||
thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
|
||||
}
|
||||
|
||||
dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128));
|
||||
dim3 block(32, 4);
|
||||
DISPATCH_ALL_FLOATING_TYPES(
|
||||
grad.type(), "embedding_bag_backward_cuda", [&]() {
|
||||
EmbeddingBag_accGradParametersKernel<
|
||||
scalar_t><<<grid, block, 0, stream>>>(
|
||||
sorted_indices.data<int64_t>(), orig_indices.data<int64_t>(),
|
||||
grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
|
||||
offset2bag.data<int64_t>(),
|
||||
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
|
||||
mode, bag_size.data<int64_t>());
|
||||
});
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_weight;
|
||||
}
|
||||
}
|
||||
}
|
@ -187,6 +187,24 @@
|
||||
- func: empty_like(Tensor self) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: embedding_bag_cpu
|
||||
CUDA: embedding_bag_cuda
|
||||
|
||||
- func: embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: embedding_bag_backward_cpu
|
||||
CUDA: embedding_bag_backward_cuda
|
||||
|
||||
- func: expand(Tensor self, IntList size) -> Tensor
|
||||
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
|
||||
|
||||
|
@ -1205,10 +1205,11 @@ class TestNN(NNTestCase):
|
||||
def test_gumbel_softmax_st_cuda(self):
|
||||
self._test_gumbel_softmax_st(True)
|
||||
|
||||
def _test_EmbeddingBag(self, cuda, mode):
|
||||
def _test_EmbeddingBag(self, cuda, mode, sparse):
|
||||
# check a known test example
|
||||
es = nn.EmbeddingBag(5, 2, mode=mode)
|
||||
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse)
|
||||
es.weight.data.copy_(torch.arange(1, 11).resize_as_(es.weight.data))
|
||||
|
||||
input = Variable(torch.LongTensor([3, 1, 1, 1, 4, 0]))
|
||||
offsets = Variable(torch.LongTensor([0, 3]))
|
||||
grad_output = torch.arange(1, 5).view(2, 2).type(torch.Tensor)
|
||||
@ -1245,8 +1246,11 @@ class TestNN(NNTestCase):
|
||||
output = es(input, offsets)
|
||||
output.backward(grad_output)
|
||||
|
||||
es_weight_grad = es.weight.grad.data
|
||||
if sparse:
|
||||
es_weight_grad = es.weight.grad.data.to_dense()
|
||||
self.assertEqual(output.data, expected_output)
|
||||
self.assertEqual(es.weight.grad.data, expected_grad_weight)
|
||||
self.assertEqual(es_weight_grad, expected_grad_weight)
|
||||
|
||||
# check same example except as 2D (2 x 3)
|
||||
input = Variable(input.data.view(2, -1))
|
||||
@ -1254,12 +1258,15 @@ class TestNN(NNTestCase):
|
||||
output = es(input)
|
||||
output.backward(grad_output)
|
||||
|
||||
es_weight_grad = es.weight.grad.data
|
||||
if sparse:
|
||||
es_weight_grad = es.weight.grad.data.to_dense()
|
||||
self.assertEqual(output.data, expected_output)
|
||||
self.assertEqual(es.weight.grad.data, expected_grad_weight)
|
||||
self.assertEqual(es_weight_grad, expected_grad_weight)
|
||||
|
||||
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
|
||||
def _test_vs_Embedding(N, D, B, L):
|
||||
es = nn.EmbeddingBag(N, D, mode=mode)
|
||||
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse)
|
||||
e = nn.Embedding(N, D)
|
||||
e.weight.data.copy_(es.weight.data)
|
||||
input = Variable(torch.rand(B, L).mul(N).long())
|
||||
@ -1283,7 +1290,10 @@ class TestNN(NNTestCase):
|
||||
|
||||
output.backward(grad_output)
|
||||
ref_output.backward(grad_output)
|
||||
self.assertEqual(es.weight.grad, e.weight.grad)
|
||||
es_weight_grad = es.weight.grad.data
|
||||
if sparse:
|
||||
es_weight_grad = es.weight.grad.data.to_dense()
|
||||
self.assertEqual(es_weight_grad, e.weight.grad.data)
|
||||
|
||||
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
|
||||
_test_vs_Embedding(N, D, B, L)
|
||||
@ -1291,7 +1301,7 @@ class TestNN(NNTestCase):
|
||||
_test_vs_Embedding(*p)
|
||||
|
||||
# check that giving illegal input combos raises error
|
||||
es = nn.EmbeddingBag(10, 20, mode=mode)
|
||||
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
|
||||
input = Variable(torch.ones(3, 4))
|
||||
offset = Variable(torch.arange(0, 3))
|
||||
self.assertRaises(ValueError, lambda: es(input, offset))
|
||||
@ -1348,14 +1358,18 @@ class TestNN(NNTestCase):
|
||||
F.conv_transpose2d(Variable(x), Variable(torch.randn(16, 1, 1, 1)).cuda())
|
||||
F.conv2d(Variable(x), Variable(torch.randn(1, 16, 1, 1)).cuda())
|
||||
|
||||
def test_EmbeddingBag(self):
|
||||
self._test_EmbeddingBag(False, 'sum')
|
||||
self._test_EmbeddingBag(False, 'mean')
|
||||
def test_embedding_bag(self):
|
||||
self._test_EmbeddingBag(False, 'sum', False)
|
||||
self._test_EmbeddingBag(False, 'mean', False)
|
||||
self._test_EmbeddingBag(False, 'sum', True)
|
||||
self._test_EmbeddingBag(False, 'mean', True)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_EmbeddingBag_cuda(self):
|
||||
self._test_EmbeddingBag(True, 'sum')
|
||||
self._test_EmbeddingBag(True, 'mean')
|
||||
def test_embedding_bag_cuda(self):
|
||||
self._test_EmbeddingBag(True, 'sum', False)
|
||||
self._test_EmbeddingBag(True, 'mean', False)
|
||||
self._test_EmbeddingBag(True, 'sum', True)
|
||||
self._test_EmbeddingBag(True, 'mean', True)
|
||||
|
||||
def test_fractional_max_pool2d(self):
|
||||
x = Variable(torch.randn(1, 2, 7, 7), requires_grad=True)
|
||||
@ -5386,6 +5400,20 @@ new_module_tests = [
|
||||
jacobian_input=False,
|
||||
check_gradgrad=False,
|
||||
),
|
||||
dict(
|
||||
module_name='EmbeddingBag',
|
||||
constructor_args=(4, 3),
|
||||
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
|
||||
jacobian_input=False,
|
||||
check_gradgrad=False,
|
||||
),
|
||||
dict(
|
||||
module_name='EmbeddingBag_sparse',
|
||||
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
|
||||
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
|
||||
jacobian_input=False,
|
||||
check_gradgrad=False,
|
||||
),
|
||||
dict(
|
||||
constructor=lambda: nn.Embedding(4, 3, sparse=True),
|
||||
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
|
||||
|
@ -674,6 +674,9 @@
|
||||
- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
|
||||
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
- name: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
|
||||
weight: embedding_bag_backward(grad, indices, offsets, result1, result2, weight.size(0), scale_grad_by_freq, mode, sparse)
|
||||
|
||||
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
|
||||
self: not_implemented("embedding_renorm")
|
||||
|
||||
|
@ -1068,7 +1068,7 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
|
||||
|
||||
|
||||
def embedding_bag(embedding_matrix, indices, offsets=None,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean'):
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False):
|
||||
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
|
||||
intermediate embeddings.
|
||||
|
||||
@ -1093,6 +1093,8 @@ def embedding_bag(embedding_matrix, indices, offsets=None,
|
||||
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
|
||||
the words in the dictionary.
|
||||
mode (string, optional): 'sum' | 'mean'. Specifies the way to reduce the bag. Default: 'mean'
|
||||
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes
|
||||
for more details regarding sparse gradients.
|
||||
|
||||
Shape:
|
||||
- Embedding_matrix: FloatTensor `(V, embedding_dim)`,
|
||||
@ -1131,19 +1133,42 @@ def embedding_bag(embedding_matrix, indices, offsets=None,
|
||||
offsets = Variable(torch.arange(0, indices.numel(), indices.size(1),
|
||||
out=indices.data.new().long()))
|
||||
indices = indices.view(-1)
|
||||
|
||||
elif indices.dim() != 1:
|
||||
elif indices.dim() == 1:
|
||||
if offsets is None:
|
||||
raise ValueError("offsets has to be a 1D Tensor but got None")
|
||||
if offsets.dim() != 1:
|
||||
raise ValueError("offsets has to be a 1D Tensor")
|
||||
if offsets[0] != 0:
|
||||
raise ValueError("offsets[0] has to be 0, i.e. the first sequence"
|
||||
" in the mini-batch has to start from position 0."
|
||||
"However, got {}".format(offsets[0]))
|
||||
if offsets[-1] > indices.size(0):
|
||||
raise ValueError("offsets[-1] has to be smaller than indices's length"
|
||||
" ({}), but got offsets[-1] of {}"
|
||||
.format(indices.size(0), offsets[-1]))
|
||||
else:
|
||||
raise ValueError("input has to be 1D or 2D Tensor,"
|
||||
" but got Tensor of dimension {}".format(indices.dim()))
|
||||
|
||||
if offsets is None:
|
||||
raise ValueError("offsets has to be a 1D Tensor but got None")
|
||||
if mode == 'sum':
|
||||
mode = 0
|
||||
elif mode == 'mean':
|
||||
mode = 1
|
||||
else:
|
||||
raise ValueError("mode has to be one of sum or mean")
|
||||
|
||||
return _functions.thnn.EmbeddingBag.apply(
|
||||
embedding_matrix, indices, offsets,
|
||||
max_norm, norm_type,
|
||||
scale_grad_by_freq, mode
|
||||
)
|
||||
if max_norm is not None:
|
||||
with torch.no_grad():
|
||||
torch._C._VariableFunctions.embedding_renorm_(weight, input, max_norm, norm_type)
|
||||
|
||||
ret, _, _ = torch._C._VariableFunctions.embedding_bag(
|
||||
embedding_matrix,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq,
|
||||
mode,
|
||||
sparse)
|
||||
return ret
|
||||
|
||||
|
||||
def instance_norm(input, weight, bias, saved_running_mean, saved_running_var,
|
||||
|
@ -139,6 +139,8 @@ class EmbeddingBag(Module):
|
||||
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
|
||||
the words in the dictionary.
|
||||
mode (string, optional): 'sum' | 'mean'. Specifies the way to reduce the bag. Default: 'mean'
|
||||
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
|
||||
more details regarding sparse gradients.
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
@ -185,7 +187,7 @@ class EmbeddingBag(Module):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim,
|
||||
max_norm=None, norm_type=2, scale_grad_by_freq=False,
|
||||
mode='mean'):
|
||||
mode='mean', sparse=False):
|
||||
super(EmbeddingBag, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
@ -194,6 +196,7 @@ class EmbeddingBag(Module):
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.mode = mode
|
||||
self.sparse = sparse
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
@ -203,7 +206,7 @@ class EmbeddingBag(Module):
|
||||
def forward(self, input, offsets=None):
|
||||
return F.embedding_bag(self.weight, input, offsets,
|
||||
self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.mode)
|
||||
self.scale_grad_by_freq, self.mode, self.sparse)
|
||||
|
||||
def __repr__(self):
|
||||
s = '{name}({num_embeddings}, {embedding_dim}'
|
||||
|
Reference in New Issue
Block a user