Files
pytorch/caffe2/perfkernels/embedding_lookup_idx.h
Qi Zhou 0ec717c830 Support int32 indices and offsets in nn.EmbeddingBag (#46758)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46758

It's in general helpful to support int32 indices and offsets, especially when such tensors are large and need to be transferred to accelerator backends. Since it may not be very useful to support the combination of int32 indices and int64 offsets, here we enforce that these two must have the same type.

Test Plan: unit tests

Reviewed By: ngimel

Differential Revision: D24470808

fbshipit-source-id: 94b8a1d0b7fc9fe3d128247aa042c04d7c227f0b
2020-11-03 23:33:50 -08:00

58 lines
1.6 KiB
C++

#pragma once
#include <cstdint>
namespace caffe2 {
// clang-format off
/**
* Embedding lookup with reduction.
*
* `input` of size data_size * block_size
* `indices` of size index_size
* `offsets` of size output_size
* `weights` nullptr or array of size index_size
* `out` of size output_size * block_size
*
* Behavior is roughly equivalent to pseudocode:
*
* pos = 0
* for (i = 0..output_size-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] = 0
* start_offset = offsets[i]
* end_offset = offsets[i+1]
* length = end_offset - start_offset
* for (j = start_offset..end_offset-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] += input[indices[pos]*block_size + k] *
* (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0)
* pos += 1
* if (normalize_weights && length > 0)
* for (k = 0..block_size-1)
* out[i*block_size + k] /= length
*
* TODO: make this API also take "offsets" rather than "lengths" to match the
* API for PyTorch's EmbeddingBag
*/
// clang-format on
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
void EmbeddingLookupIdx(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const IndexType* offsets,
const float* weights, // optional, can be null for non-weighted sum
const float* scale_bias, // optional scale & bias params for uint8 input
bool normalize_by_lengths,
OutType* out);
} // namespace caffe2