mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46758 It's in general helpful to support int32 indices and offsets, especially when such tensors are large and need to be transferred to accelerator backends. Since it may not be very useful to support the combination of int32 indices and int64 offsets, here we enforce that these two must have the same type. Test Plan: unit tests Reviewed By: ngimel Differential Revision: D24470808 fbshipit-source-id: 94b8a1d0b7fc9fe3d128247aa042c04d7c227f0b
58 lines
1.6 KiB
C++
58 lines
1.6 KiB
C++
#pragma once
|
|
|
|
#include <cstdint>
|
|
|
|
namespace caffe2 {
|
|
|
|
// clang-format off
|
|
/**
|
|
* Embedding lookup with reduction.
|
|
*
|
|
* `input` of size data_size * block_size
|
|
* `indices` of size index_size
|
|
* `offsets` of size output_size
|
|
* `weights` nullptr or array of size index_size
|
|
* `out` of size output_size * block_size
|
|
*
|
|
* Behavior is roughly equivalent to pseudocode:
|
|
*
|
|
* pos = 0
|
|
* for (i = 0..output_size-1)
|
|
* for (k = 0..block_size-1)
|
|
* out[i*block_size + k] = 0
|
|
* start_offset = offsets[i]
|
|
* end_offset = offsets[i+1]
|
|
* length = end_offset - start_offset
|
|
* for (j = start_offset..end_offset-1)
|
|
* for (k = 0..block_size-1)
|
|
* out[i*block_size + k] += input[indices[pos]*block_size + k] *
|
|
* (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0)
|
|
* pos += 1
|
|
* if (normalize_weights && length > 0)
|
|
* for (k = 0..block_size-1)
|
|
* out[i*block_size + k] /= length
|
|
*
|
|
* TODO: make this API also take "offsets" rather than "lengths" to match the
|
|
* API for PyTorch's EmbeddingBag
|
|
*/
|
|
// clang-format on
|
|
template <
|
|
typename IndexType,
|
|
typename InType,
|
|
typename OutType,
|
|
bool IS_WEIGHT_POSITIONAL = false>
|
|
void EmbeddingLookupIdx(
|
|
const std::int64_t block_size,
|
|
const std::int64_t output_size,
|
|
const std::int64_t index_size,
|
|
const std::int64_t data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const IndexType* offsets,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
const float* scale_bias, // optional scale & bias params for uint8 input
|
|
bool normalize_by_lengths,
|
|
OutType* out);
|
|
|
|
} // namespace caffe2
|