Move EmbeddingBag into ATen (#4856)

This diff creates code related to EmbeddingBag in ATen. It also allows sparse gradients.
This commit is contained in:
cpuhrsch
2018-02-12 14:20:32 -05:00
committed by Sam Gross
parent 177b4509ce
commit 07be53b57f
8 changed files with 623 additions and 27 deletions

View File

@ -115,7 +115,7 @@ Tensor embedding_backward_cpu(
int64_t end = start + (num_weights/nthreads + 1);
for (int64_t i = 0; i < numel; i++) {
if (indices_data[i] != padding_idx) {
int64_t k = indices_data[i] - TH_INDEX_BASE;
int64_t k = indices_data[i];
if (k >= start && k < end) {
double scale = 1.0;
if (scale_grad_by_freq) {

View File

@ -0,0 +1,236 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/NativeFunctions.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <sstream>
#include <vector>
#include <TH/THBlas.h>
#ifdef _OPENMP
#include <omp.h>
#endif
namespace at {
namespace native {
static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
Tensor &offset2bag) {
offset2bag.index_fill_(0, offsets, 1); // offset2bag = [1 0 1 0 1]
offset2bag[0] = 0; // offset2bag = [0 0 1 0 1]
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
}
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] != 1) {
bag_size.slice(0, 0, bag_size.sizes()[0] - 1, 1) =
offsets.slice(0, 1, offsets.sizes()[0], 1) -
offsets.slice(0, 0, offsets.sizes()[0] - 1, 1);
bag_size[-1] = indices.sizes()[0] - offsets[-1];
}
}
}
static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &output,
const Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] == 1) {
auto bag_size_ = indices.sizes()[0];
output /= bag_size_;
} else {
auto bag_size_ =
bag_size.toType(output.type()).unsqueeze(1).expand_as(output);
output /= bag_size_;
}
}
return output;
}
static Tensor apply_bag_size_backward(const Tensor &offsets,
const Tensor &indices, const int64_t mode,
Tensor &output, const Tensor &offset2bag,
const Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
if (offsets.sizes()[0] == 1) {
auto bag_size_ = indices.sizes()[0];
output /= bag_size_;
} else {
auto bag_size_ = bag_size.toType(output.type())
.unsqueeze(1)
.index_select(0, offset2bag);
output /= bag_size_;
}
}
return output;
}
std::tuple<Tensor, Tensor, Tensor>
embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
const Tensor &offsets__, const bool scale_grad_by_freq,
const int64_t mode, bool sparse) {
auto indices_arg = TensorArg(indices__, "indices__", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
Tensor indices = indices__.contiguous();
Tensor offsets = offsets__.contiguous();
auto bag_size = indices.type().zeros(offsets.sizes());
auto offset2bag =
indices__.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag);
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
auto index_output = weight.index_select(0, indices);
output.index_add_(0, offset2bag, index_output);
make_bag_size(offsets, indices, mode, bag_size);
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
}
Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
const Tensor &offsets__,
const Tensor &offset2bag__,
const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
bool sparse) {
auto indices_arg = TensorArg(indices__, "indices__", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkContiguous("embedding_bag", offset2bag_arg);
Tensor indices = indices__.contiguous();
Tensor offsets = offsets__.contiguous();
if (sparse) {
return at::embedding_bag_sparse_backward(
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
scale_grad_by_freq, mode);
} else {
return at::embedding_bag_dense_backward(
grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
scale_grad_by_freq, mode);
}
}
Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
const Tensor &offsets__,
const Tensor &offset2bag__,
const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode) {
auto grad = grad_.contiguous();
auto indices_arg = TensorArg(indices__, "indices__", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkContiguous("embedding_bag", offset2bag_arg);
Tensor indices_ = indices__.contiguous();
Tensor offsets_ = offsets__.contiguous();
Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);
auto ind_sort_ = indices_.sort();
auto indices = std::get<0>(ind_sort_);
auto ind_sort = std::get<1>(ind_sort_);
auto offset2bag = offset2bag_.index_select(0, ind_sort);
auto indices_data = indices.data<int64_t>();
auto offsets_data = offsets_.data<int64_t>();
auto offset2bag_data = offset2bag.data<int64_t>();
int64_t numel = indices.numel();
std::vector<int64_t> counts(num_weights);
for (int i = 0; i < numel; i++) {
counts[indices_data[i]] = 0;
}
for (int i = 0; i < numel; i++) {
counts[indices_data[i]]++;
}
std::vector<int64_t> counts_uniq;
counts_uniq.reserve(num_weights);
int64_t o = 0;
for (int64_t i = 0; i < numel; i += counts[indices_data[i]]) {
counts_uniq.push_back(counts[indices_data[i]]);
if (o > 0) {
counts_uniq[o] += counts_uniq[o - 1];
}
o++;
}
auto index_grad_weight =
grad.type().zeros({num_weights, grad.sizes()[1]}).contiguous();
#pragma omp parallel for if (numel > 1000)
for (int64_t i = 0; i < (int64_t)counts_uniq.size(); i++) {
int64_t start = i == 0 ? 0 : counts_uniq[i - 1];
int64_t index = indices_data[start];
for (int64_t j = start; j < counts_uniq[i]; j++) {
int64_t source = offset2bag_data[j];
double scale = 1.0;
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.sizes()[0] == 1) {
auto bag_size = indices.sizes()[0];
scale /= bag_size;
} else {
if (source == offsets_.sizes()[0] - 1) {
scale /= indices.sizes()[0] - offsets_data[offsets_.sizes()[0] - 1];
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
}
}
}
int64_t ddim = grad.sizes()[1];
if (grad.type().scalarType() == kFloat) {
auto igwd = index_grad_weight.data<float>();
auto gd = grad.data<float>();
THFloatBlas_axpy(ddim, (float)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else if (grad.type().scalarType() == kDouble) {
auto igwd = index_grad_weight.data<double>();
auto gd = grad.data<double>();
THDoubleBlas_axpy(ddim, (double)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else {
index_grad_weight[index].add_(grad[source], scale);
}
}
}
return index_grad_weight;
}
Tensor embedding_bag_sparse_backward(
const Tensor &grad_, const Tensor &indices__, const Tensor &offsets__,
const Tensor &offset2bag__, const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode) {
auto indices_arg = TensorArg(indices__, "indices__", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
Tensor indices = indices__.contiguous();
Tensor offsets = offsets__.contiguous();
Tensor offset2bag = offset2bag__.contiguous();
Tensor grad = grad_;
Tensor index_grad = grad_.index_select(0, offset2bag);
index_grad = apply_bag_size_backward(offsets, indices, mode, index_grad,
offset2bag, bag_size_);
return native::embedding_backward(index_grad, indices, num_weights, -1,
scale_grad_by_freq, true);
}
}
} // namespace at::native

View File

@ -0,0 +1,283 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"
#include "ATen/cuda/AccumulateType.h"
#include "ATen/cuda/CUDATensorMethods.cuh"
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCNumerics.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THCUNN/THCHalfAutoNumerics.cuh>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
const int WARP_SIZE = 32;
const int MODE_SUM = 0;
const int MODE_MEAN = 1;
namespace at {
namespace native {
namespace {
template <typename scalar_t>
__global__ void EmbeddingBag_updateOutputKernel(
int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output,
int64_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t stride,
int mode, int64_t *bag_size) {
// the strategy here is that each bag x feature is handled by a single thread
using accscalar_t = cuda::acc_type<scalar_t>;
int64_t chunksPerBag = THCCeilDiv(stride, (int64_t)blockDim.x);
int64_t numChunks = numBags * chunksPerBag;
int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
int64_t chunkStride = gridDim.x * blockDim.y;
for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
if (featureDim < stride) {
int64_t bag = chunk / chunksPerBag;
scalar_t *weightFeat = weight + featureDim;
int64_t begin = offsets[bag];
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
assert(end >= begin);
accscalar_t weightFeatSum = scalar_cast<accscalar_t>(0);
int64_t bag_size_ = 0;
for (int64_t emb = begin; emb < end; emb++) {
const int weightRow = ((int)input[emb]) * stride;
weightFeatSum += scalar_cast<accscalar_t>(weightFeat[weightRow]);
bag_size_++;
if (featureDim == 0) {
offset2bag[emb] = bag;
}
}
if (mode == MODE_MEAN) {
weightFeatSum = weightFeatSum / scalar_cast<accscalar_t>(bag_size_);
bag_size[bag] = bag_size_;
}
(void)MODE_SUM; // silence warnings about unused MODE_SUM;
output[bag * stride + featureDim] = scalar_cast<scalar_t>(weightFeatSum);
}
}
}
// FIXME: removed the accGradParametersKernelByFeature case present in
// LookupTable. That kernel is faster at small sizes (<768 indices), which
// does not need EmbeddingBag (LookupTable + Sum works fine), but would
// still be nice to not be slow in that case.
template <typename scalar_t>
__global__ void EmbeddingBag_accGradParametersKernel(
int64_t *input, int64_t *indices, scalar_t *gradOutput,
scalar_t *gradWeight, int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
int64_t stride, int mode, int64_t *bag_size) {
using accscalar_t = cuda::acc_type<scalar_t>;
int idx = blockIdx.x * 4 + threadIdx.y;
// Each warp is responsible for an input into the LookupTable.
// If the preceding input has the same as this input, then the warp
// exits immediately. The warp also processes subsequent inputs with the
// same value. //
// Input Warp
// 1 <warp 1>
// 1 <warp 1> (<warp 2> exits without doing any work)
// 5 <warp 3>
// 8 <warp 4>
// Number of values proceessed by each thread (grain size)
const int SZ = 4;
if (idx < numel && (idx == 0 || input[idx] != input[idx - 1])) {
do {
const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
const int weightRow = ((int)input[idx]) * stride;
// Note: only this line changes from LookupTable_accgradParametersKernel
const int origRow = ((int)indices[idx]);
const int seq_number = offset2bag[origRow];
const int gradOutputRow = ((int)seq_number) * stride;
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
accscalar_t gradient[SZ];
accscalar_t weight[SZ];
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride) {
gradient[ii] =
scalar_cast<accscalar_t>(gradOutput[gradOutputRow + featureDim]);
if (mode == MODE_MEAN) {
gradient[ii] /= bag_size[seq_number];
}
weight[ii] =
scalar_cast<accscalar_t>(gradWeight[weightRow + featureDim]);
}
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
weight[ii] += gradient[ii] * scale;
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride) {
gradWeight[weightRow + featureDim] =
scalar_cast<scalar_t>(weight[ii]);
}
}
idx++;
} while (idx < numel && input[idx] == input[idx - 1]);
}
}
}
std::tuple<Tensor, Tensor, Tensor>
embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
const int64_t mode, bool sparse) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
checkContiguous("embedding_bag_cuda", indices_arg);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
checkContiguous("embedding_bag_cuda", offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
checkContiguous("embedding_bag_cuda", weight_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
int64_t numIndices = indices.sizes()[0];
int64_t numBags = offsets.sizes()[0];
int64_t stride = weight.sizes()[1];
auto bag_size = indices.type().zeros(offsets.sizes());
auto offset2bag =
indices.type().zeros({indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
cudaStream_t stream = globalContext().getCurrentCUDAStream();
auto output = weight.type().zeros({offsets.sizes()[0], weight.sizes()[1]});
dim3 block = dim3(32, 8);
int grid = 1024;
DISPATCH_ALL_FLOATING_TYPES(weight.type(), "embedding_bag_cuda", [&]() {
EmbeddingBag_updateOutputKernel<scalar_t><<<grid, block, 0, stream>>>(
indices.data<int64_t>(), offsets.data<int64_t>(),
weight.data<scalar_t>(), output.data<scalar_t>(),
offset2bag.data<int64_t>(), numIndices, numBags, stride, mode,
bag_size.data<int64_t>());
});
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, Tensor>(output, offset2bag, bag_size);
}
Tensor embedding_bag_backward_cuda(const Tensor &grad_, const Tensor &indices,
const Tensor &offsets,
const Tensor &offset2bag,
const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode) {
Tensor grad = grad_.contiguous();
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag_cuda", indices_arg, kLong);
checkContiguous("embedding_bag_cuda", indices_arg);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag_cuda", offsets_arg, kLong);
checkContiguous("embedding_bag_cuda", offsets_arg);
auto grad_arg = TensorArg(grad, "grad", 1);
checkContiguous("embedding_bag_cuda", grad_arg);
checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg);
checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
Tensor &bag_size = const_cast<Tensor &>(bag_size_);
auto grad_weight = grad_.type().zeros({num_weights, grad.sizes()[1]});
int nDim = indices.ndimension();
ptrdiff_t numel = indices.numel();
int64_t stride = grad_weight.stride(0);
cudaStream_t stream = globalContext().getCurrentCUDAStream();
auto sorted_indices = indices.type().tensor(indices.sizes());
auto orig_indices = indices.type().tensor(indices.sizes());
using device_ptr = thrust::device_ptr<int64_t>;
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<int64_t>(0);
auto orig_data = device_ptr(orig_indices.data<int64_t>());
thrust::copy(policy, count_iter, count_iter + numel, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data<int64_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + numel, orig_data,
ThrustLTOp<int64_t>());
}
Tensor count;
if (scale_grad_by_freq) {
count = indices.type().tensor(indices.sizes());
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data<int64_t>());
auto count_data = device_ptr(count.data<int64_t>());
thrust::inclusive_scan_by_key(policy, sorted_data, sorted_data + numel,
thrust::make_constant_iterator(1),
count_data);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy, thrust::make_reverse_iterator(sorted_data + numel),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + numel),
thrust::make_reverse_iterator(count_data + numel),
thrust::equal_to<int64_t>(), thrust::maximum<int64_t>());
}
dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128));
dim3 block(32, 4);
DISPATCH_ALL_FLOATING_TYPES(
grad.type(), "embedding_bag_backward_cuda", [&]() {
EmbeddingBag_accGradParametersKernel<
scalar_t><<<grid, block, 0, stream>>>(
sorted_indices.data<int64_t>(), orig_indices.data<int64_t>(),
grad.data<scalar_t>(), grad_weight.data<scalar_t>(),
offset2bag.data<int64_t>(),
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
mode, bag_size.data<int64_t>());
});
THCudaCheck(cudaGetLastError());
return grad_weight;
}
}
}

View File

@ -187,6 +187,24 @@
- func: empty_like(Tensor self) -> Tensor
variants: function
- func: embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CPU: embedding_bag_cpu
CUDA: embedding_bag_cuda
- func: embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor
variants: function
- func: embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor
variants: function
- func: embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor
variants: function
dispatch:
CPU: embedding_bag_backward_cpu
CUDA: embedding_bag_backward_cuda
- func: expand(Tensor self, IntList size) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.

View File

@ -1205,10 +1205,11 @@ class TestNN(NNTestCase):
def test_gumbel_softmax_st_cuda(self):
self._test_gumbel_softmax_st(True)
def _test_EmbeddingBag(self, cuda, mode):
def _test_EmbeddingBag(self, cuda, mode, sparse):
# check a known test example
es = nn.EmbeddingBag(5, 2, mode=mode)
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse)
es.weight.data.copy_(torch.arange(1, 11).resize_as_(es.weight.data))
input = Variable(torch.LongTensor([3, 1, 1, 1, 4, 0]))
offsets = Variable(torch.LongTensor([0, 3]))
grad_output = torch.arange(1, 5).view(2, 2).type(torch.Tensor)
@ -1245,8 +1246,11 @@ class TestNN(NNTestCase):
output = es(input, offsets)
output.backward(grad_output)
es_weight_grad = es.weight.grad.data
if sparse:
es_weight_grad = es.weight.grad.data.to_dense()
self.assertEqual(output.data, expected_output)
self.assertEqual(es.weight.grad.data, expected_grad_weight)
self.assertEqual(es_weight_grad, expected_grad_weight)
# check same example except as 2D (2 x 3)
input = Variable(input.data.view(2, -1))
@ -1254,12 +1258,15 @@ class TestNN(NNTestCase):
output = es(input)
output.backward(grad_output)
es_weight_grad = es.weight.grad.data
if sparse:
es_weight_grad = es.weight.grad.data.to_dense()
self.assertEqual(output.data, expected_output)
self.assertEqual(es.weight.grad.data, expected_grad_weight)
self.assertEqual(es_weight_grad, expected_grad_weight)
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
def _test_vs_Embedding(N, D, B, L):
es = nn.EmbeddingBag(N, D, mode=mode)
es = nn.EmbeddingBag(N, D, mode=mode, sparse=sparse)
e = nn.Embedding(N, D)
e.weight.data.copy_(es.weight.data)
input = Variable(torch.rand(B, L).mul(N).long())
@ -1283,7 +1290,10 @@ class TestNN(NNTestCase):
output.backward(grad_output)
ref_output.backward(grad_output)
self.assertEqual(es.weight.grad, e.weight.grad)
es_weight_grad = es.weight.grad.data
if sparse:
es_weight_grad = es.weight.grad.data.to_dense()
self.assertEqual(es_weight_grad, e.weight.grad.data)
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
_test_vs_Embedding(N, D, B, L)
@ -1291,7 +1301,7 @@ class TestNN(NNTestCase):
_test_vs_Embedding(*p)
# check that giving illegal input combos raises error
es = nn.EmbeddingBag(10, 20, mode=mode)
es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
input = Variable(torch.ones(3, 4))
offset = Variable(torch.arange(0, 3))
self.assertRaises(ValueError, lambda: es(input, offset))
@ -1348,14 +1358,18 @@ class TestNN(NNTestCase):
F.conv_transpose2d(Variable(x), Variable(torch.randn(16, 1, 1, 1)).cuda())
F.conv2d(Variable(x), Variable(torch.randn(1, 16, 1, 1)).cuda())
def test_EmbeddingBag(self):
self._test_EmbeddingBag(False, 'sum')
self._test_EmbeddingBag(False, 'mean')
def test_embedding_bag(self):
self._test_EmbeddingBag(False, 'sum', False)
self._test_EmbeddingBag(False, 'mean', False)
self._test_EmbeddingBag(False, 'sum', True)
self._test_EmbeddingBag(False, 'mean', True)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_EmbeddingBag_cuda(self):
self._test_EmbeddingBag(True, 'sum')
self._test_EmbeddingBag(True, 'mean')
def test_embedding_bag_cuda(self):
self._test_EmbeddingBag(True, 'sum', False)
self._test_EmbeddingBag(True, 'mean', False)
self._test_EmbeddingBag(True, 'sum', True)
self._test_EmbeddingBag(True, 'mean', True)
def test_fractional_max_pool2d(self):
x = Variable(torch.randn(1, 2, 7, 7), requires_grad=True)
@ -5386,6 +5400,20 @@ new_module_tests = [
jacobian_input=False,
check_gradgrad=False,
),
dict(
module_name='EmbeddingBag',
constructor_args=(4, 3),
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
jacobian_input=False,
check_gradgrad=False,
),
dict(
module_name='EmbeddingBag_sparse',
constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),
jacobian_input=False,
check_gradgrad=False,
),
dict(
constructor=lambda: nn.Embedding(4, 3, sparse=True),
input_fn=lambda: Variable(torch.randperm(2).repeat(1, 2)),

View File

@ -674,6 +674,9 @@
- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
- name: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
weight: embedding_bag_backward(grad, indices, offsets, result1, result2, weight.size(0), scale_grad_by_freq, mode, sparse)
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
self: not_implemented("embedding_renorm")

View File

@ -1068,7 +1068,7 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
def embedding_bag(embedding_matrix, indices, offsets=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean'):
max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False):
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
intermediate embeddings.
@ -1093,6 +1093,8 @@ def embedding_bag(embedding_matrix, indices, offsets=None,
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
the words in the dictionary.
mode (string, optional): 'sum' | 'mean'. Specifies the way to reduce the bag. Default: 'mean'
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes
for more details regarding sparse gradients.
Shape:
- Embedding_matrix: FloatTensor `(V, embedding_dim)`,
@ -1131,19 +1133,42 @@ def embedding_bag(embedding_matrix, indices, offsets=None,
offsets = Variable(torch.arange(0, indices.numel(), indices.size(1),
out=indices.data.new().long()))
indices = indices.view(-1)
elif indices.dim() != 1:
elif indices.dim() == 1:
if offsets is None:
raise ValueError("offsets has to be a 1D Tensor but got None")
if offsets.dim() != 1:
raise ValueError("offsets has to be a 1D Tensor")
if offsets[0] != 0:
raise ValueError("offsets[0] has to be 0, i.e. the first sequence"
" in the mini-batch has to start from position 0."
"However, got {}".format(offsets[0]))
if offsets[-1] > indices.size(0):
raise ValueError("offsets[-1] has to be smaller than indices's length"
" ({}), but got offsets[-1] of {}"
.format(indices.size(0), offsets[-1]))
else:
raise ValueError("input has to be 1D or 2D Tensor,"
" but got Tensor of dimension {}".format(indices.dim()))
if offsets is None:
raise ValueError("offsets has to be a 1D Tensor but got None")
if mode == 'sum':
mode = 0
elif mode == 'mean':
mode = 1
else:
raise ValueError("mode has to be one of sum or mean")
return _functions.thnn.EmbeddingBag.apply(
embedding_matrix, indices, offsets,
max_norm, norm_type,
scale_grad_by_freq, mode
)
if max_norm is not None:
with torch.no_grad():
torch._C._VariableFunctions.embedding_renorm_(weight, input, max_norm, norm_type)
ret, _, _ = torch._C._VariableFunctions.embedding_bag(
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse)
return ret
def instance_norm(input, weight, bias, saved_running_mean, saved_running_var,

View File

@ -139,6 +139,8 @@ class EmbeddingBag(Module):
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
the words in the dictionary.
mode (string, optional): 'sum' | 'mean'. Specifies the way to reduce the bag. Default: 'mean'
sparse (boolean, optional): if ``True``, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
@ -185,7 +187,7 @@ class EmbeddingBag(Module):
def __init__(self, num_embeddings, embedding_dim,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
mode='mean'):
mode='mean', sparse=False):
super(EmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
@ -194,6 +196,7 @@ class EmbeddingBag(Module):
self.scale_grad_by_freq = scale_grad_by_freq
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.mode = mode
self.sparse = sparse
self.reset_parameters()
@ -203,7 +206,7 @@ class EmbeddingBag(Module):
def forward(self, input, offsets=None):
return F.embedding_bag(self.weight, input, offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode)
self.scale_grad_by_freq, self.mode, self.sparse)
def __repr__(self):
s = '{name}({num_embeddings}, {embedding_dim}'