Compare commits

..

1 Commits

Author SHA1 Message Date
73c49ee963 Speed up fx graph iteration by implementing it in C++
ghstack-source-id: af7493f6f73baf00e30a6d5790a601729bd9c900
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
2024-06-08 17:12:47 -07:00
1431 changed files with 5837 additions and 7150 deletions

View File

@ -1,5 +0,0 @@
0.6b
manylinux_2_17
rocm6
04b5df8c8123f90cba3ede7e971e6fbc6040d506
3db6ecbc915893ff967abd6e1b43bd5f54949868873be60dc802086c3863e648

View File

@ -113,18 +113,18 @@ COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
# Install AOTriton (Early fail)
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# Install ccache/sccache (do this last, so we get priority in PATH)
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH
RUN bash ./install_cache.sh && rm install_cache.sh
# Install AOTriton
COPY ci_commit_pins/aotriton.txt aotriton.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# Include BUILD_ENVIRONMENT environment variable in image
ARG BUILD_ENVIRONMENT
ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}

View File

@ -0,0 +1 @@
24a3fe9cb57e5cda3c923df29743f9767194cc27

31
.ci/docker/common/install_aotriton.sh Executable file → Normal file
View File

@ -4,20 +4,21 @@ set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
TARBALL='aotriton.tar.bz2'
# This read command alwasy returns with exit code 1
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
ARCH=$(uname -m)
AOTRITON_DIR="aotriton"
AOTRITON_PINNED_NAME="aotriton" # No .txt extension
AOTRITON_PINNED_COMMIT=$(get_pinned_commit ${AOTRITON_PINNED_NAME})
AOTRITON_INSTALL_PREFIX="$1"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}.tar.bz2"
cd "${AOTRITON_INSTALL_PREFIX}"
# Must use -L to follow redirects
curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}"
ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1)
if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then
echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256},"
echo " which does not match the expected value ${SHA256}."
exit
fi
tar xf "${TARBALL}" && rm -rf "${TARBALL}"
git clone https://github.com/ROCm/aotriton.git "${AOTRITON_DIR}"
cd "${AOTRITON_DIR}"
git checkout "${AOTRITON_PINNED_COMMIT}"
git submodule sync --recursive
git submodule update --init --recursive --force --depth 1
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_COMPRESS_KERNEL=OFF -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=ON
ninja install
mkdir -p "${AOTRITON_INSTALL_PREFIX}"
cp -r install_dir/* "${AOTRITON_INSTALL_PREFIX}"
find /tmp/ -mindepth 1 -delete
rm -rf ~/.triton

View File

@ -105,18 +105,18 @@ COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
# Install AOTriton
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# Install ccache/sccache (do this last, so we get priority in PATH)
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH
RUN bash ./install_cache.sh && rm install_cache.sh
# Install AOTriton
COPY ci_commit_pins/aotriton.txt aotriton.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# Include BUILD_ENVIRONMENT environment variable in image
ARG BUILD_ENVIRONMENT
ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}

View File

@ -62,6 +62,4 @@ readability-string-compare,
'
HeaderFilterRegex: '^(aten/|c10/|torch/).*$'
WarningsAsErrors: '*'
CheckOptions:
misc-header-include-cycle.IgnoredFilesList: 'format.h;ivalue.h;custom_class.h;Dict.h;List.h'
...

View File

@ -1 +1 @@
b829e936f7cc61b48149f5f957a451a38bf2a178
1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0

View File

@ -461,8 +461,15 @@ filegroup(
filegroup(
name = "caffe2_perfkernels_srcs",
srcs = [
"caffe2/perfkernels/adagrad.cc",
"caffe2/perfkernels/embedding_lookup.cc",
"caffe2/perfkernels/embedding_lookup_idx.cc",
"caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc",
"caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc",
"caffe2/perfkernels/fused_nbit_rowwise_conversion.cc",
"caffe2/perfkernels/lstm_unit_cpu_common.cc",
"caffe2/perfkernels/math_cpu_base.cc",
"caffe2/perfkernels/typed_axpy.cc",
],
)

View File

@ -865,13 +865,12 @@ cmake_dependent_option(
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
# CAVEAT: Again, do not check USE_ROCM here Flash Attention2 will error while
# building for sm52 while Mem Eff Attention won't
cmake_dependent_option(
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM" OFF)
Will be disabled if not supported by the platform" ON "USE_CUDA" OFF)
if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")

View File

@ -40,7 +40,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### Untrusted inputs during training and prediction
If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permissions strictly required, and keep your libraries updated with the latest security patches.
If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permisisons strictly required, and keep your libraries updated with the lates security patches.
If applicable, prepare your model against bad inputs and prompt injections. Some recommendations:
- Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using fuzzing for prompt injection).

View File

@ -364,7 +364,7 @@ class TORCH_API Context {
bool enabled_flashSDP = true;
bool enabled_mem_efficientSDP = true;
bool enabled_mathSDP = true;
bool enabled_cudnnSDP = true;
bool enabled_cudnnSDP = false;
#ifdef USE_ROCM
bool benchmark_cudnn = true;
#else
@ -385,11 +385,8 @@ class TORCH_API Context {
? at::LinalgBackend::Cusolver
: at::LinalgBackend::Default;
at::BlasBackend blas_preferred_backend =
#ifdef USE_ROCM
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
#else
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
#endif
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true)
? at::BlasBackend::Cublaslt
: at::BlasBackend::Cublas;
#ifdef C10_MOBILE

View File

@ -143,7 +143,7 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
return at::detail::getXPUHooks().getDeviceFromPtr(data);
default:
TORCH_CHECK(
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
false, "Unsupported device_type: " + c10::to_string(ctx.device_type));
}
}
@ -167,7 +167,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
@ -186,7 +186,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kInt bits ", std::to_string(dtype.bits));
false, "Unsupported kInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
@ -202,7 +202,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLBfloat:
@ -212,7 +212,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLComplex:
@ -228,7 +228,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLBool:
@ -238,11 +238,11 @@ ScalarType toScalarType(const DLDataType& dtype) {
break;
default:
TORCH_CHECK(
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
false, "Unsupported kDLBool bits " + c10::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
TORCH_CHECK(false, "Unsupported code " + c10::to_string(dtype.code));
}
return stype;
}
@ -298,7 +298,9 @@ Tensor fromDLPack(DLManagedTensor* src) {
return fromDLPack(src, std::move(deleter));
}
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
Tensor fromDLPack(
DLManagedTensor* src,
std::function<void(void*)> deleter) {
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
if (!src->dl_tensor.strides) {

View File

@ -19,13 +19,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
auto strides = t->sym_strides();
auto sizes = t->sym_sizes();
for (const auto i : c10::irange(strides.size())) {
// NB: The size oblivious test is written very carefully here. When
// unbacked SymInts are involved, we should try to conservatively report
// if memory overlap /could/ happen under some setting of unbacked
// SymInts. Thus, if I have u0 size, we should assume that this has > 1
// elements (first expression), but if I have a u0 stride, I should NOT
// assume that it is not zero (second expression)
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) {
if (strides[i] == 0 && sizes[i] > 1) {
return MemOverlap::Yes;
}
}

View File

@ -197,7 +197,7 @@ TORCH_API std::ostream& operator<<(
const std::vector<TensorIndex>& tensor_indices);
namespace impl {
inline Tensor applySlice(
static inline Tensor applySlice(
const Tensor& self,
int64_t dim,
c10::SymInt start,
@ -227,7 +227,7 @@ inline Tensor applySlice(
dim, std::move(start), std::move(stop), std::move(step));
}
inline Tensor applySelect(
static inline Tensor applySelect(
const Tensor& self,
int64_t dim,
SymInt index,
@ -266,7 +266,9 @@ inline Tensor applySelect(
return self.select_symint(dim, std::move(index));
}
inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
static inline Tensor boolToIndexingTensorCPUOrCUDA(
const Tensor& self,
bool value) {
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
// false as empty.
if (value) {
@ -276,7 +278,7 @@ inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
}
}
inline Tensor boolToIndexingTensorNonNativeDeviceType(
static inline Tensor boolToIndexingTensorNonNativeDeviceType(
const Tensor& self,
bool value) {
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
@ -288,7 +290,7 @@ inline Tensor boolToIndexingTensorNonNativeDeviceType(
}
}
inline Tensor boolToIndexingTensor(
static inline Tensor boolToIndexingTensor(
const Tensor& self,
bool value,
const at::Device& self_device) {
@ -299,13 +301,13 @@ inline Tensor boolToIndexingTensor(
}
}
inline Tensor scalarToTensorNonNativeDeviceType(
static inline Tensor scalarToTensorNonNativeDeviceType(
const Scalar& v,
const TensorOptions& options) {
return at::scalar_tensor(v, options);
}
inline void recordTensorIndex(
static inline void recordTensorIndex(
const Tensor& tensor,
std::vector<Tensor>& outIndices,
int64_t* dim_ptr) {
@ -315,7 +317,7 @@ inline void recordTensorIndex(
(*dim_ptr)++;
};
inline c10::List<::std::optional<Tensor>> typeConvertIndices(
static inline c10::List<::std::optional<Tensor>> typeConvertIndices(
const Tensor& /*self*/,
std::vector<Tensor>&& indices) {
c10::List<::std::optional<Tensor>> converted_inds;
@ -336,7 +338,7 @@ inline c10::List<::std::optional<Tensor>> typeConvertIndices(
// construct a `std::vector` container to be consumed by the C++
// `count_specified_dimensions` function, which adds 100s of nanoseconds
// overhead and is undesirable.
inline int64_t count_specified_dimensions(
static inline int64_t count_specified_dimensions(
const ArrayRef<TensorIndex>& indices) {
// Count the number of indexed dimensions (everything but ellipsis and None)
int64_t count = 0;
@ -370,7 +372,7 @@ inline int64_t count_specified_dimensions(
//
// The rest of the functions are in `at::indexing::impl` namespace, signifying
// that they shouldn't be used from Python indexing implementation.
inline Tensor scalarToTensor(
static inline Tensor scalarToTensor(
const Scalar& v,
const TensorOptions& options,
const at::Device& self_device) {
@ -385,7 +387,7 @@ inline Tensor scalarToTensor(
// To match numpy semantics:
// As a special case for backwards compatibility,
// strip away unit dimensions from the left of 'src'
inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
size_t first_non1_src = sizes.size();
for (const auto i : c10::irange(sizes.size())) {
// Unbacked SymInt has different behavior, but this is sound because
@ -400,7 +402,7 @@ inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
return sizes.slice(first_non1_src);
}
inline void copy_to(const Tensor& dst, const Tensor& src) {
static inline void copy_to(const Tensor& dst, const Tensor& src) {
if (dst.sym_sizes().equals(src.sym_sizes())) {
// A shortcut to avoid generating hard-coded constant sizes during tracing.
// This is not a perfect solution: when src & dst have different shapes,
@ -419,7 +421,7 @@ inline void copy_to(const Tensor& dst, const Tensor& src) {
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
// indexing functions from Python ]
inline Tensor handleDimInMultiDimIndexing(
static inline Tensor handleDimInMultiDimIndexing(
const Tensor& prev_dim_result,
const Tensor& original_tensor,
const TensorIndex& index,
@ -507,7 +509,7 @@ inline Tensor handleDimInMultiDimIndexing(
namespace impl {
// This mirrors `applySlicing` in
// torch/csrc/autograd/python_variable_indexing.cpp
inline Tensor applySlicing(
static inline Tensor applySlicing(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
std::vector<Tensor>& outIndices,
@ -548,13 +550,13 @@ inline Tensor applySlicing(
}
} // namespace impl
inline Tensor dispatch_index(
static inline Tensor dispatch_index(
const Tensor& self,
std::vector<Tensor>&& indices) {
return self.index(impl::typeConvertIndices(self, std::move(indices)));
}
inline Tensor dispatch_index_put_(
static inline Tensor dispatch_index_put_(
Tensor& self,
std::vector<Tensor>&& indices,
const Tensor& value) {
@ -596,7 +598,7 @@ inline Tensor dispatch_index_put_(
// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
// `disable_slice_optimization` when calling C++ tensor indexing functions from
// Python ]
inline Tensor get_item(
static inline Tensor get_item(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
bool disable_slice_optimization = false) {
@ -662,7 +664,7 @@ inline Tensor get_item(
// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
// tensor indexing functions from Python ]
inline void set_item(
static inline void set_item(
const Tensor& self,
const ArrayRef<TensorIndex>& indices,
const Tensor& value,

View File

@ -22,6 +22,7 @@
#endif
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
#include <c10/util/SmallBuffer.h>
#include <array>
@ -1397,7 +1398,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type));
TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", c10::to_string((int)setup_type));
}
//coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
if (ndim() > 1){

View File

@ -31,7 +31,7 @@ struct TemplateEnv {
// Add a number 'v' to the map at key 'k'
template <typename T>
void d(const std::string& k, const T& v) {
strings_[k] = std::to_string(v);
strings_[k] = c10::to_string(v);
lists_.erase(k);
}

View File

@ -150,7 +150,7 @@ Generator make_generator(Args&&... args) {
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
*/
template <typename T>
inline T * check_generator(std::optional<Generator> gen) {
static inline T * check_generator(std::optional<Generator> gen) {
TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
@ -164,7 +164,7 @@ inline T * check_generator(std::optional<Generator> gen) {
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
*/
template <typename T>
inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
static inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}
@ -177,7 +177,7 @@ namespace detail {
* - The new state tensor must be a torch.ByteTensor
* - Data of the new state tensor must be contiguous
*/
inline void check_rng_state(const c10::TensorImpl& new_state) {
static inline void check_rng_state(const c10::TensorImpl& new_state) {
TORCH_CHECK_TYPE(
new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
"RNG state must be a torch.ByteTensor"

View File

@ -953,7 +953,7 @@ TensorBase make_tensor_base(Args&&... args) {
} // namespace detail
inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
return legacyExtractDispatchKey(t.key_set());
}

View File

@ -21,7 +21,7 @@ namespace impl {
// on TLS.
//
// NB: If there is no valid dispatch key, this will return Undefined
inline DispatchKeySet computeDispatchKeySet(
static inline DispatchKeySet computeDispatchKeySet(
DispatchKeySet ks,
// The key mask lets us eliminate (by zero entries) keys which should not
// be considered for dispatch. There are two cases when we use this:

View File

@ -66,51 +66,51 @@ class Operation {
// treat the last N elements of the stack as a list, looking up
// element i
inline IValue& peek(Stack& stack, size_t i, size_t N) {
static inline IValue& peek(Stack& stack, size_t i, size_t N) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
return *(stack.end() - N + i);
}
inline IValue& peek(Stack* stack, size_t i, size_t N) {
static inline IValue& peek(Stack* stack, size_t i, size_t N) {
return peek(*stack, i, N);
}
inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
return *(stack.end() - N + i);
}
inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
static inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
return peek(*stack, i, N);
}
// treat the last N elements of the stack as a list, looking up the
// slice starting at index i and having length len
inline at::ArrayRef<IValue> peekSlice(
static inline at::ArrayRef<IValue> peekSlice(
const Stack& stack,
size_t i,
size_t len,
size_t N) {
return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
}
inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
return peekSlice(stack, 0, N, N);
}
inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
static inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
return last(*stack, N);
}
inline void drop(Stack& stack, size_t n) {
static inline void drop(Stack& stack, size_t n) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
stack.erase(stack.end() - n, stack.end());
}
inline void drop(Stack* stack, size_t n) {
static inline void drop(Stack* stack, size_t n) {
drop(*stack, n);
}
inline IValue pop(Stack& stack) {
static inline IValue pop(Stack& stack) {
auto r = std::move(stack.back());
stack.pop_back();
return r;
}
inline IValue pop(Stack* stack) {
static inline IValue pop(Stack* stack) {
return pop(*stack);
}
inline std::vector<IValue> pop(Stack& stack, size_t n) {
static inline std::vector<IValue> pop(Stack& stack, size_t n) {
std::vector<IValue> result;
result.reserve(n);
for (const auto i : c10::irange(n)) {
@ -127,7 +127,7 @@ inline std::vector<IValue> pop(Stack& stack, size_t n) {
// b = pop(stack).toTensor();
// a = pop(stack).toInt();
template <typename... Types>
inline void pop(Stack& stack, Types&... args) {
static inline void pop(Stack& stack, Types&... args) {
size_t i = 0;
constexpr size_t N = sizeof...(args);
(void)std::initializer_list<int>{
@ -135,15 +135,15 @@ inline void pop(Stack& stack, Types&... args) {
drop(stack, N);
}
template <typename... Types>
inline void pop(Stack* stack, Types&... args) {
static inline void pop(Stack* stack, Types&... args) {
pop(*stack, args...);
}
template <typename Type>
inline void push_one(Stack& stack, Type&& arg) {
static inline void push_one(Stack& stack, Type&& arg) {
stack.emplace_back(std::forward<Type>(arg));
}
inline void push_one(Stack& stack, c10::TensorOptions options) {
static inline void push_one(Stack& stack, c10::TensorOptions options) {
stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
stack.emplace_back(options.layout());
stack.emplace_back(options.device());
@ -151,15 +151,15 @@ inline void push_one(Stack& stack, c10::TensorOptions options) {
}
template <typename... Types>
inline void push(Stack& stack, Types&&... args) {
static inline void push(Stack& stack, Types&&... args) {
(void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
}
template <typename... Types>
inline void push(Stack* stack, Types&&... args) {
static inline void push(Stack* stack, Types&&... args) {
return push(*stack, std::forward<Types>(args)...);
}
template <class T>
inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
for (T elem : elements) {
stack.push_back(std::move(elem));
}

View File

@ -59,6 +59,13 @@ view_as_complex_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
return std::make_tuple(result, 0);
}
std::tuple<Tensor,optional<int64_t>>
to_other_batch_rule(const Tensor& self, optional<int64_t> self_bdim,
const Tensor& other, optional<int64_t> other_bdim,
bool non_blocking,
bool copy, std::optional<at::MemoryFormat> memory_format) {
return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim);
}
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {

View File

@ -23,7 +23,7 @@ enum class GeluType {
END
};
inline GeluType get_gelutype_enum(const c10::string_view approximate) {
static GeluType get_gelutype_enum(const c10::string_view approximate) {
if (approximate == "none") {
return GeluType::None;
} else if (approximate == "tanh") {
@ -33,7 +33,7 @@ inline GeluType get_gelutype_enum(const c10::string_view approximate) {
}
}
inline std::string gelutype_to_string(const GeluType type) {
static std::string gelutype_to_string(const GeluType type) {
switch(type) {
case GeluType::None: return "none";
case GeluType::Tanh: return "tanh";

View File

@ -28,15 +28,15 @@ using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, con
DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
return (a / b) * c + ((a % b) * c) / b;
}
inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
static inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
return 1 + ((a + 1) * c - 1) / b;
}
inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
int64_t ndim = gradOutput_.ndimension();
for (const auto i : c10::irange(1, ndim)) {
TORCH_CHECK(gradOutput_.size(i) > 0,

View File

@ -75,7 +75,7 @@ namespace {
}
}
inline bool cudnnv8_enabled_check_debug() {
static inline bool cudnnv8_enabled_check_debug() {
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
static uint8_t cudnnv8_debugcount = 0;
@ -86,7 +86,7 @@ inline bool cudnnv8_enabled_check_debug() {
return cudnnv8_flag == 1;
}
inline bool cudnnv8_use_heur_mode_b() {
static inline bool cudnnv8_use_heur_mode_b() {
return is_cudnnv8_heuristic_mode_b();
}
@ -186,7 +186,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
// (which the user can change) and computed inputs (which the user can
// only indirectly affect). It would be an interesting exercise to
// come up with a general framework to handle such situations.)
inline void convolution_shape_check(
static void convolution_shape_check(
CheckedFrom c,
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
@ -212,7 +212,7 @@ inline void convolution_shape_check(
// takes an extra output_padding argument to resolve the ambiguity.
template <typename T>
inline std::vector<T> _conv_output_size(
static inline std::vector<T> _conv_output_size(
ArrayRef<T> input_size, ArrayRef<T> weight_size,
ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
) {
@ -231,14 +231,14 @@ inline std::vector<T> _conv_output_size(
return output_size;
}
inline std::vector<int64_t> conv_output_size(
static inline std::vector<int64_t> conv_output_size(
IntArrayRef input_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
) {
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
}
inline std::vector<c10::SymInt> conv_output_size(
static inline std::vector<c10::SymInt> conv_output_size(
SymIntArrayRef input_size, SymIntArrayRef weight_size,
SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
) {
@ -264,14 +264,14 @@ std::vector<T> _conv_input_size(
return input_size;
}
inline std::vector<c10::SymInt> conv_input_size(
static inline std::vector<c10::SymInt> conv_input_size(
SymIntArrayRef output_size, SymIntArrayRef weight_size,
SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
) {
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
}
inline std::vector<int64_t> conv_input_size(
static inline std::vector<int64_t> conv_input_size(
IntArrayRef output_size, IntArrayRef weight_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
@ -295,27 +295,27 @@ std::vector<T> _conv_weight_size(
return weight_size;
}
inline std::vector<c10::SymInt> conv_weight_size(
static inline std::vector<c10::SymInt> conv_weight_size(
SymIntArrayRef input_size, SymIntArrayRef output_size,
SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
}
inline std::vector<int64_t> conv_weight_size(
static inline std::vector<int64_t> conv_weight_size(
IntArrayRef input_size, IntArrayRef output_size,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
) {
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
}
inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
std::vector<int64_t> shape(dim, 1);
shape[1] = -1;
return bias.reshape(shape);
}
inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
input.scalar_type() == at::kDouble ||
@ -351,7 +351,7 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
@ -378,7 +378,7 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// disable NHWC for float64 input.
if (input.scalar_type() == at::kDouble ||
@ -405,7 +405,7 @@ inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Ten
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
}
inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
@ -417,7 +417,7 @@ inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tenso
return can_use_thnn_channels_last_2d;
}
inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
// check layout only for xpu tensor.
if (!input.is_xpu() || !weight.is_xpu()) {

View File

@ -254,7 +254,7 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<a
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
*/
template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
if (x == 0) {
return INFINITY;
@ -376,7 +376,7 @@ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
scalar_t numer = 1;
@ -394,7 +394,7 @@ C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, sc
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
// Assumes x is close to zero and uses a Taylor expansion.
template <typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
for (int i = 1; i <= 8; ++i) {
@ -412,7 +412,7 @@ C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, sca
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
// To ensure numerical stability, this computation is performed at higher precision.
template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
const accscalar_t total = alpha + beta;
const accscalar_t mean = alpha / total;
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
@ -452,7 +452,7 @@ C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha
// This function inputs total=alpha+beta to make it easy to implement
// Dirichlet reparameterized gradients in terms of Betas.
template<typename scalar_t, typename accscalar_t>
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
accscalar_t x_ = static_cast<accscalar_t>(x);
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
accscalar_t total_ = static_cast<accscalar_t>(total);

View File

@ -6,7 +6,7 @@
namespace at::native {
template<typename scalar_t>
inline std::vector<int> generate_intervals(
static inline std::vector<int> generate_intervals(
scalar_t sample,
int64_t inputSize,
int64_t outputSize,
@ -28,7 +28,7 @@ inline std::vector<int> generate_intervals(
}
template <int64_t ndim>
inline void fractional_max_pool_check_shape(
static inline void fractional_max_pool_check_shape(
const Tensor& input,
const Tensor& randomSamples) {

View File

@ -27,7 +27,7 @@
namespace at::native {
inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
static inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
if (tensor.is_conj()) {
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
} else {
@ -35,7 +35,7 @@ inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
}
}
inline DimVector batched_matrix_contiguous_strides(
static inline DimVector batched_matrix_contiguous_strides(
const IntArrayRef sizes,
const bool f_contig = false) {
// f_contig chooses between the strides of a batch of Fortran (F-contiguous)
@ -62,7 +62,7 @@ inline DimVector batched_matrix_contiguous_strides(
* P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
* matrix starting at Q.data_ptr()[B * M' * N'].
*/
inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
// If src is already in batched column major format, then
// this will be efficient (no reordering of the data will occur)
// because the first transpose will make the tensor contiguous,
@ -75,7 +75,7 @@ inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
/*
* contig chooses between C-contig (true) and F-contig (false)
*/
inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
: cloneBatchedColumnMajor(clone));
@ -92,7 +92,7 @@ inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor&
* which is either the original batch size of the input, or its larger
* broadcasted shape.
*/
inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) {
nrows = (nrows == -1) ? src.size(-2) : nrows;
auto copy_sizes = desired_batch_sizes.has_value()
@ -109,7 +109,7 @@ inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
* Given batches of matrices with arbitrary batch dim,
* computes the number of batches.
*/
inline int64_t batchCount(const Tensor& batched_matrices) {
static inline int64_t batchCount(const Tensor& batched_matrices) {
int64_t result = 1;
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
result *= batched_matrices.size(i);
@ -118,15 +118,15 @@ inline int64_t batchCount(const Tensor& batched_matrices) {
}
// Computes the number of elements of a matrix in a batched matrix tensor
inline int64_t matrixStride(const Tensor& batched_matrices) {
static inline int64_t matrixStride(const Tensor& batched_matrices) {
return batched_matrices.size(-1) * batched_matrices.size(-2);
}
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
}
inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
checkIsMatrix(self, f_name, arg_name);
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
f_name,
@ -134,7 +134,7 @@ inline void squareCheckInputs(const Tensor& self, const char* const f_name, cons
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
}
inline void checkInputsSolver(const Tensor& A,
static inline void checkInputsSolver(const Tensor& A,
const Tensor& B,
const bool left,
const char* const f_name) {
@ -146,14 +146,14 @@ inline void checkInputsSolver(const Tensor& A,
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
}
inline bool is_row_or_column_contiguous(const Tensor& t) {
static inline bool is_row_or_column_contiguous(const Tensor& t) {
// This could be made more general, similar to how it's checked in matmul, which would allow to
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
// We choose to be conservative for simplicity
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
}
inline TransposeType to_transpose_type(const bool contig, const bool conj) {
static inline TransposeType to_transpose_type(const bool contig, const bool conj) {
if (conj) {
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
else { return TransposeType::ConjTranspose; }
@ -261,7 +261,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu
}
// Returns the epsilon value for floating types except half
inline double _get_epsilon(const ScalarType& sc_type) {
static inline double _get_epsilon(const ScalarType& sc_type) {
switch (sc_type) {
case at::ScalarType::Float:
return static_cast<double>(std::numeric_limits<float>::epsilon());
@ -274,7 +274,7 @@ inline double _get_epsilon(const ScalarType& sc_type) {
// Validates input shapes and devices
// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
TORCH_CHECK(self.device() == A.device(),
"Expected b and A to be on the same device, but found b on ",
self.device(), " and A on ", A.device(), " instead.");
@ -293,7 +293,7 @@ inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const ch
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
auto dtype = t.scalar_type();
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
@ -305,13 +305,13 @@ inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, co
// Checks if all the Tensors in a TensorList are of the same dimensions
inline void checkAllSameDim(TensorList tensors, int64_t dim) {
static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
for (auto &t : tensors) {
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
}
}
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
// broadcast the batch dimensions of arg1 and arg2.
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
@ -325,7 +325,7 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
}
inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
// If there's no name we assume we don't want to check the errors
if (name != nullptr) {
linearSolveCheckInputs(arg1, arg2, name);
@ -338,7 +338,7 @@ inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
}
inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
static inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
@ -346,7 +346,7 @@ inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor&
}
// Return a permutation with the given axes moved to the end.
inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
const std::vector<int64_t> a = axes.vec();
const int64_t ndim = self.ndimension();
std::vector<int64_t> perm;
@ -368,7 +368,7 @@ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
}
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
static inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
@ -388,7 +388,7 @@ inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
}
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
static inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
const Tensor& input,
bool reduced) {
int64_t m = input.size(-2), n = input.size(-1);
@ -407,7 +407,7 @@ inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
return std::make_tuple(q_sizes, q_strides, n_columns_q);
}
inline bool svd_uses_cusolver(const Tensor& A) {
static inline bool svd_uses_cusolver(const Tensor& A) {
// if cusolver is available, it is used unconditionally
return A.is_cuda()
&& at::globalContext().hasCuSOLVER()
@ -417,7 +417,7 @@ inline bool svd_uses_cusolver(const Tensor& A) {
// Function used instead of .to so that the original strides are retained
// .to doesn't retain strides and make the output tensor contiguous
inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
auto strided_to = at::empty_strided(original_tensor.sizes(),
original_tensor.strides(),
options);
@ -433,7 +433,7 @@ inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOpti
// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
// be `vec(0, 2, 1, 3)`.
inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
static inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
TORCH_CHECK(
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
"duplicate or invalid dimensions");
@ -453,7 +453,7 @@ inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64
// will reverse a given permutation.
// The reverse permutation array is created by swapping the indices and their
// associated values from the given permutation array.
inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
static inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
int64_t ndim = permutation.size();
std::vector<int64_t> reverse_permutation(ndim);
for (const auto dim_ind : c10::irange(ndim)) {
@ -464,7 +464,7 @@ inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> perm
// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
auto mn = std::min(m, n);
auto mx = std::max(m, n);
if (jobz == 'N') {
@ -484,14 +484,14 @@ inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
// This function checks whether the uplo argument input is valid
// Allowed strings are "u", "U", "l", "L"
inline void checkUplo(const c10::string_view uplo) {
static inline void checkUplo(const c10::string_view uplo) {
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
}
inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
TORCH_CHECK(
result.device() == input.device(),
fn_name,
@ -504,7 +504,7 @@ inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor in
// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
// c10::canCast is used for checking the "safe copy" dtype requirements.
inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
TORCH_CHECK(
can_cast,
@ -514,7 +514,7 @@ inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result
}
// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
bool can_cast = c10::canCast(result_type, out_type);
TORCH_CHECK(
can_cast,
@ -523,7 +523,7 @@ inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType ou
out_name, " with dtype ", out_type);
}
inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
}
@ -538,7 +538,7 @@ inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
*/
inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
return vector_case;
@ -547,7 +547,7 @@ inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other)
/*
Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
*/
inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
}
@ -578,7 +578,7 @@ class BroadcastLinearIndices {
}
};
inline bool is_blas_compatible_column_major_order(const Tensor& input) {
static inline bool is_blas_compatible_column_major_order(const Tensor& input) {
IntArrayRef input_strides = input.strides();
IntArrayRef input_sizes = input.sizes();
auto ndim = input.dim();
@ -599,7 +599,7 @@ inline bool is_blas_compatible_column_major_order(const Tensor& input) {
batch_stride_compatible;
}
inline bool is_blas_compatible_row_major_order(const Tensor& input) {
static inline bool is_blas_compatible_row_major_order(const Tensor& input) {
IntArrayRef input_strides = input.strides();
IntArrayRef input_sizes = input.sizes();
auto ndim = input.dim();

View File

@ -675,6 +675,15 @@ Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::op
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
}
// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages.
static Tensor nll_loss(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index));
}
Tensor nll_loss_nd_symint(
const Tensor& self,
const Tensor& target,

View File

@ -147,7 +147,7 @@ jiterator_also_stringify_as(jiterator_code(
#define CENTRAL_RANGE 0.7
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_erfinv(T y) {
/* Function to calculate inverse error function. Rational approximation
is used to generate an initial approximation, which is then improved to
@ -232,7 +232,7 @@ Date: February 1996
* See note [3-Clause BSD License for the Cephes Math Library].
*/
template <typename scalar_t, bool is_cuda=false>
C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
using acc_t = at::acc_type<scalar_t, is_cuda>;
const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
constexpr acc_t zero = acc_t{0.0};
@ -324,7 +324,7 @@ C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_floa
* N 0
*/
template <typename T>
C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) {
C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) {
T result = 0;
for (size_t i = 0; i <= len; i++) {
result = result * x + A[i];
@ -332,7 +332,7 @@ C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) {
return result;
}
inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
double sign = +1;
double result = 0;
if (x < 0.5) {
@ -350,7 +350,7 @@ inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
return sign * result;
}
inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
float sign = +1;
float result = 0;
if (x < 0.5f) {
@ -372,7 +372,7 @@ inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
inline double calc_digamma(double x) {
static inline double calc_digamma(double x) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
@ -430,7 +430,7 @@ inline double calc_digamma(double x) {
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
inline float calc_digamma(float x) {
static inline float calc_digamma(float x) {
// See [C++ Standard Reference: Gamma Function]
static float PSI_10 = 2.25175258906672110764f;
if (x == 0) {
@ -485,16 +485,16 @@ inline float calc_digamma(float x) {
return result + logf(x) - (0.5f / x) - y;
}
inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
static inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
return calc_digamma(static_cast<float>(a));
}
inline c10::Half calc_digamma(c10::Half a) {
static inline c10::Half calc_digamma(c10::Half a) {
return calc_digamma(static_cast<float>(a));
}
template <typename scalar_t, bool is_cuda=false>
inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
// already blocked if n <= 1
const auto one = scalar_t{1};
return ((n % 2) ? one : -one) *
@ -519,7 +519,7 @@ inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
* See NOTICE for the licenses.
*/
template <typename scalar_t>
scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
const scalar_t denom[], int64_t N) {
// evaluating rational function, i.e., the ratio of two polynomials
// the coefficients for numerator are given by `num` while coeffs for
@ -1061,7 +1061,7 @@ static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
}
template <typename scalar_t>
inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
static inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
/* the calculation of the regularized upper incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
@ -1141,7 +1141,7 @@ inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
}
template <typename scalar_t>
scalar_t calc_igamma(scalar_t a, scalar_t x) {
static inline scalar_t calc_igamma(scalar_t a, scalar_t x) {
/* the calculation of the regularized lower incomplete gamma function
* is done differently based on the values of a and x:
* - if x and/or a is at the boundary of defined region, then assign the
@ -1203,39 +1203,39 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) {
}
template <>
C10_UNUSED inline c10::BFloat16 calc_igamma<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
C10_UNUSED c10::BFloat16 calc_igamma<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
return calc_igamma<float>(float(a), float(x));
}
template <>
C10_UNUSED inline c10::Half calc_igamma<c10::Half>(c10::Half a, c10::Half x) {
C10_UNUSED c10::Half calc_igamma<c10::Half>(c10::Half a, c10::Half x) {
return calc_igamma<float>(float(a), float(x));
}
template <>
C10_UNUSED inline c10::BFloat16 calc_igammac<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
C10_UNUSED c10::BFloat16 calc_igammac<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
return calc_igammac<float>(float(a), float(x));
}
template <>
C10_UNUSED inline c10::Half calc_igammac<c10::Half>(c10::Half a, c10::Half x) {
C10_UNUSED c10::Half calc_igammac<c10::Half>(c10::Half a, c10::Half x) {
return calc_igammac<float>(float(a), float(x));
}
inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); }
template <typename T>
inline T abs_impl(T v) {
static T abs_impl(T v) {
return std::abs(v);
}
template <>
C10_UNUSED inline uint8_t abs_impl(uint8_t v) {
C10_UNUSED uint8_t abs_impl(uint8_t v) {
return v;
}
template <typename T>
inline typename std::enable_if<std::is_integral<T>::value, T>::type
static inline typename std::enable_if<std::is_integral<T>::value, T>::type
calc_gcd(T a, T b) {
a = abs_impl(a);
b = abs_impl(b);
@ -1284,7 +1284,7 @@ C10_HOST_DEVICE c10::complex<T> exp2_impl(c10::complex<T> x) {
* required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1.
*/
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
chbevl(const T x, const T array[], size_t len) {
T b0, b1, b2;
@ -1310,7 +1310,7 @@ chbevl(const T x, const T array[], size_t len) {
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
@ -1336,7 +1336,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
};
template <typename T>
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
@ -1361,7 +1361,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
};
template <typename T>
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
@ -1388,7 +1388,7 @@ chebyshev_coefficients_i1e_A() {
};
template <typename T>
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
@ -1417,7 +1417,7 @@ chebyshev_coefficients_i1e_A() {
};
template <typename T>
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
@ -1443,7 +1443,7 @@ chebyshev_coefficients_i1e_B() {
};
template <typename T>
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
chebyshev_coefficients_i1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
@ -1463,7 +1463,7 @@ chebyshev_coefficients_i1e_B() {
};
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i0(T _x) {
T x = std::abs(_x);
@ -1481,7 +1481,7 @@ calc_i0(T _x) {
}
// Upcast bfloat16 input to float for numerical accuracy purposes
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
/*
* This function is derived from the implementation of the i1 function in the Cephes Math Library.
@ -1493,7 +1493,7 @@ inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i1(T _x) {
T x = std::abs(_x);
@ -1522,7 +1522,7 @@ calc_i1(T _x) {
* of all inputs to convert them into the domain of the approximation.
*/
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_i1e(T _x) {
T x = std::abs(_x);
@ -1549,7 +1549,7 @@ calc_i1e(T _x) {
* (integrated from minus infinity to x) is equal to y.
*/
template <typename T>
inline C10_HOST_DEVICE T calc_ndtri(T y0) {
static inline C10_HOST_DEVICE T calc_ndtri(T y0) {
/* sqrt(2pi) */
constexpr T s2pi = 2.50662827463100050242E0;
@ -1737,7 +1737,7 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) {
template <typename T>
C10_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
C10_HOST_DEVICE static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
erfcx_y100(T y100)
{
switch (static_cast<int>(y100)) {
@ -2148,7 +2148,7 @@ return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682
}
template <typename T>
C10_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
C10_HOST_DEVICE static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
calc_erfcx(T x)
{
if (at::_isnan(x)) {
@ -2188,7 +2188,7 @@ calc_erfcx(T x)
* See NOTICE for the licenses.
*/
template <typename T>
inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
T t = x * c10::frac_sqrt_2<T>;
if (x < T{-1.0}) {
return std::log(calc_erfcx(-t) / 2) - t * t;
@ -2198,7 +2198,7 @@ inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
}
template<typename T>
inline C10_HOST_DEVICE T airy_ai_forward(T x) {
static inline C10_HOST_DEVICE T airy_ai_forward(T x) {
static const T AN[] = {
+3.46538101525629032477e-01,
+1.20075952739645805542e+01,
@ -2377,7 +2377,7 @@ inline C10_HOST_DEVICE T airy_ai_forward(T x) {
} // T airy_ai(T x)
template<typename T>
inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
static inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
@ -2489,7 +2489,7 @@ inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
} // bessel_j0_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
static inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
@ -2597,7 +2597,7 @@ inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
} // bessel_j1_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
static inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
static const T PP[] = {
+7.96936729297347051624e-04,
+8.28352392107440799803e-02,
@ -2712,7 +2712,7 @@ inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
} // bessel_y0_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
static inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
static const T PP[] = {
+7.62125616208173112003e-04,
+7.31397056940917570436e-02,
@ -2826,7 +2826,7 @@ inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
} // bessel_y1_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -2865,12 +2865,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
} // chebyshev_polynomial_t_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
return chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_t_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -2913,12 +2913,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
} // chebyshev_polynomial_u_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
return chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_u_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -2969,12 +2969,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
} // chebyshev_polynomial_v_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
return chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_v_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3029,12 +3029,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
} // chebyshev_polynomial_w_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_w_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3061,17 +3061,17 @@ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
} // hermite_polynomial_h_forward(T x, int64_t n)
template<typename T, bool is_cuda=false, std::enable_if_t<!std::is_floating_point<T>::value, int> = 0>
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
return hermite_polynomial_h_forward(x, static_cast<int64_t>(n));
} // hermite_polynomial_h_forward(T x, T n)
template<typename T, bool is_cuda=false, std::enable_if_t<std::is_floating_point<T>::value, int> = 0>
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast<int64_t>(n) : static_cast<int64_t>(-1));
} // hermite_polynomial_h_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3098,12 +3098,12 @@ inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
} // hermite_polynomial_he_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
return hermite_polynomial_he_forward(x, static_cast<int64_t>(n));
} // hermite_polynomial_he_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3134,12 +3134,12 @@ inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
} // laguerre_polynomial_l_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
return laguerre_polynomial_l_forward(x, static_cast<int64_t>(n));
} // laguerre_polynomial_l_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3174,12 +3174,12 @@ inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
} // legendre_polynomial_p_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
return legendre_polynomial_p_forward(x, static_cast<int64_t>(n));
} // legendre_polynomial_p_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
static const T A[] = {
-4.41534164647933937950e-18,
+3.33079451882223809783e-17,
@ -3268,7 +3268,7 @@ inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
} // modified_bessel_i0_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
static const T A[] = {
+2.77791411276104639959e-18,
-2.11142121435816608115e-17,
@ -3364,7 +3364,7 @@ inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
} // modified_bessel_i1_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
static const T A[] = {
+1.37446543561352307156e-16,
+4.25981614279661018399e-14,
@ -3441,7 +3441,7 @@ inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
} // modified_bessel_k0_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
static const T A[] = {
-7.02386347938628759343e-18,
-2.42744985051936593393e-15,
@ -3519,7 +3519,7 @@ inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
} // modified_bessel_k1_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
static const T A[] = {
+1.37446543561352307156e-16,
+4.25981614279661018399e-14,
@ -3596,7 +3596,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
} // T scaled_modified_bessel_k0_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
static const T A[] = {
-7.02386347938628759343e-18,
-2.42744985051936593393e-15,
@ -3674,7 +3674,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
} // T scaled_modified_bessel_k1_forward(T x)
template<typename T>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3717,12 +3717,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
return shifted_chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
} // shifted_chebyshev_polynomial_t_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3769,12 +3769,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
return shifted_chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
} // shifted_chebyshev_polynomial_u_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3825,12 +3825,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
return shifted_chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
} // shifted_chebyshev_polynomial_v_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
if (n < 0) {
return T(0.0);
}
@ -3881,12 +3881,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
template<typename T, bool is_cuda=false>
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
return shifted_chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
} // shifted_chebyshev_polynomial_w_forward(T x, T n)
template<typename T>
inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
if (std::isinf(x)) {
return T(0.0);
}

View File

@ -26,7 +26,7 @@ DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
namespace padding {
template <int dim>
inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
TORCH_CHECK(padding.size() == 2 * dim,
"padding size is expected to be ", 2 * dim,

View File

@ -48,7 +48,7 @@ DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
namespace {
template <typename dest_t, typename src_t>
inline dest_t
static inline dest_t
safe_downcast(src_t v)
{
TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
@ -58,7 +58,7 @@ safe_downcast(src_t v)
}
template<typename T>
inline T pooling_output_shape_pad_lr(
static inline T pooling_output_shape_pad_lr(
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
bool ceil_mode) {
T outputSize = div_rtn<T>(
@ -75,7 +75,7 @@ inline T pooling_output_shape_pad_lr(
}
template<typename T>
inline T pooling_output_shape(
static inline T pooling_output_shape(
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
TORCH_CHECK(stride != 0, "stride should not be zero");
TORCH_CHECK(pad >= 0,
@ -117,7 +117,7 @@ inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
}
// AveragePool2d/DilatedMaxPool2d (forward)
inline void
static inline void
pool2d_shape_check(
const Tensor& input,
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
@ -164,7 +164,7 @@ pool2d_shape_check(
}
// DilatedMaxPool2d (backward)
inline void
static inline void
max_pool2d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
@ -192,7 +192,7 @@ max_pool2d_backward_shape_check(
}
// AveragePool2d (backward)
inline void
static inline void
avg_pool2d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
@ -218,7 +218,7 @@ avg_pool2d_backward_shape_check(
}
// AveragePool3d/DilatedMaxPool3d (forward)
inline void
static inline void
pool3d_shape_check(
const Tensor& input,
int64_t nslices,
@ -280,7 +280,7 @@ pool3d_shape_check(
"Output size is too small");
}
inline void
static inline void
max_pool3d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
@ -317,7 +317,7 @@ max_pool3d_backward_shape_check(
check_dim_size(indices, ndim, ndim-1, owidth);
}
inline void
static inline void
avg_pool3d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,

View File

@ -24,7 +24,7 @@ namespace native {
// only non-zero result.
template <class T,
typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
T result = 1;
while (b) {
if (b & 1) {
@ -38,13 +38,13 @@ inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
template <class T,
typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
inline HOST_DEVICE T powi(T a, T b) {
static inline HOST_DEVICE T powi(T a, T b) {
return powi_impl(a, b);
}
template <class T,
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
inline HOST_DEVICE T powi(T a, T b) {
static inline HOST_DEVICE T powi(T a, T b) {
if ( b < 0 ) {
if ( a == 1 ) {
return 1;

View File

@ -31,7 +31,7 @@ constexpr scalar_t lower_bound() {
return lim::has_infinity ? -lim::infinity() : lim::lowest();
}
inline Tensor restride_dim(
static inline Tensor restride_dim(
const Tensor& src, int64_t dim,
IntArrayRef replacement_shape
) {
@ -96,13 +96,13 @@ inline std::optional<Tensor> _allreduce_return_trivial(
" but found ", out.option())\
}
inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
}
inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
static inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
ScalarType scalarType = self.scalar_type();
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
@ -111,7 +111,7 @@ inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype
using DimMask = TensorIterator::DimMask;
inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
if (opt_dims.has_value()) {
return DimVector(opt_dims.value());
} else {
@ -121,7 +121,7 @@ inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
}
}
inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
DimMask mask;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
@ -150,7 +150,7 @@ inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keep
return shape;
}
inline void resize_reduction_result(
static void resize_reduction_result(
Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
ScalarType /*dtype*/)
{
@ -167,7 +167,7 @@ inline Tensor create_reduction_result(
return at::empty(shape, self.options().dtype(dtype));
}
inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
if (keepdim) {
return result;
}
@ -182,7 +182,7 @@ inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask,
return result.as_strided(shape, stride);
}
inline TensorIterator make_reduction(
static TensorIterator make_reduction(
const char* name, Tensor& result, const Tensor& self,
at::OptionalIntArrayRef dim_opt,
bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
@ -207,7 +207,7 @@ inline TensorIterator make_reduction(
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
}
inline C10_UNUSED TensorIterator make_reduction(
static C10_UNUSED TensorIterator make_reduction(
const char* name, Tensor& result, const Tensor& self,
at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
// special case for type promotion in mixed precision, improves computational
@ -222,7 +222,7 @@ inline C10_UNUSED TensorIterator make_reduction(
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
}
inline TensorIterator make_reduction(
static TensorIterator make_reduction(
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
ScalarType dtype2) {
@ -259,13 +259,13 @@ inline TensorIterator make_reduction(
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
}
inline C10_UNUSED TensorIterator make_reduction(
static C10_UNUSED TensorIterator make_reduction(
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
}
inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
if (self.ndimension() == 0) {
TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
": Expected reduction dim -1 or 0 for scalar but got ", dim);
@ -276,7 +276,7 @@ inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const c
}
}
inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
TORCH_CHECK(
!dim.empty(),
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
@ -286,7 +286,7 @@ inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, con
}
}
inline std::vector<int64_t> get_zero_numel_tensor_size(
static std::vector<int64_t> get_zero_numel_tensor_size(
const Tensor& self,
const int64_t dim,
const bool keepdim,
@ -313,7 +313,7 @@ inline std::vector<int64_t> get_zero_numel_tensor_size(
// This function should be called when you are reducing a zero-numel tensor and want to
// resize the output and return it. This function exists for resizing zero-numel
// tensors when the size of the reduction dimension is non-zero.
inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
const Tensor& self, const int64_t dim,
const bool keepdim, const char *fn_name) {
auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
@ -349,7 +349,7 @@ inline ScalarType get_dtype_from_result(Tensor& result, std::optional<ScalarType
namespace at::meta {
inline C10_UNUSED DimVector get_reduction_shape(
static C10_UNUSED DimVector get_reduction_shape(
const Tensor& self,
IntArrayRef dims,
bool keepdim,
@ -358,7 +358,7 @@ inline C10_UNUSED DimVector get_reduction_shape(
return native::shape_from_dim_mask(self, mask, keepdim);
}
inline void resize_reduction(
static void resize_reduction(
impl::MetaBase& meta,
const Tensor& self,
OptionalIntArrayRef opt_dims,
@ -379,7 +379,7 @@ inline void resize_reduction(
meta.maybe_get_output(), self, dims_, keepdim);
}
inline void resize_reduction_with_indices(
static void resize_reduction_with_indices(
impl::MetaBase& meta,
const Tensor& self,
IntArrayRef dims,
@ -396,7 +396,7 @@ inline void resize_reduction_with_indices(
meta.maybe_get_output(1), self, dims_, keepdim);
}
inline TensorIterator make_reduction(
static TensorIterator make_reduction(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef opt_dims,
@ -412,7 +412,7 @@ inline TensorIterator make_reduction(
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
}
inline TensorIterator make_reduction(
static TensorIterator make_reduction(
const Tensor& self,
const Tensor& result1,
const Tensor& result2,
@ -434,7 +434,7 @@ inline TensorIterator make_reduction(
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
}
inline C10_UNUSED TensorIterator make_reduction_from_out_ty(
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef opt_dims,

View File

@ -6,7 +6,7 @@ namespace at::native {
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "max" || reduce == "amax") {
return ReductionType::MAX;
} else if (reduce == "mean") {
@ -23,7 +23,7 @@ inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
}
// used for `scatter_reduce`, old options for BC.
inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
if (use_new_options) {
return get_reduction_enum(reduce);
} else {

View File

@ -40,7 +40,7 @@ TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes);
inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
@ -79,7 +79,7 @@ template <>
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
template <typename T>
inline void checkInBoundsForStorage(
static inline void checkInBoundsForStorage(
ArrayRef<T> size,
ArrayRef<T> stride,
T storage_offset,
@ -111,7 +111,7 @@ inline void checkInBoundsForStorage(
}
template <typename T>
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
ArrayRef<T> size, ArrayRef<T> stride) {
// FIXME: stride should be optional
if (stride.data()) {

View File

@ -440,6 +440,15 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
}
}
static Tensor softmax(const Tensor& input_, const int64_t dim_) {
auto result = [&]() {
NoNamesGuard guard;
return at::_softmax(input_, dim_, false);
}();
namedinference::propagate_names(result, input_);
return result;
}
Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;
@ -496,6 +505,15 @@ Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional<S
return at::softmax(input_, dim_, dtype);
}
static Tensor log_softmax(const Tensor& input_, const int64_t dim_) {
auto result = [&]() {
NoNamesGuard guard;
return at::_log_softmax(input_, dim_, false);
}();
namedinference::propagate_names(result, input_);
return result;
}
Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;

View File

@ -172,10 +172,18 @@ Tensor arange(
return at::arange_out(result, start, end, step);
}
static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) {
return at::arange_out(result, start, end, /*step=*/1);
}
Tensor& arange_out(const Scalar& end, Tensor& result) {
return at::arange_out(result, /*start=*/0, end, /*step=*/1);
}
static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) {
return at::arange_out(result, start, end, /*step=*/1);
}
Tensor _dim_arange(const Tensor& like, int64_t dim) {
return at::arange(like.size(dim), like.options().dtype(at::kLong));
}

View File

@ -105,6 +105,10 @@ Tensor & detach_(Tensor & self) {
return self;
}
static Tensor contiguous(const Tensor & self) {
return contiguous(self, MemoryFormat::Contiguous);
}
Tensor contiguous(const Tensor& self, MemoryFormat memory_format) {
if (self.is_contiguous(memory_format)) {
return self;

View File

@ -1181,6 +1181,14 @@ Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef s
return result;
}
static Tensor as_strided_tensorimpl_meta(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result = at::detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
setStrided(result, size, stride, storage_offset);
return result;
}
template <typename T>
inline void setStridedUnchecked(
const Tensor& self,
@ -1241,6 +1249,10 @@ const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymInt
return self;
}
static Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
}
// Should just use narrow_copy_out, but this API is used internally at Meta:
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
@ -3575,6 +3587,10 @@ Tensor view_as(const Tensor& self, const Tensor& other) {
return self.view_symint(other.sym_sizes());
}
static int64_t numel(const Tensor& self) {
return self.unsafeGetTensorImpl()->numel();
}
std::vector<Tensor> unbind(const Tensor &self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim());
int64_t size = self.size(dim);

View File

@ -180,6 +180,10 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
compute_triu_tril<UpperTriangle>(self, k, result);
}
static Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) {
return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes));
}
Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");

View File

@ -210,7 +210,7 @@ std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArray
return {nbatch, channels, output_depth, output_height, output_width};
}
inline void upsample_2d_shape_check(
static inline void upsample_2d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
@ -251,7 +251,7 @@ inline void upsample_2d_shape_check(
}
template <typename scalar_t>
inline scalar_t compute_scales_value(
static inline scalar_t compute_scales_value(
const std::optional<double> scale,
int64_t input_size,
int64_t output_size) {
@ -263,7 +263,7 @@ inline scalar_t compute_scales_value(
}
template <typename scalar_t>
inline scalar_t area_pixel_compute_scale(
static inline scalar_t area_pixel_compute_scale(
int64_t input_size,
int64_t output_size,
bool align_corners,
@ -281,7 +281,7 @@ inline scalar_t area_pixel_compute_scale(
}
template <typename scalar_t>
inline scalar_t area_pixel_compute_source_index(
static inline scalar_t area_pixel_compute_source_index(
scalar_t scale,
int64_t dst_index,
bool align_corners,
@ -308,7 +308,7 @@ inline scalar_t area_pixel_compute_source_index(
}
}
inline int64_t nearest_neighbor_compute_source_index(
static inline int64_t nearest_neighbor_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
@ -319,7 +319,7 @@ inline int64_t nearest_neighbor_compute_source_index(
return src_index;
}
inline int64_t nearest_neighbor_exact_compute_source_index(
static inline int64_t nearest_neighbor_exact_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
@ -331,7 +331,7 @@ inline int64_t nearest_neighbor_exact_compute_source_index(
return src_index;
}
inline int64_t nearest_idx(
static inline int64_t nearest_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
@ -352,7 +352,7 @@ inline int64_t nearest_idx(
}
}
inline int64_t nearest_exact_idx(
static inline int64_t nearest_exact_idx(
int64_t output_index,
int64_t input_size,
int64_t output_size,
@ -392,17 +392,17 @@ static void upsample_increment_value_bounded(
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename scalar_t>
inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename scalar_t>
inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename scalar_t>
inline void get_cubic_upsample_coefficients(
static inline void get_cubic_upsample_coefficients(
scalar_t coeffs[4],
scalar_t t) {
scalar_t A = -0.75;
@ -418,7 +418,7 @@ inline void get_cubic_upsample_coefficients(
}
template <typename scalar_t>
inline scalar_t cubic_interp1d(
static inline scalar_t cubic_interp1d(
scalar_t x0,
scalar_t x1,
scalar_t x2,
@ -434,7 +434,7 @@ inline scalar_t cubic_interp1d(
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
@ -443,7 +443,7 @@ inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64
}
template<typename scalar_t, typename opmath_t>
inline void compute_source_index_and_lambda(
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,

View File

@ -82,7 +82,7 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o
template <typename func_t,
typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
inline void
static inline void
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
using traits = function_traits<func_t>;
using result_type = typename traits::result_type;
@ -97,7 +97,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t
template <typename func_t,
typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
inline void
static inline void
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
using traits = function_traits<func_t>;
for (; i < n; i++) {
@ -111,7 +111,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t
// Basic loop operation (one output, N inputs). May be auto-vectorized
// by the compiler. Supports inputs and outputs of different types.
template <typename func_t>
inline void
static inline void
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
using traits = function_traits<func_t>;
constexpr int ntensors = traits::arity + 1;
@ -166,7 +166,7 @@ void handle_tuple_outputs(char* C10_RESTRICT data[],
// 2. Iterate over the members of the returned tuple, set the corresponding
// output tensor by the tuple member in `handle_tuple_outputs` function.
template <typename func_t>
inline void
static inline void
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
using traits = function_traits<func_t>;
@ -195,7 +195,7 @@ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
// is 0, then there are no scalar inputs.
template <typename func_t, typename vec_func_t>
inline void
static inline void
vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
using traits = function_traits<vec_func_t>;
using scalar_t = typename function_traits<func_t>::result_type;
@ -228,7 +228,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
template <typename traits, typename cb_t>
inline void unroll_contiguous_scalar_checks(
static inline void unroll_contiguous_scalar_checks(
const int64_t* /*strides*/,
std::index_sequence<>,
cb_t&& cb) {
@ -236,7 +236,7 @@ inline void unroll_contiguous_scalar_checks(
}
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
inline void unroll_contiguous_scalar_checks(
static inline void unroll_contiguous_scalar_checks(
const int64_t* strides,
std::index_sequence<INDEX0, INDEX...>,
cb_t&& cb) {

View File

@ -21,21 +21,21 @@ using namespace vec;
// reduction that is contiguous over the input in dim 0
template <typename traits>
inline bool is_contiguous_reduction(const int64_t* strides) {
static inline bool is_contiguous_reduction(const int64_t* strides) {
return strides[0] == 0 &&
strides[1] == sizeof(typename traits::arg2_t);
}
// reduction that is contiguous over the input in dim 1
template <typename traits>
inline bool is_outer_reduction(const int64_t* strides) {
static inline bool is_outer_reduction(const int64_t* strides) {
return strides[0] == 0 &&
strides[2] == sizeof(typename traits::result_type) &&
strides[3] == sizeof(typename traits::arg2_t);
}
template <typename func_t, typename vec_func_t>
inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
static inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
func_t op, vec_func_t vop, bool reduce) {
VEC_LOOP_HEADER(func_t, data)
const char* in1_ptr = data[1];
@ -69,7 +69,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
}
template <typename F>
inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
for (const auto j C10_UNUSED : c10::irange(n)) {
f();
data[0] += strides[0];
@ -79,7 +79,7 @@ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n,
// computes the reduction out = op(out, in)
template <typename func_t, typename vec_func_t>
inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
VEC_LOOP_HEADER(func_t, data)
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
int64_t count = n / (4 * Vec::size());
@ -93,7 +93,7 @@ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_fu
// computes the reduction out = op(out, in)
template <typename func_t, typename vec_func_t>
inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
VEC_LOOP_HEADER(func_t, data)
// reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
@ -132,13 +132,13 @@ static void set_results(const res_t result, const TensorIteratorBase &iter, cons
}
template<typename traits, std::size_t i = 0, typename... tuple_t>
inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
static inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
for_each_in_tuple(const std::tuple<tuple_t...>& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
return i;
}
template<typename traits, std::size_t i = 0, typename... tuple_t>
inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
static inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIteratorBase &iter, const int num_outputs) {
if (i < (size_t)num_outputs) {
set_result<traits>(i, std::get<i>(t), iter, num_outputs);
@ -286,7 +286,7 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo
// when reduction is on most inner dimension (dim 0 in TensorIterator)
// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
// can be used.
inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
static inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
&& iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
}

View File

@ -1002,7 +1002,7 @@ std::string generate_code(
std::string extra_args = "";
for (size_t i = 0; i < extra_args_typenames.size(); i++) {
auto type = std::string(extra_args_typenames[i]);
auto name = "extra_arg_" + std::to_string(i);
auto name = "extra_arg_" + std::string(to_string(i));
extra_params += "," + type + " " + name;
extra_args += ", " + name;
}

View File

@ -614,13 +614,6 @@ void run_cudnn_SDP_bprop(
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
Tensor dO_ = dO;
if (!dO.strides()[dO.strides().size() - 1]) {
TORCH_WARN(
"cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous\
tensor which will increase memory usage...");
dO_ = dO.contiguous();
}
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
@ -642,7 +635,7 @@ void run_cudnn_SDP_bprop(
k,
v,
o,
dO_,
dO,
softmaxstats,
dQ,
dK,

View File

@ -5,7 +5,7 @@
namespace at::native {
inline void col2im_shape_check(
static inline void col2im_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t output_height,
@ -135,7 +135,7 @@ inline void col2im_shape_check(
}
}
inline void im2col_shape_check(
static inline void im2col_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t kernel_height,

View File

@ -27,7 +27,53 @@ Tensor mkldnn_convolution(
TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
}
static Tensor mkldnn_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
}
static std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
}
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
}
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);
static Tensor mkldnn_convolution_transpose(
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support");
}
static Tensor mkldnn_convolution_transpose_backward_input(
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support");
}
static std::tuple<Tensor, Tensor> mkldnn_convolution_transpose_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool bias_defined) {
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support");
}
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, std::array<bool,3> output_mask) {
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward: ATen not compiled with MKLDNN support");
}
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub);
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub);

View File

@ -14,7 +14,6 @@
#include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h>
#endif
#include <ATen/native/mkldnn/Utils.h>
#if !AT_MKLDNN_ENABLED()
@ -38,7 +37,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
}
std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
static std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
const Tensor& input,
IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
double eps, bool inplace) {
@ -82,6 +81,7 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
#else // AT_MKLDNN_ENABLED
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/layer_norm.h>
#include <ideep/abstract_types.hpp>

View File

@ -36,7 +36,7 @@ void check_mkldnn_binary_fusion_inputs(
const Tensor& weight,
const Tensor& bias);
inline std::vector<int64_t> padding_r(
static inline std::vector<int64_t> padding_r(
IntArrayRef padding, IntArrayRef output_padding)
{
// ConvTranpose padding adjustment
@ -60,7 +60,7 @@ inline std::vector<int64_t> padding_r(
// Make sure input has default contiguous strides if it's contiguous tensors for better performance.
// For example, for tensor of size = [1, 1280], stride = [0, 1], we'll convert it to size = [1, 1280], stride = [1280, 1]
// before calling oneDNN for better performance.
inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) {
static inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) {
auto input_size = input.sizes().vec();
auto input_stride = input.strides().vec();
auto input_default_contiguous_strides = c10::contiguous_strides(input_size);

View File

@ -14728,12 +14728,12 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
tags: nondeterministic_seeded

View File

@ -5,7 +5,6 @@
#include <ATen/ExpandUtils.h>
#include <torch/library.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/native/quantized/cpu/BinaryOps.h>
#include <ATen/native/quantized/cpu/QuantizedOps.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
@ -501,7 +500,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
} // namespace
Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
static Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
return qadd<false>(std::move(qa), std::move(qb), scale, zero_point);
}

View File

@ -172,6 +172,16 @@ Tensor mean_quantized_cpu(
return result;
}
static Tensor& mean_out_quantized_cpu(
Tensor& result,
const Tensor& self,
DimnameList dim,
bool keepdim,
std::optional<ScalarType> opt_dtype) {
return mean_out_quantized_cpu(
self, dimnames_to_positions(self, dim), keepdim, opt_dtype, result);
}
// qstd
inline bool is_std_inner_dim_fast_path(
const Tensor& self,

View File

@ -216,6 +216,20 @@ Tensor upsample_bilinear2d_quantized_cpu(
}
}
using at::native::upsample::compute_output_size;
using at::native::upsample::get_scale_value;
static Tensor upsample_bilinear2d_quantized_cpu(
const Tensor& input,
at::OptionalIntArrayRef output_size,
bool align_corners,
std::optional<ArrayRef<double>> scale_factors) {
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
return upsample_bilinear2d_quantized_cpu(input, osize, align_corners, scale_h, scale_w);
}
DEFINE_DISPATCH(qupsample_bilinear2d_nhwc_stub);
} // namespace native
} // namespace at

View File

@ -218,5 +218,25 @@ Tensor _upsample_nearest_exact2d_quantized_cpu(
return _upsample_nearest2d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(input, osize, scale_h, scale_w);
}
static Tensor upsample_nearest2d_quantized_cpu(
const Tensor& input,
at::OptionalIntArrayRef output_size,
std::optional<ArrayRef<double>> scale_factors) {
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
return upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w);
}
static Tensor _upsample_nearest_exact2d_quantized_cpu(
const Tensor& input,
at::OptionalIntArrayRef output_size,
std::optional<ArrayRef<double>> scale_factors) {
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
return _upsample_nearest_exact2d_quantized_cpu(input, osize, scale_h, scale_w);
}
} // namespace native
} // namespace at

View File

@ -1,7 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include <ATen/core/Tensor.h>
@ -36,6 +35,7 @@
#endif
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
namespace {
// To have a sanity check for maximum matrix size.
@ -1848,15 +1848,15 @@ class QConvInt8ForBC final {
int64_t output_zero_point) {
if (kReluFused) {
TORCH_WARN_ONCE(
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv" +
std::to_string(kSpatialDim),
"d_relu, have been removed, please update your model to remove these arguments.");
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv"
+ c10::to_string(kSpatialDim) + "d_relu, " +
"have been removed, please update your model to remove these arguments.");
return packed_weight->apply_relu(act, output_scale, output_zero_point);
} else {
TORCH_WARN_ONCE(
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv",
std::to_string(kSpatialDim),
"d, have been removed, please update your model to remove these arguments.");
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv"
+ c10::to_string(kSpatialDim) + "d, " +
"have been removed, please update your model to remove these arguments.");
return packed_weight->apply(act, output_scale, output_zero_point);
}
}

View File

@ -342,10 +342,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
output_shape[cols_dim] = output_columns;
at::SymDimVector output_shape_vec(output_shape);
return at::empty_symint(
output_shape_vec,
weight.options().dtype(weight.scalar_type()),
weight.suggest_memory_format());
return at::empty_symint(output_shape_vec, weight.options().dtype(weight.scalar_type()), weight.suggest_memory_format());
}
namespace {
@ -376,10 +373,9 @@ Tensor _qembeddingbag_nbit_prepack_helper(
int NUM_ELEM_PER_BYTE = 8 / bit_width;
TORCH_CHECK(
weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0,
"qembeddingbag_",
std::to_string(bit_width),
"bit_prepack only works for the number of columns a multiple of ",
std::to_string(NUM_ELEM_PER_BYTE));
"qembeddingbag_" + c10::to_string(bit_width) +
"bit_prepack only works for the number of columns a multiple of " +
c10::to_string(NUM_ELEM_PER_BYTE));
// The "fused" representation stores the scale and bias with the
// row-wise quantized data in one tensor.
@ -555,9 +551,11 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
TORCH_FN(QEmbeddingPackWeights::run));
}
TORCH_LIBRARY_IMPL(quantized, Meta, m) {
m.impl(
"quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta);
"quantized::embedding_bag_byte_prepack",
qembeddingbag_byte_prepack_meta);
}
} // namespace

View File

@ -624,6 +624,15 @@ Tensor _sparse_softmax(const Tensor& self, Dimname dim, optional<ScalarType> dty
return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype);
}
static Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_) {
auto result = [&]() {
NoNamesGuard guard;
return at::_sparse_log_softmax(input_, dim_, false);
}();
namedinference::propagate_names(result, input_);
return result;
}
Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;

View File

@ -270,6 +270,10 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value) {
return div_out_sparse_zerodim(self, value, self);
}
static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseTensor& r) {
return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r);
}
Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
auto commonDtype = at::result_type(self, value);
if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
@ -283,6 +287,10 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<c10::string
return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self);
}
static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, std::optional<c10::string_view> rounding_mode, SparseTensor& r) {
return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), std::move(rounding_mode), r);
}
// --------------------------------------------------------------------
// floor_divide(SparseTensor, Scalar)
// --------------------------------------------------------------------
@ -342,6 +350,10 @@ Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) {
return floor_divide_out_sparse_zerodim(self, value, self);
}
static SparseTensor& floor_divide_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) {
return floor_divide_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r);
}
// --------------------------------------------------------------------
// norm(SparseTensor, Scalar)
// --------------------------------------------------------------------

View File

@ -666,7 +666,7 @@ Tensor scaled_dot_product_attention(
case sdp::SDPBackend::cudnn_attention: {
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
query_, key, value, compute_logsumexp, dropout_p, is_causal, scale);
query_, key, value, dropout_p, is_causal, compute_logsumexp, scale);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::flash_attention: {

View File

@ -72,17 +72,10 @@
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#endif
#ifdef USE_MEM_EFF_ATTENTION
#ifndef USE_ROCM
// MemoryEfficient Attention Specific Imports for CUDA
// MemoryEfficient Attention Specific Imports
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#else
// MemoryEfficient Attention Specific Imports for ROCM
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#endif
#endif
namespace at {
@ -735,27 +728,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
}
// Adapted from TE
// extract seed and offset from PhiloxCudaState
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) {
if (arg.captured_) {
*seed_ptr = static_cast<int64_t>(*arg.seed_.ptr);
*offset_ptr = static_cast<int64_t>(
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
*seed_ptr = static_cast<int64_t>(arg.seed_.val);
*offset_ptr = static_cast<int64_t>(arg.offset_.val);
}
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
bool compute_logsumexp,
double dropout_p,
bool is_causal,
c10::optional<double> scale) {
bool training,
std::optional<double> scale) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
@ -774,33 +754,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
Tensor attention, log_sumexp;
at::Tensor cudnn_seed, cudnn_offset;
cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
// See Note [Seed and Offset Device] in _efficient_attention_forward
at::PhiloxCudaState philox_state;
const bool in_capture_stream =
at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
if (use_dropout) {
// Device
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
// if using dropout, we produce 1 random number for each element of the
// attention tensor
// TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k);
unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()));
}
auto cudnn_seed = at::zeros({1}, query.options().dtype(kLong));
auto cudnn_offset = at::zeros({1}, query.options().dtype(kLong));
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
Tensor debugmask;
run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
num_heads/*int64_t h*/,
@ -808,7 +764,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
max_seqlen_batch_k/*int64_t s_kv*/,
head_dim/*int64_t d*/,
softmax_scale/*float scaling_factor*/,
compute_logsumexp/* bool */,
training/* bool */,
is_causal/* bool */,
dropout_p/*double dropout_probability*/,
query/* Tensor q*/,
@ -819,7 +775,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
return std::make_tuple(attention, log_sumexp, cudnn_seed, cudnn_offset);
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
@ -1106,64 +1062,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
offset_t = at::empty({}, at::dtype(at::kLong).device(device));
}
#ifdef USE_ROCM
// ROCM Implementation
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
}
// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
res = at::empty({B, M, num_heads, Kv}, query.options());
logsumexp = at::empty(
{ B, num_heads, max_seqlen_q },
query.options().dtype(at::ScalarType::Float));
at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
hipError_t err; // TODO: Error handling
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
if (!compute_logsumexp) {
// Set the tensor to empty when compute_logsumexp is false
logsumexp = at::empty(
{ B * num_heads, max_seqlen_q, 0 },
query.options().dtype(at::ScalarType::Float));
}
#else
// CUDA Implementation
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;
@ -1333,7 +1231,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!");
AT_CUDA_CHECK(cudaGetLastError());
#endif // USE_ROCM
return std::make_tuple(
std::move(res),
std::move(logsumexp),
@ -1354,7 +1251,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
#ifdef USE_MEM_EFF_ATTENTION
namespace {
/**
* simple kernel that populates a tensor with rand uniform values.
@ -1404,7 +1301,7 @@ __global__ void rand_uniform_kernel(
}
}
} // namespace
#endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
#endif
/**
* fill tensor with random uniform values. only used for testing, not much
* attention is paid to performance
@ -1422,17 +1319,6 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
const int64_t n_keys = self.size(3);
#if defined(USE_MEM_EFF_ATTENTION)
#ifdef USE_ROCM
using aotriton::v2::flash::debug_fill_dropout_rng;
using sdp::aotriton_adapter::mk_aotensor;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
hipError_t err; // TODO: Error handling
err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
static_cast<uint64_t>(seed),
static_cast<uint64_t>(offset),
stream);
#else
at::PhiloxCudaState rng_engine_inputs;
rng_engine_inputs = at::PhiloxCudaState(seed, offset);
at::cuda::CUDAGuard device_guard(self.device());
@ -1446,7 +1332,6 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
rng_engine_inputs,
reinterpret_cast<float*>(self.data_ptr()),
self.numel());
#endif
return self;
#endif

View File

@ -36,18 +36,11 @@
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#endif
#ifdef USE_MEM_EFF_ATTENTION
#ifndef USE_ROCM
// MemoryEfficient Attention Specific Imports for CUDA
// MemoryEfficient Attention Specific Imports
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#else
// MemoryEfficient Attention Specific Imports for ROCM
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#endif
#endif
#ifdef __HIP_PLATFORM_AMD__
@ -171,32 +164,18 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
const Tensor& philox_seed,
const Tensor& philox_offset,
// const Tensor& cumulative_sequence_length_q,
// const Tensor& cumulative_sequence_length_k,
// const int64_t max_seqlen_batch_q,
// const int64_t max_seqlen_batch_k,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
double dropout_p,
bool is_causal,
c10::optional<double> scale) {
auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (ctx.deterministicAlgorithmsWarnOnly()) {
TORCH_WARN_ONCE(
"cuDNN Attention defaults to a non-deterministic algorithm. ",
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
}
}
const Tensor& philox_seed,
const Tensor& philox_offset,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim = query.size(3);
const int64_t max_seqlen_batch_q = query.size(1);
const int64_t max_seqlen_batch_k = key.size(1);
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
@ -369,6 +348,7 @@ _efficient_attention_backward(
grad_bias = at::empty(sz, bias->options())
.slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim);
}
at::Tensor workspace;
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
@ -388,62 +368,6 @@ _efficient_attention_backward(
}
}
#ifdef USE_ROCM
// ROCM Implementation
TORCH_CHECK(!num_splits_key.has_value(),
"ROCM does not support num_split_keys in _efficient_attention_forward");
TORCH_CHECK(!window_size.has_value(),
"ROCM does not support window_size in _efficient_attention_forward");
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
}
at::Tensor q_t = query.permute({0,2,1,3});
at::Tensor k_t = key.permute({0,2,1,3});
at::Tensor v_t = value.permute({0,2,1,3});
at::Tensor out_t = out.permute({0,2,1,3});
at::Tensor dq_t = grad_q.permute({0,2,1,3});
at::Tensor dk_t = grad_k.permute({0,2,1,3});
at::Tensor dv_t = grad_v.permute({0,2,1,3});
at::Tensor dout_t = grad_out.permute({0,2,1,3});
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
rng_engine_inputs.seed_.val,
rng_engine_inputs.offset_.val,
is_causal,
stream);
#else
at::Tensor workspace;
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
const int computeCapability = p->major * 10 + p->minor;
@ -700,9 +624,8 @@ _efficient_attention_backward(
}));
TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!");
AT_CUDA_CHECK(cudaGetLastError());
#endif // USE_ROCM
return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias));
#endif // defined(USE_MEM_EFF_ATTENTION)
#endif
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}

View File

@ -6,7 +6,6 @@
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
@ -45,28 +44,14 @@
namespace sdp {
namespace {
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
bool check_prefer_cudnn_attention() {
auto dprops = at::cuda::getCurrentDeviceProperties();
return dprops->major >= 9;
}
// flash_attention V2 is universally faster than efficient_attention and Math
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
constexpr std::array<SDPBackend, num_backends> default_order{
SDPBackend::flash_attention,
SDPBackend::cudnn_attention,
SDPBackend::efficient_attention,
SDPBackend::math};
constexpr std::array<SDPBackend, num_backends> cudnn_order{
SDPBackend::cudnn_attention,
SDPBackend::flash_attention,
SDPBackend::efficient_attention,
SDPBackend::math};
static const bool prefer_cudnn = check_prefer_cudnn_attention();
return prefer_cudnn ? cudnn_order : default_order;
return default_order;
}
bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {
@ -230,17 +215,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
}
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
if (debug) {
@ -253,7 +227,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
#endif
return true;
}
@ -466,6 +439,17 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
return true;
}
bool check_is_causal(sdp_params const& params, bool debug) {
// Check that the input is causal
if (!params.is_causal) {
if (debug) {
TORCH_WARN("CuDNN requires is_causal=True.");
}
return false;
}
return true;
}
bool check_for_nested_inputs(sdp_params const& params, bool debug) {
// Check that the input is nested
if (has_for_nested_inputs(params)) {
@ -489,6 +473,22 @@ bool check_dtypes_low_precision(sdp_params const& params, bool debug) {
}
}
bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) {
static c10::once_flag supported_flag;
static bool supported = false;
c10::call_once(supported_flag, []() {
supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true);
});
if (!supported) {
if (debug) {
TORCH_WARN(
"The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`");
}
return false;
}
return true;
}
bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
// We check the global context to see if user has explicitly turned of cudnn
// sdp kernels
@ -501,15 +501,13 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
return true;
}
bool check_cudnn_deterministic(const sdp_params& params, bool debug) {
auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (!ctx.deterministicAlgorithmsWarnOnly()) {
if (debug) {
TORCH_WARN("cuDNN SDPA is not deterministic.");
}
return false;
bool check_cudnn_requires_grad(sdp_params const& params, bool debug) {
// Check that the input is causal
if (input_requires_grad(params)) {
if (debug) {
TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True.");
}
return false;
}
return true;
}
@ -517,29 +515,21 @@ bool check_cudnn_deterministic(const sdp_params& params, bool debug) {
} // namespace
bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \
(defined(CUDNN_VERSION) && CUDNN_VERSION < 8900)
TORCH_WARN_ONCE(!debug, "Torch was not compiled with cuDNN attention.");
return false;
#endif
// Define gate functions that determine if a flash kernel can be ran
// Replace with std::to_array when we migrate to c++20
constexpr auto general_constraints =
array_of<bool (*)(sdp_params const&, bool)>(
check_for_nested_inputs,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim>*/>,
check_all_tensors_on_device,
check_tensor_shapes,
check_cudnn_tensor_shapes,
check_runtime_enabled_cudnn,
check_runtime_disabled_cudnn,
check_cudnn_deterministic,
// check_cudnn_layout,
check_cudnn_hardware_support,
check_all_tensors_on_device,
check_cudnn_tensor_shapes,
check_cudnn_layout,
// check_is_causal,
check_dtypes_low_precision,
check_for_attn_mask_cudnn,
check_cudnn_hardware_support
);
check_for_nested_inputs,
check_cudnn_requires_grad,
check_dtypes_low_precision);
for (auto& constraint : general_constraints) {
if (!constraint(params, debug)) {
return false;
@ -607,10 +597,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
constexpr auto less_than_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);
#ifdef USE_ROCM
constexpr auto aotriton_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
#endif
// Define gate functions that determine if a mem efficient kernel can be ran
constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
@ -626,10 +612,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
if (has_for_nested_inputs(params)) {
#ifdef USE_ROCM
TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
return false;
#endif
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_requires_grad_and_nested,
check_batch_size_nested,
@ -652,14 +634,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
}
#ifdef USE_ROCM
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
#else
auto dprop = at::cuda::getCurrentDeviceProperties();
if (dprop->major >= 8) {
return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
}
#endif
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
}

View File

@ -1,130 +0,0 @@
#pragma once
#ifdef USE_ROCM
#include <aotriton/dtypes.h>
#include <aotriton/util.h>
////////////////////////////////////////////////////////////////////////////////
// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
////////////////////////////////////////////////////////////////////////////////
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK( \
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
TORCH_CHECK( \
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK( \
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
}
namespace sdp {
namespace aotriton_adapter {
inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
{
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
CAST_TYPE(kByte, kUInt8);
CAST_TYPE(kUInt16, kUInt16);
CAST_TYPE(kUInt32, kUInt32);
CAST_TYPE(kUInt64, kUInt64);
CAST_TYPE(kChar, kInt8);
CAST_TYPE(kShort, kInt16);
CAST_TYPE(kInt, kInt32);
CAST_TYPE(kLong, kInt64);
CAST_TYPE(kHalf, kFloat16);
CAST_TYPE(kFloat, kFloat32);
CAST_TYPE(kBFloat16, kBFloat16);
return aotriton::DType::kUnknown;
#undef CAST_TYPE
}
template<typename TargetType, int Rank>
struct IntArrayRefCaster {
// std::array<TargetType, Rank> cast(IntArrayRef);
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 1> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 2> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 2>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 3> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 3>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 4> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 4>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2)),
static_cast<TargetType>(ref.at(3))
}};
}
};
template<int Rank = 4>
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
{
const auto strides = q.strides();
int real_rank = strides.size();
if (real_rank != Rank) { // Lazy convertion of tensor_name
TORCH_CHECK(false,
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
+ " but is " + std::to_string(real_rank));
}
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
cast_dtype(q.dtype()));
}
} // namespace aotriton_adapter
} // namespace sdp
namespace at::native {
inline int64_t ceil_div(int64_t numerator, int64_t denominator) {
return (numerator + (denominator - 1)) / denominator;
}
}
#endif // USE_ROCM

View File

@ -54,15 +54,16 @@
#include <ATen/ops/pad.h>
#endif
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <ATen/native/transformers/hip/flash_attn/flash_api.h>
#include <c10/util/Exception.h>
#include <c10/util/CallOnce.h>
// AOTriton headers
#include <aotriton/dtypes.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#include <aotriton/util.h>
namespace pytorch_flash {
@ -72,10 +73,90 @@ void check_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
}
}
aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
{
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
CAST_TYPE(kByte, kUInt8);
CAST_TYPE(kUInt16, kUInt16);
CAST_TYPE(kUInt32, kUInt32);
CAST_TYPE(kUInt64, kUInt64);
CAST_TYPE(kChar, kInt8);
CAST_TYPE(kShort, kInt16);
CAST_TYPE(kInt, kInt32);
CAST_TYPE(kLong, kInt64);
CAST_TYPE(kHalf, kFloat16);
CAST_TYPE(kFloat, kFloat32);
CAST_TYPE(kBFloat16, kBFloat16);
return aotriton::DType::kUnknown;
#undef CAST_TYPE
}
template<typename TargetType, int Rank>
struct IntArrayRefCaster {
// std::array<TargetType, Rank> cast(IntArrayRef);
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 1> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 2> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 2>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 3> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 3>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 4> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 4>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2)),
static_cast<TargetType>(ref.at(3))
}};
}
};
template<int Rank = 4>
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
{
const auto strides = q.strides();
int real_rank = strides.size();
if (real_rank != Rank) { // Lazy convertion of tensor_name
TORCH_CHECK(false,
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
+ " but is " + std::to_string(real_rank));
}
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
cast_dtype(q.dtype()));
}
}
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
@ -219,13 +300,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
empty_bias,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
@ -418,20 +495,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
hipError_t err; // TODO: Error handling
{
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
empty_bias,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
empty_bias,
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,

View File

@ -266,18 +266,7 @@ inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug)
inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
if (params.attn_mask.has_value()) {
if (debug) {
TORCH_WARN("Flash Attention do not support non-null attn_mask.");
}
return false;
}
return true;
}
// TODO(eqy): remove this once support is added
inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) {
if (params.attn_mask.has_value()) {
if (debug) {
TORCH_WARN("cuDNN Attention does not support non-null attn_mask.");
TORCH_WARN("Flash Attention does not support non-null attn_mask.");
}
return false;
}
@ -324,7 +313,7 @@ inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
(query_dim == 4))) {
if (debug) {
TORCH_WARN(
"All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
"Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
query_dim,
", Key dim: ",
params.key.dim(),
@ -436,7 +425,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool
if (zero_seq_len_q || zero_seq_len_k) {
if (debug) {
TORCH_WARN(
"All fused kernels do not support zero seq_len_q or seq_len_kv.");
"Both fused kernels do not support zero seq_len_q or seq_len_kv.");
}
return false;
}
@ -471,7 +460,7 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool
}
epilogue_message << " instead.";
TORCH_WARN(
"All fused kernels require the last dimension of the input to have stride 1. ",
"Both fused kernels require the last dimension of the input to have stride 1. ",
"Got Query.stride(-1): ",
params.query.sym_stride(-1),
", Key.stride(-1): ",

View File

@ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs
# We are primarily interested in tf32 datatype
torch.backends.cuda.matmul.allow_tf32 = True
# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
torch._inductor.config.fx_graph_cache = True
def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to

View File

@ -827,6 +827,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/mtia/Module.cpp",
"torch/csrc/inductor/aoti_runner/pybind.cpp",

View File

@ -49,33 +49,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
// NB: legacy, prefer float_truediv or int_truediv
virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_truediv(const SymNode& other) {
return truediv(other);
}
virtual SymNode int_truediv(const SymNode& other) {
return truediv(other);
}
// NB: legacy, prefer float_pow or pow_by_natural
virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode float_pow(const SymNode& other) {
return pow(other);
}
virtual SymNode pow_by_natural(const SymNode& other) {
return pow(other);
}
// NB: legacy, prefer int_floordiv
virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymNode int_floordiv(const SymNode& other) {
return floordiv(other);
}
virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI");
}

View File

@ -1321,9 +1321,6 @@ if(USE_ROCM)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
endif()
if(USE_MEM_EFF_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION)
endif()
endif()
if(BUILD_LITE_INTERPRETER)

View File

@ -0,0 +1,186 @@
#include "caffe2/perfkernels/adagrad.h"
#include <cmath>
#include "caffe2/perfkernels/common.h"
namespace caffe2 {
void adagrad_update__base(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
const float lr,
const float weight_decay = 0.f) {
internal::adagrad_update_base_inlined(
N, w, g, h, nw, nh, decay, epsilon, lr, weight_decay);
}
void adagrad_update_prefetch__base(
int N,
const float* w,
const float* /* w_n */, // prefetch ptr
const float* g,
const float* h,
const float* /* h_n */, // prefetch ptr
float* nw,
float* /* nw_n */, // prefetch ptr
float* nh,
float* /* nh_n */, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr, weight_decay);
}
void adagrad_fp16_update_prefetch__base(
int N,
const at::Half* w,
const at::Half* /* w_n */, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* /* h_n */, // prefetch ptr
at::Half* nw,
at::Half* /* nw_n */, // prefetch ptr
at::Half* nh,
at::Half* /* nh_n */, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
internal::adagrad_update_base_inlined(
N, w, g, h, nw, nh, 1.0f, epsilon, lr, weight_decay);
}
// version without prefetching
decltype(adagrad_update__base) adagrad_update__avx2_fma;
decltype(adagrad_update__base) adagrad_update__avx512;
void adagrad_update(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
float lr,
float weight_decay) {
AVX512_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
AVX2_FMA_DO(
adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr, weight_decay);
}
decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx2_fma;
void adagrad_update_prefetch(
int N,
const float* w,
const float* w_n, // prefetch ptr
const float* g,
const float* h,
const float* h_n, // prefetch ptr
float* nw,
float* nw_n, // prefetch ptr
float* nh,
float* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay) {
AVX2_FMA_DO(
adagrad_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
BASE_DO(
adagrad_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
}
// Version with prefetching for embeddings and
// momentum using fp16
decltype(adagrad_fp16_update_prefetch__base)
adagrad_fp16_update_prefetch__avx2_fma;
void adagrad_fp16_update_prefetch(
int N,
const at::Half* w,
const at::Half* w_n, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* h_n, // prefetch ptr
at::Half* nw,
at::Half* nw_n, // prefetch ptr
at::Half* nh,
at::Half* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay) {
AVX2_FMA_DO(
adagrad_fp16_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
BASE_DO(
adagrad_fp16_update_prefetch,
N,
w,
w_n,
g,
h,
h_n,
nw,
nw_n,
nh,
nh_n,
epsilon,
lr,
weight_decay);
}
} // namespace caffe2

View File

@ -0,0 +1,205 @@
#pragma once
#if defined(__AVX__) && !defined(__NVCC__) && \
(defined(__x86_64__) || defined(_M_X64) || defined(__i386__))
#define CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
#include <immintrin.h>
#endif
#include <c10/util/Half.h>
#include <c10/util/irange.h>
namespace caffe2 {
namespace internal {
// The following functions inside internal namespace are inlined because they
// are performance critical.
template <typename T>
static inline void adagrad_update_base_inlined(
int N,
const T* w,
const float* g,
const T* h,
T* nw,
T* nh,
float decay,
float epsilon,
float lr,
float weight_decay = 0.f) {
for (const auto i : c10::irange(N)) {
float gi = std::fma(weight_decay, w[i], g[i]);
float hi = decay * h[i] + gi * gi;
nh[i] = hi;
nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
}
}
// version with prefetching
// TODO(msmelyan)
// Crux of the computation is computing a / (sqrt(b) + epsilon),
// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not
// change. Today it's computed using two vector sqrt and vector divide simd
// instructions. It is slow. We can take advantage of existing fast vector
// VRSQRTPS instruction that computes approximate reciprocals of square roots
// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the
// addition of epsilon is just done to avoid division by zero, we approximate a
// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can
// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for
// the test on random numbers between 0.1 and 1 the absolute error was about
// 10^-3 compared to using slower but more accurate combination of vsqrt and
// vdiv. Extend Marat's function with more NR iterations to get more accuracy
// for training
// TODO(msmelyan)
// explore streaming stores, but need to have unique indices (deduplication)
inline void adagrad_update_prefetch_inlined(
int N,
const float* w,
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
const float* w_n, // prefetch ptr
#else
const float* /* unused */,
#endif
const float* g,
const float* h,
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
const float* h_n, // prefetch ptr
#else
const float* /* unused */,
#endif
float* nw,
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
float* nw_n, // prefetch ptr
#else
float* /* unused */,
#endif
float* nh,
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
float* nh_n, // prefetch ptr
#else
float* /* unused */,
#endif
float epsilon,
float lr,
float weight_decay = 0.f) {
auto i = 0;
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
constexpr int kSize = 8;
for (; i + kSize <= N; i += kSize) {
_mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
__m256 gi = _mm256_loadu_ps(g + i);
__m256 hi = _mm256_loadu_ps(h + i);
__m256 wi = _mm256_loadu_ps(w + i);
#ifdef __FMA__
gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);
#else
gi = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(weight_decay), wi), gi);
#endif
__m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
_mm256_storeu_ps(nh + i, nhi);
__m256 vtmp = _mm256_div_ps(
_mm256_mul_ps(_mm256_set1_ps(lr), gi),
_mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
_mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp));
}
#endif
adagrad_update_base_inlined(
N - i,
w + i,
g + i,
h + i,
nw + i,
nh + i,
1.0f,
epsilon,
lr,
weight_decay);
}
} // namespace internal
// version with prefetching
// TODO(msmelyan)
// Crux of the computation is computing a / (sqrt(b) + epsilon),
// where a and b are vectors and epsilon is very small (eg., 10^-5) and does not
// change. Today it's computed using two vector sqrt and vector divide simd
// instructions. It is slow. We can take advantage of existing fast vector
// VRSQRTPS instruction that computes approximate reciprocals of square roots
// of the vector. It is 6x faster than vsrt and vdiv combinations. Since the
// addition of epsilon is just done to avoid division by zero, we approximate a
// / (sqrt(b) + epsilon) by a / (sqrt(b + sqrt(epsilon)) If we do that, we can
// use VRSQRTPS instead now. VRSQRTPS is not very accurate. Specifically, for
// the test on random numbers between 0.1 and 1 the absolute error was about
// 10^-3 compared to using slower but more accurate combination of vsqrt and
// vdiv. Extend Marat's function with more NR iterations to get more accuracy
// for training
// TODO(msmelyan)
// explore streaming stores, but need to have inuque indices (deduplication)
void adagrad_update_prefetch(
int N,
const float* w,
const float* w_n, // prefetch ptr
const float* g,
const float* h,
const float* h_n, // prefetch ptr
float* nw,
float* nw_n, // prefetch ptr
float* nh,
float* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f);
// Version with prefetching for embeddings and
// momentum using fp16
void adagrad_fp16_update_prefetch(
int N,
const at::Half* w,
const at::Half* w_n, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* h_n, // prefetch ptr
at::Half* nw,
at::Half* nw_n, // prefetch ptr
at::Half* nh,
at::Half* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f);
// version without prefetching
void adagrad_update(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
float lr,
float weight_decay = 0.f);
} // namespace caffe2
#ifdef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
#undef CAFFE2_PERFKERNELS_ADAGRAD_H_USE_INTRINSIC
#endif

View File

@ -0,0 +1,125 @@
#include "caffe2/perfkernels/adagrad.h"
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
#include <emmintrin.h>
#include <immintrin.h>
namespace caffe2 {
// version without prefetching
void adagrad_update__avx2_fma(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
float lr,
float weight_decay = 0.f) {
constexpr int kSize = 8;
auto i = 0;
for (; i + kSize <= N; i += kSize) {
__m256 gi = _mm256_loadu_ps(g + i);
__m256 hi = _mm256_loadu_ps(h + i);
__m256 wi = _mm256_loadu_ps(w + i);
gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);
__m256 nhi = _mm256_add_ps(
_mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi));
_mm256_storeu_ps(nh + i, nhi);
__m256 vtmp = _mm256_div_ps(
_mm256_mul_ps(_mm256_set1_ps(lr), gi),
_mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
_mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp));
}
for (; i < N; ++i) {
float gi = std::fma(weight_decay, w[i], g[i]);
float hi = nh[i] = decay * h[i] + gi * gi;
nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
}
}
void adagrad_update_prefetch__avx2_fma(
int N,
const float* w,
const float* w_n, // prefetch ptr
const float* g,
const float* h,
const float* h_n, // prefetch ptr
float* nw,
float* nw_n, // prefetch ptr
float* nh,
float* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
internal::adagrad_update_prefetch_inlined(
N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr, weight_decay);
}
// Compute adagrad sparse, assumes embedding and momentum are at::Half
void adagrad_fp16_update_prefetch__avx2_fma(
int N,
const at::Half* w,
const at::Half* w_n, // prefetch ptr
const float* g,
const at::Half* h,
const at::Half* h_n, // prefetch ptr
at::Half* nw,
at::Half* nw_n, // prefetch ptr
at::Half* nh,
at::Half* nh_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay = 0.f) {
constexpr int kSize = 8;
auto i = 0;
for (; i + kSize <= N; i += kSize) {
_mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
_mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);
// only convert momentum and embedding, gradient is fp32
__m256 gi = _mm256_loadu_ps(g + i);
__m128i hhi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(h + i));
__m256 hi = _mm256_cvtph_ps(hhi);
__m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
__m256 wi = _mm256_cvtph_ps(whi);
gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);
__m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
__m128i nhhi = _mm256_cvtps_ph(nhi, 0);
_mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi);
__m256 vtmp = _mm256_div_ps(
_mm256_mul_ps(_mm256_set1_ps(lr), gi),
_mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
__m256 nwi = _mm256_add_ps(wi, vtmp);
__m128i nhwi = _mm256_cvtps_ph(nwi, 0);
_mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi);
}
for (; i < N; ++i) {
float gi = std::fma(
weight_decay,
_cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]),
g[i]);
float nhi =
_cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
reinterpret_cast<unsigned short*>(nh)[i] = _cvtss_sh(nhi, 0);
float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
lr * gi / (std::sqrt(nhi) + epsilon);
reinterpret_cast<unsigned short*>(nw)[i] = _cvtss_sh(nwi, 0);
}
}
} // namespace caffe2

View File

@ -0,0 +1,45 @@
#include "caffe2/perfkernels/adagrad.h"
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
#include <emmintrin.h>
#include <immintrin.h>
namespace caffe2 {
// version without prefetching
void adagrad_update__avx512(
int N,
const float* w,
const float* g,
const float* h,
float* nw,
float* nh,
float epsilon,
float decay,
float lr,
float weight_decay = 0.f) {
constexpr int kSize = 16;
auto i = 0;
for (; i + kSize <= N; i += kSize) {
__m512 gi = _mm512_loadu_ps(g + i);
__m512 hi = _mm512_loadu_ps(h + i);
__m512 wi = _mm512_loadu_ps(w + i);
gi = _mm512_fmadd_ps(_mm512_set1_ps(weight_decay), wi, gi);
__m512 nhi = _mm512_add_ps(
_mm512_mul_ps(_mm512_set1_ps(decay), hi), _mm512_mul_ps(gi, gi));
_mm512_storeu_ps(nh + i, nhi);
__m512 vtmp = _mm512_div_ps(
_mm512_mul_ps(_mm512_set1_ps(lr), gi),
_mm512_add_ps(_mm512_sqrt_ps(nhi), _mm512_set1_ps(epsilon)));
_mm512_storeu_ps(nw + i, _mm512_add_ps(wi, vtmp));
}
for (; i < N; ++i) {
float gi = std::fma(weight_decay, w[i], g[i]);
float hi = nh[i] = decay * h[i] + gi * gi;
nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
}
}
} // namespace caffe2

View File

@ -0,0 +1,113 @@
#include "caffe2/perfkernels/common.h"
#include <algorithm>
#include <cstdint>
#include <cmath>
namespace caffe2 {
namespace {
template <typename T>
void BoxCoxNaive(
std::size_t N,
std::size_t D,
const T* data_ptr,
const T* __restrict lambda1_ptr,
const T* __restrict lambda2_ptr,
T* output_ptr) {
constexpr T k_eps = static_cast<T>(1e-6);
for (std::size_t i = 0; i < N; i++) {
for (std::size_t j = 0; j < D; j++, data_ptr++, output_ptr++) {
T lambda1_v = lambda1_ptr[j];
T lambda2_v = lambda2_ptr[j];
T tmp = std::max(*data_ptr + lambda2_v, k_eps);
if (lambda1_v == 0) {
*output_ptr = std::log(tmp);
} else {
T lambda_1 = 1 / lambda1_v;
T pow = std::pow(tmp, lambda1_v);
*output_ptr = lambda_1 * pow - lambda_1;
}
}
}
}
}
#if defined(CAFFE2_PERF_WITH_AVX2) && defined(CAFFE2_PERF_USE_MKL)
namespace details {
template <typename T>
void compute_batch_box_cox__avx2_fma(
std::size_t N,
std::size_t D,
std::size_t block_size,
const T* data_ptr,
const T* __restrict lambda1_ptr,
const T* __restrict lambda2_ptr,
T* output_ptr);
extern template
void compute_batch_box_cox__avx2_fma<float>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const float* self_data,
const float* __restrict lambda1_data,
const float* __restrict lambda2_data,
float* output_data);
extern template
void compute_batch_box_cox__avx2_fma<double>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const double* self_data,
const double* __restrict lambda1_data,
const double* __restrict lambda2_data,
double* output_data);
} // namespace detail
#endif
template <typename T>
void compute_batch_box_cox(
std::size_t N,
std::size_t D,
std::size_t block_size,
const T* data,
const T* lambda1_data,
const T* lambda2_data,
T* output_data) {
#ifdef CAFFE2_PERF_WITH_AVX2
AVX2_FMA_DO(
details::compute_batch_box_cox,
N,
D,
block_size,
data,
lambda1_data,
lambda2_data,
output_data);
#endif
BoxCoxNaive<T>(N, D, data, lambda1_data, lambda2_data, output_data);
}
template void compute_batch_box_cox<float>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const float* data,
const float* lambda1_data,
const float* lambda2_data,
float* output_data);
template void compute_batch_box_cox<double>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const double* data,
const double* lambda1_data,
const double* lambda2_data,
double* output_data);
} // namespace caffe2

View File

@ -0,0 +1,35 @@
// Impmenets BoxCox operator for CPU
#pragma once
#include <cstdint>
namespace caffe2 {
template <typename T>
void compute_batch_box_cox(
std::size_t N,
std::size_t D,
std::size_t block_size,
const T* self_data,
const T* lambda1_data,
const T* lambda2_data,
T* output_data);
extern template void compute_batch_box_cox<float>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const float* data,
const float* lambda1_data,
const float* lambda2_data,
float* output_data);
extern template void compute_batch_box_cox<double>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const double* data,
const double* lambda1_data,
const double* lambda2_data,
double* output_data);
} // namespace caffe2

View File

@ -0,0 +1,399 @@
#include <immintrin.h>
#ifdef CAFFE2_PERF_USE_MKL
#include <c10/util/irange.h>
#include <caffe2/perfkernels/common.h>
#include <folly/SingletonThreadLocal.h>
#include "vectorizer.h"
// Enable compiler vectorized version only if numerical consistency is not
// required between dev and opt versions - disabled for now
#ifndef FAST_VECTORIZED_KERNEL
#define CPU_CAPABILITY_AVX2
#include <ATen/cpu/vec/vec.h>
namespace at::vec {
// Implements the vectorized version of std::max() operation,
// which DOESNOT propagates NaN for second argument
template <typename scalar_t>
Vectorized<scalar_t> max(const Vectorized<scalar_t>& a, const Vectorized<scalar_t>& b);
template <>
Vectorized<double> max(const Vectorized<double>& a, const Vectorized<double>& b) {
// std::max(NaN, nonNan) -> NaN
return _mm256_max_pd(b, a);
}
template <>
Vectorized<float> max(const Vectorized<float>& a, const Vectorized<float>& b) {
// std::max(NaN, nonNan) -> NaN
return _mm256_max_ps(b, a);
}
// Implements recieprocal method based on newton-rapson method
// 1. user RCP approximiation
// 2. update with RCP = RCP * (2 - X * RCP)
template <typename scalar_t>
Vectorized<scalar_t> fast_recieprocal(const Vectorized<scalar_t>& b);
template <typename scalar_t>
scalar_t fast_recieprocal(scalar_t b);
template<>
Vectorized<float> fast_recieprocal(const Vectorized<float>& b) {
auto minus2 = _mm256_set1_ps(-2.f);
auto rcp = _mm256_rcp_ps(b);
rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2));
rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2));
return rcp;
}
template <>
float fast_recieprocal(float b) {
auto minus2 = _mm_set_ss(-2.f);
auto b_reg = _mm_set_ss(b);
auto rcp = _mm_rcp_ss(b_reg);
rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2));
rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2));
return _mm_cvtss_f32(rcp);
}
template<>
Vectorized<double> fast_recieprocal(const Vectorized<double>& b) {
return b.reciprocal();
}
template <>
double fast_recieprocal(double b) {
return 1./b;
}
}
#endif
#include <cstdint>
#include <cmath>
#include <vector>
#include <mkl.h>
namespace caffe2::details {
// MKL VML function templates.
template <typename T>
void PackV(const int N, const T* a, const int* ia, T* y);
template <typename T>
void UnpackV(const int N, const T* a, T* y, const int* iy);
#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \
template <> \
void PackV<T>(const int N, const T* a, const int* ia, T* y) { \
OriginalFunc(N, a, ia, y); \
}
DELEGATE_PACKV_FUNCTION(float, vsPackV)
DELEGATE_PACKV_FUNCTION(double, vdPackV)
#undef DELEGATE_PACKV_FUNCTION
#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \
template <> \
void UnpackV<T>(const int N, const T* a, T* y, const int* iy) { \
OriginalFunc(N, a, y, iy); \
}
DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV)
DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV)
#undef DELEGATE_UNPACKV_FUNCTION
#ifndef FAST_VECTORIZED_KERNEL
template <typename T>
void box_cox_zero_lambda(
size_t D,
const T* const self_data,
const T* const lambda2_data,
T k_eps,
T* const output_data) {
int j = 0;
using Vec = at::vec::Vectorized<T>;
constexpr int64_t VLEN = Vec::size();
auto k_eps_vec = Vec(k_eps);
for(; j + VLEN < D; j += VLEN) {
auto data = Vec::loadu(self_data + j);
auto lambda2 = Vec::loadu(lambda2_data + j);
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps_vec);
auto res = max.log();
res.store(output_data + j);
}
for ( ;j < D; ++j) {
auto sum = self_data[j] + lambda2_data[j];
auto max = std::max(sum, k_eps);
output_data[j] = std::log(max);
}
}
template <typename T>
void box_cox_nonzero_lambda(
int64_t D,
const T* data_ptr,
const T* lambda1_ptr,
const T* lambda2_ptr,
T k_eps,
T* out) {
int j = 0;
using Vec = at::vec::Vectorized<T>;
constexpr int64_t VLEN = Vec::size();
auto k_eps_vec = Vec(k_eps);
for(; j + VLEN < D; j += VLEN) {
auto data = Vec::loadu(data_ptr + j);
auto lambda2 = Vec::loadu(lambda2_ptr + j);
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps_vec);
auto lambda1 = Vec::loadu(lambda1_ptr + j);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
auto pow = max.pow(lambda1);
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
res.store(out + j);
}
for ( ;j < D; ++j) {
auto sum = data_ptr[j] + lambda2_ptr[j];
auto max = std::max(sum, k_eps);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
auto pow = std::pow(max, lambda1_ptr[j]);
out[j] = pow * lambda_over_1 - lambda_over_1;
}
}
#else
template <typename T>
void box_cox_zero_lambda(
size_t D,
const T* const self_data,
const T* const lambda2_data,
T k_eps,
T* const output_data) {
VECTOR_LOOP for (auto j=0 ;j < D; ++j) {
auto sum = self_data[j] + lambda2_data[j];
auto max = std::max(sum, k_eps);
output_data[j] = std::log(max);
}
}
template <typename T>
void box_cox_nonzero_lambda(
int64_t D,
const T* data_ptr,
const T* lambda1_ptr,
const T* lambda2_ptr,
T k_eps,
T* out) {
VECTOR_LOOP for (auto j=0 ;j < D; ++j) {
FAST_MATH
auto sum = data_ptr[j] + lambda2_ptr[j];
auto max = std::max(sum, k_eps);
auto lamda1 = lambda1_ptr[j];
auto lambda_over_1 = 1 / lamda1;
if constexpr (std::is_same<T, float>::value) {
lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1);
lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1);
}
auto pow = std::pow(max, lamda1);
out[j] = pow * lambda_over_1 - lambda_over_1;
}
}
#endif
template <typename T>
void box_cox_mixed_lambda(
const T* const self_data,
const std::vector<int>& nonzeros,
const std::vector<int>& zeros,
const T* const lambda1,
const T* const lambda2,
const T* const lambda2_z_,
T k_eps,
T* const buffer,
T* const output_data) {
PackV(nonzeros.size(), self_data, nonzeros.data(), buffer);
box_cox_nonzero_lambda<T>(
nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer);
UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data());
PackV(zeros.size(), self_data, zeros.data(), buffer);
box_cox_zero_lambda<T>(
zeros.size(), buffer, lambda2_z_, k_eps, buffer);
UnpackV(zeros.size(), buffer, output_data, zeros.data());
}
template <typename T>
void TileArrayIntoVector(
const T* const a,
const size_t D,
const int K,
std::vector<T>& b) {
b.resize(K * D);
for (const auto k : c10::irange(K)) {
std::copy(a, a + D, b.begin() + k * D);
}
}
void TileIndicesInPlace(std::vector<int>& v, const std::size_t D, const std::size_t K) {
auto n = v.size();
v.resize(K * n);
for (const auto k : c10::irange(1, K)) {
for (const auto j : c10::irange(n)) {
v[k * n + j] = v[j] + k * D;
}
}
}
template <typename T>
void compute_batch_box_cox__avx2_fma(
std::size_t N,
std::size_t D,
std::size_t block_size,
const T* self_data,
const T* __restrict lambda1_data,
const T* __restrict lambda2_data,
T* output_data) {
constexpr T k_eps = static_cast<T>(1e-6);
FOLLY_DECLARE_REUSED(zeros, std::vector<int>);
FOLLY_DECLARE_REUSED(nonzeros, std::vector<int>);
// Don't bother calling reserve; calls after the first will get a
// correctly-sized allocation anyway.
for (const auto j : c10::irange(D)) {
if (lambda1_data[j] == 0) {
zeros.push_back(j);
} else {
nonzeros.push_back(j);
}
}
// Process K rows at a time for effective vectorization with small rows.
const auto K = std::min(N, (block_size + D - 1) / D);
FOLLY_DECLARE_REUSED(lambda1_, std::vector<T>);
FOLLY_DECLARE_REUSED(lambda2_, std::vector<T>);
FOLLY_DECLARE_REUSED(lambda2_z_, std::vector<T>);
if (nonzeros.size() == D) {
// ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0
size_t i = 0;
if (K > 1) {
TileArrayIntoVector(lambda1_data, D, K, lambda1_);
TileArrayIntoVector(lambda2_data, D, K, lambda2_);
DCHECK_EQ(K * D, lambda1_.size());
DCHECK_EQ(K * D, lambda2_.size());
for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) {
box_cox_nonzero_lambda<T>(
K * D,
self_data,
lambda1_.data(),
lambda2_.data(),
k_eps,
output_data);
}
}
for (; i < N; i++, self_data += D, output_data += D) {
box_cox_nonzero_lambda<T>(
D, self_data, lambda1_data, lambda2_data, k_eps, output_data);
}
} else if (zeros.size() == D) {
// ln(x + lambda2), if lambda1 == 0
size_t i = 0;
if (K > 1) {
TileArrayIntoVector(lambda2_data, D, K, lambda2_z_);
DCHECK_EQ(K * D, lambda2_z_.size());
for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) {
box_cox_zero_lambda<T>(
K * D, self_data, lambda2_z_.data(), k_eps, output_data);
}
}
for (; i < N; i++, self_data += D, output_data += D) {
box_cox_zero_lambda<T>(
D, self_data, lambda2_data, k_eps, output_data);
}
} else {
// mix zeros and nonzeros
const size_t n = nonzeros.size();
if (K > 1) {
TileIndicesInPlace(nonzeros, 0, K);
TileIndicesInPlace(zeros, 0, K);
}
FOLLY_DECLARE_REUSED(buffer, std::vector<T>);
buffer.resize(std::max(nonzeros.size(), zeros.size()));
lambda1_.resize(nonzeros.size());
lambda2_.resize(nonzeros.size());
lambda2_z_.resize(zeros.size());
PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data());
PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data());
PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data());
size_t i = 0;
if (K > 1) {
// Truncate to original size, and re-tile with offsets this time.
nonzeros.resize(n);
DCHECK_GT(D, n);
zeros.resize(D - n);
TileIndicesInPlace(nonzeros, D, K);
TileIndicesInPlace(zeros, D, K);
DCHECK_EQ(nonzeros.size(), lambda1_.size());
DCHECK_EQ(nonzeros.size(), lambda2_.size());
DCHECK_EQ(zeros.size(), lambda2_z_.size());
for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) {
box_cox_mixed_lambda<T>(
self_data,
nonzeros,
zeros,
lambda1_.data(),
lambda2_.data(),
lambda2_z_.data(),
k_eps,
buffer.data(),
output_data);
}
// Truncate to original size.
nonzeros.resize(n);
zeros.resize(D - n);
}
for (; i < N; i++, self_data += D, output_data += D) {
box_cox_mixed_lambda<T>(
self_data,
nonzeros,
zeros,
lambda1_.data(),
lambda2_.data(),
lambda2_z_.data(),
k_eps,
buffer.data(),
output_data);
}
}
};
template
void compute_batch_box_cox__avx2_fma<float>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const float* self_data,
const float* __restrict lambda1_data,
const float* __restrict lambda2_data,
float* output_data);
template
void compute_batch_box_cox__avx2_fma<double>(
std::size_t N,
std::size_t D,
std::size_t block_size,
const double* self_data,
const double* __restrict lambda1_data,
const double* __restrict lambda2_data,
double* output_data);
} // namespace caffe2::detail
#endif

View File

@ -0,0 +1,75 @@
#pragma once
// Apple clang was fixed in 8.1
#if defined(__apple_build_version__) && \
((__clang_major__ < 8) || \
((__clang_major__ == 8) && (__clang_minor__ < 1)))
#define CAFFE2_INTERNAL_APPLE_NEED_FIX 1
#endif
// Regular clang was fixed in 3.9
#if defined(__clang__) && (__clang_major__ < 4) && (__clang_minor__ < 9)
#define CAFFE2_INTERNAL_CLANG_NEED_FIX 1
#endif
#if defined(CAFFE2_INTERNAL_APPLE_NEED_FIX) || \
defined(CAFFE2_INTERNAL_CLANG_NEED_FIX)
#include <c10/util/Half.h>
#include <emmintrin.h>
// This version of clang has a bug that _cvtsh_ss is not defined, see
// https://reviews.llvm.org/D16177
static __inline float
__attribute__((__always_inline__, __nodebug__, __target__("f16c")))
_cvtsh_ss(unsigned short a) {
__v8hi v = {(short)a, 0, 0, 0, 0, 0, 0, 0};
__v4sf r = __builtin_ia32_vcvtph2ps(v);
return r[0];
}
static __inline unsigned short
__attribute__((__always_inline__, __nodebug__, __target__("f16c")))
_cvtss_sh(float a, int imm8) {
unsigned short ret;
*reinterpret_cast<at::Half*>(&ret) = a;
return ret;
}
#endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX
#undef __APPLE_NEED_FIX
#undef __CLANG_NEED_FIX
#if defined(_MSC_VER) && !defined(__clang__)
#include <c10/util/Half.h>
#include <cstdint>
// It seems that microsoft msvc does not have a _cvtsh_ss implementation so
// we will add a dummy version to it.
static inline float _cvtsh_ss(unsigned short x) {
union {
std::uint32_t intval;
float floatval;
} t1;
std::uint32_t t2, t3;
t1.intval = x & 0x7fff; // Non-sign bits
t2 = x & 0x8000; // Sign bit
t3 = x & 0x7c00; // Exponent
t1.intval <<= 13; // Align mantissa on MSB
t2 <<= 16; // Shift sign bit into position
t1.intval += 0x38000000; // Adjust bias
t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero
t1.intval |= t2; // Re-insert sign bit
return t1.floatval;
}
static inline unsigned short _cvtss_sh(float x, int imm8) {
unsigned short ret;
*reinterpret_cast<at::Half*>(&ret) = x;
return ret;
}
#endif // _MSC_VER

View File

@ -0,0 +1,211 @@
#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
#include "caffe2/perfkernels/common.h"
#include <c10/util/Logging.h>
#include <c10/util/irange.h>
namespace caffe2 {
/**
* Base implementation does runtime dispatch for each segment of reduction
* @return false if there is an out-of-bound error
*/
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
static bool Fused8BitRowwiseEmbeddingLookupGenericSlow(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const InType* input,
const IndexType* indices,
const int* lengths,
const float* weights, // optional, can be null for sum reducer
bool normalize_by_lengths,
OutType* out) {
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
const auto scale_bias_offset = 8 / sizeof(InType);
const int64_t fused_block_size = block_size + scale_bias_offset;
int64_t current = 0;
for (const auto m : c10::irange(output_size)) {
memset(out, 0, sizeof(OutType) * block_size);
if (current + lengths[m] > index_size) {
return false;
}
for (int i = 0; i < lengths[m]; ++i) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
#ifdef __GNUC__
if (current + 1 < index_size) {
__builtin_prefetch(
input + fused_block_size * indices[current + 1], 0, 1);
}
#endif // __GNUC__
const float* scale_bias = reinterpret_cast<const float*>(
input + fused_block_size * indices[current] + block_size);
float weight = 1.0f;
if (weights) {
weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
}
const float scale = weight * scale_bias[0];
const float bias = weight * scale_bias[1];
for (const auto j : c10::irange(block_size)) {
out[j] += scale * input[fused_block_size * indices[current] + j] + bias;
}
++current;
}
if (normalize_by_lengths && lengths[m]) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float scale = 1.f / lengths[m];
for (const auto j : c10::irange(block_size)) {
out[j] *= scale;
}
}
out += block_size;
}
return current == index_size;
}
// clang-format off
// Proxy back to generic implementation
#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \
bool \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int* lengths, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
return Fused8BitRowwiseEmbeddingLookupGenericSlow< \
IndexType, \
uint8_t, \
OutType, \
false>( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
lengths, \
weights, \
normalize_by_lengths, \
out); \
} \
decltype( \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \
bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int* lengths, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
const int32_t one = 1; \
CAFFE_ENFORCE_EQ( \
reinterpret_cast<const uint8_t*>(&one)[0], \
1, \
"Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
AVX2_FMA_DO( \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
lengths, \
weights, \
normalize_by_lengths, \
out); \
BASE_DO( \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
lengths, \
weights, \
normalize_by_lengths, \
out); \
} \
template <> \
void Fused8BitRowwiseEmbeddingLookup<IndexType, uint8_t, OutType, false>( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const int* lengths, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
bool success = \
Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
lengths, \
weights, \
normalize_by_lengths, \
out); \
if (success) { \
return; \
} \
int64_t current = 0; \
for (int m = 0; m < output_size; ++m) { \
for (int i = 0; i < lengths[m]; ++i) { \
CAFFE_ENFORCE_LT(current, index_size); \
IndexType idx = indices[current]; \
CAFFE_ENFORCE( \
0 <= idx && idx < data_size, \
"Index ", \
current, \
" is out of bounds: ", \
idx, \
", range 0 to ", \
data_size); \
++current; \
} \
} \
CAFFE_ENFORCE_EQ( \
current, \
index_size, \
"Your input seems to be incorrect: the sum of lengths values should be " \
"the size of the indices tensor, but it appears not."); \
}
// clang-format on
FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float);
FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float);
#undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION
} // namespace caffe2

View File

@ -0,0 +1,55 @@
#pragma once
#include <cstdint>
namespace caffe2 {
/**
* Embedding lookup with reduction.
*
* `input` of size data_size * (block_size + 8B)
* `indices` of size index_size
* `lengths` of size output_size
* `weights` nullptr or array of size index_size
* `out` of size output_size * block_size
* sum(lengths[i]) == index_size
*
* Note that block_size should be the number of quantized values per row in the
* data, i.e. excluding the scale and bias. The total (fused) block size is
* assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias.
*
* Behavior is roughly equivalent to pseudocode:
*
* pos = 0
* fused_block_size = block_size + 8B // quantized values and scale and bias
* for (i = 0..output_size-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] = 0
* for (j = 0..lengths[i]-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] *
* (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0)
* pos += 1
* if (normalize_weights && lengths[i] > 0)
* for (k = 0..block_size-1)
* out[i*block_size + k] /= lengths[i]
*
*/
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
void Fused8BitRowwiseEmbeddingLookup(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out);
} // namespace caffe2

View File

@ -0,0 +1,213 @@
#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.h"
#include "caffe2/perfkernels/common.h"
#include <c10/util/Logging.h>
#include <c10/util/irange.h>
namespace caffe2 {
/**
* Base implementation does runtime dispatch for each segment of reduction
* @return false if there is an out-of-bound error
*/
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
static bool Fused8BitRowwiseEmbeddingLookupGenericSlowIdx(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const InType* input,
const IndexType* indices,
const IndexType* offsets,
const float* weights, // optional, can be null for sum reducer
bool normalize_by_lengths,
OutType* out) {
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
const auto scale_bias_offset = 8 / sizeof(InType);
const int64_t fused_block_size = block_size + scale_bias_offset;
int64_t current = 0;
for (const auto m : c10::irange(output_size)) {
memset(out, 0, sizeof(OutType) * block_size);
if (current != offsets[m] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[m];
int64_t end_offset = offsets[m + 1];
int64_t length = end_offset - start_offset;
for (const auto i : c10::irange(start_offset, end_offset)) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
#ifdef __GNUC__
if (current + 1 < index_size) {
__builtin_prefetch(
input + fused_block_size * indices[current + 1], 0, 1);
}
#endif // __GNUC__
const float* scale_bias = reinterpret_cast<const float*>(
input + fused_block_size * indices[current] + block_size);
float weight = 1.0f;
if (weights) {
weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
}
const float scale = weight * scale_bias[0];
const float bias = weight * scale_bias[1];
for (const auto j : c10::irange(block_size)) {
out[j] += scale * input[fused_block_size * indices[current] + j] + bias;
}
++current;
}
if (normalize_by_lengths && length) {
float scale = 1.f / length;
for (const auto j : c10::irange(block_size)) {
out[j] *= scale;
}
}
out += block_size;
}
return current == index_size;
}
// clang-format off
// Proxy back to generic implementation
#define FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(IndexType, OutType) \
bool \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
return Fused8BitRowwiseEmbeddingLookupGenericSlowIdx< \
IndexType, \
uint8_t, \
OutType, \
false>( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
normalize_by_lengths, \
out); \
} \
decltype( \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__base) \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \
bool Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
const int32_t one = 1; \
CAFFE_ENFORCE_EQ( \
reinterpret_cast<const uint8_t*>(&one)[0], \
1, \
"Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
AVX2_FMA_DO( \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
normalize_by_lengths, \
out); \
BASE_DO( \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType##_false, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
normalize_by_lengths, \
out); \
} \
template <> \
void Fused8BitRowwiseEmbeddingLookupIdx<IndexType, uint8_t, OutType, false>( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const uint8_t* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
bool normalize_by_lengths, \
OutType* out) { \
bool success = \
Fused8BitRowwiseEmbeddingLookupIdx_##IndexType##_uint8_t_##OutType( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
normalize_by_lengths, \
out); \
if (success) { \
return; \
} \
int64_t current = 0; \
for (int m = 0; m < output_size; ++m) { \
for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) { \
CAFFE_ENFORCE_LT(current, index_size); \
IndexType idx = indices[current]; \
CAFFE_ENFORCE( \
0 <= idx && idx < data_size, \
"Index ", \
current, \
" is out of bounds: ", \
idx, \
", range 0 to ", \
data_size); \
++current; \
} \
} \
CAFFE_ENFORCE_EQ( \
current, \
index_size, \
"Your input seems to be incorrect: the sum of lengths values should be " \
"the size of the indices tensor, but it appears not."); \
}
// clang-format on
FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int32_t, float);
FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION(int64_t, float);
#undef FUSED_8BIT_ROWWISE_EMBEDDING_IDX_SPECIALIZATION
} // namespace caffe2

View File

@ -0,0 +1,57 @@
#pragma once
#include <cstdint>
namespace caffe2 {
/**
* Embedding lookup with reduction.
*
* `input` of size data_size * (block_size + 8B)
* `indices` of size index_size
* `offsets` of size output_size
* `weights` nullptr or array of size index_size
* `out` of size output_size * block_size
*
* Note that block_size should be the number of quantized values per row in the
* data, i.e. excluding the scale and bias. The total (fused) block size is
* assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias.
*
* Behavior is roughly equivalent to pseudocode:
*
* pos = 0
* fused_block_size = block_size + 8B // quantized values and scale and bias
* for (i = 0..output_size-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] = 0
* start_offset = offsets[i]
* end_offset = i == output_size-1 ? index_size : offsets[i+1] - 1
* length = end_offset - start_offset
* for (j = start_offset..end_offset)
* for (k = 0..block_size-1)
* out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] *
* (weights ? weights[IS_WEIGHT_POSITIONAL ? j : pos] : 1.0)
* pos += 1
* if (normalize_weights && length > 0)
* for (k = 0..block_size-1)
* out[i*block_size + k] /= length
*
*/
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
void Fused8BitRowwiseEmbeddingLookupIdx(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const IndexType* offsets,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out);
} // namespace caffe2

View File

@ -0,0 +1,214 @@
#include "./fused_nbit_rowwise_conversion.h"
#include <c10/util/Half.h>
#include <algorithm>
#include <cmath>
#include "common.h"
#ifdef USE_FBGEMM
#include "fbgemm/QuantUtils.h"
#endif
namespace caffe2 {
void FloatToFused8BitRowwiseQuantized__base(
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
constexpr float kEpsilon = 1e-8f;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int output_columns = input_columns + 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);
float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
float range = maximum_element - minimum_element;
output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
}
}
}
void Fused8BitRowwiseQuantizedToFloat__base(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int output_columns = input_columns - 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output + row * output_columns;
for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
output_row[col] =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
}
}
}
void FloatToFused8BitRowwiseQuantized(
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
#ifdef USE_FBGEMM
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
input, input_rows, input_columns, output);
#else
FloatToFused8BitRowwiseQuantized__base(
input, input_rows, input_columns, output);
#endif
}
void Fused8BitRowwiseQuantizedToFloat(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
#ifdef USE_FBGEMM
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float>(
input, input_rows, input_columns, output);
#else
Fused8BitRowwiseQuantizedToFloat__base(
input, input_rows, input_columns, output);
#endif
}
void FloatToFusedNBitRowwiseQuantizedSBHalf__base(
int bit_rate,
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
int num_elem_per_byte = 8 / bit_rate;
int output_columns =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(at::Half);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
at::Half* output_row_scale_bias = reinterpret_cast<at::Half*>(
output_row +
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte);
float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
minimum_element = static_cast<at::Half>(minimum_element);
const float range = maximum_element - minimum_element;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
at::Half scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1);
if (scale == 0) {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because X - minimum_element will be 0 for all X
scale = 1.0f;
}
float inverse_scale = 1.0f / scale;
if (std::isinf(inverse_scale)) {
scale = 1.0f;
inverse_scale = 1.0f;
}
output_row_scale_bias[0] = scale;
output_row_scale_bias[1] = minimum_element;
for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
float X = input_row[col];
std::uint8_t quantized = std::max(
0,
std::min<int>(
std::lrintf((X - minimum_element) * inverse_scale),
(1 << bit_rate) - 1));
if (col % num_elem_per_byte == 0) {
output_row[col / num_elem_per_byte] = quantized;
} else {
output_row[col / num_elem_per_byte] |=
(quantized << ((col % num_elem_per_byte) * bit_rate));
}
}
}
}
void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
int bit_rate,
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
int num_elem_per_byte = 8 / bit_rate;
int output_columns =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte;
for (std::size_t row = 0; row < static_cast<size_t>(input_rows); ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const at::Half* input_row_scale_bias = reinterpret_cast<const at::Half*>(
input_row +
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
float scale = input_row_scale_bias[0];
float bias = input_row_scale_bias[1];
float* output_row = output + row * output_columns;
for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
output_row[col] = scale * quantized + bias;
}
}
}
void FloatToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output) {
#ifdef USE_FBGEMM
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
bit_rate, input, input_rows, input_columns, output);
#else
FloatToFusedNBitRowwiseQuantizedSBHalf__base(
bit_rate, input, input_rows, input_columns, output);
#endif
}
void FusedNBitRowwiseQuantizedSBHalfToFloat(
int bit_rate,
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output) {
#ifdef USE_FBGEMM
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<float>(
bit_rate, input, input_rows, input_columns, output);
#else
FusedNBitRowwiseQuantizedSBHalfToFloat__base(
bit_rate, input, input_rows, input_columns, output);
#endif
}
} // namespace caffe2

View File

@ -0,0 +1,39 @@
#pragma once
#include <cstddef>
#include <cstdint>
namespace caffe2 {
void FloatToFused8BitRowwiseQuantized(
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
void Fused8BitRowwiseQuantizedToFloat(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output);
/**
* Row-wise quantization with fp16 scale and bias
*
* @param bit_rate can be 2, 4, or 8
*/
void FloatToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const float* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
void FusedNBitRowwiseQuantizedSBHalfToFloat(
int bit_rate,
const std::uint8_t* input,
size_t input_rows,
int input_columns,
float* output);
} // namespace caffe2

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
import argparse

View File

@ -0,0 +1,141 @@
#pragma once
#include <string.h>
#include <cmath>
#include <cstdint>
#include "c10/util/irange.h"
#include "caffe2/utils/conversions.h"
#include "vectorizer.h"
namespace caffe2 {
namespace perfkernels {
namespace {
template <typename T>
inline T sigmoid(T x) {
return 1 / (1 + std::exp(-x));
}
template <typename T>
inline T host_tanh(T x) {
return 2 * sigmoid(2 * x) - 1;
}
template <typename T>
inline void LstmUnitImpl(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias) {
const T forgetBias = convert::To<float, T>(forget_bias);
for (const auto n : c10::irange(N)) {
const bool valid = seqLengths == nullptr || t < seqLengths[n];
if (!valid) {
if (drop_states) {
memset(H, 0, sizeof(T) * D);
memset(C, 0, sizeof(T) * D);
} else {
memcpy(H, H_prev, sizeof(T) * D);
memcpy(C, C_prev, sizeof(T) * D);
}
} else {
const T* X_D = &X[D];
const T* X_2D = &X[2 * D];
const T* X_3D = &X[3 * D];
VECTOR_LOOP for (const auto d : c10::irange(D)) {
const T i = sigmoid(X[d]);
const T f = sigmoid(X_D[d] + forgetBias);
const T o = sigmoid(X_2D[d]);
const T g = host_tanh(X_3D[d]);
const T c_prev = C_prev[d];
const T c = f * c_prev + i * g;
C[d] = c;
const T host_tanh_c = host_tanh(c);
H[d] = o * host_tanh_c;
}
}
H_prev += D;
C_prev += D;
X += 4 * D;
C += D;
H += D;
}
}
template <typename T>
inline void LstmUnitGradientImpl(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias) {
const T localForgetBias = convert::To<float, T>(forget_bias);
for (const auto n : c10::irange(N)) {
const bool valid = seqLengths == nullptr || t < seqLengths[n];
if (!valid) {
if (drop_states) {
memset(C_prev_diff, 0, sizeof(T) * D);
memset(H_prev_diff, 0, sizeof(T) * D);
} else {
memcpy(H_prev_diff, H_diff, sizeof(T) * D);
memcpy(C_prev_diff, C_diff, sizeof(T) * D);
}
memset(X_diff, 0, 4 * sizeof(T) * D);
} else {
VECTOR_LOOP for (const auto d : c10::irange(D)) {
T* c_prev_diff = C_prev_diff + d;
T* h_prev_diff = H_prev_diff + d;
T* i_diff = X_diff + d;
T* f_diff = X_diff + 1 * D + d;
T* o_diff = X_diff + 2 * D + d;
T* g_diff = X_diff + 3 * D + d;
const T i = sigmoid(X[d]);
const T f = sigmoid(X[1 * D + d] + localForgetBias);
const T o = sigmoid(X[2 * D + d]);
const T g = host_tanh(X[3 * D + d]);
const T c_prev = C_prev[d];
const T c = C[d];
const T host_tanh_c = host_tanh(c);
const T c_term_diff =
C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
*c_prev_diff = c_term_diff * f;
*h_prev_diff = 0; // not used in 'valid' case
*i_diff = c_term_diff * g * i * (1 - i);
*f_diff = c_term_diff * c_prev * f * (1 - f);
*o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
*g_diff = c_term_diff * i * (1 - g * g);
}
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
C_diff += D;
H_diff += D;
X_diff += 4 * D;
H_prev_diff += D;
C_prev_diff += D;
}
}
} // namespace
} // namespace perfkernels
} // namespace caffe2

View File

@ -0,0 +1,73 @@
#pragma once
#include <cstdint>
namespace caffe2 {
namespace detail {
// Forward declration of the LSTMUnit templated
// implementation
template <typename T>
void LstmUnitCpu(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias);
// Forward specialization
extern template void LstmUnitCpu<float>(
const int N,
const int D,
const int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const bool drop_states,
float* C,
float* H,
const float forget_bias);
template <typename T>
void LstmUnitGradientCpu(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias);
extern template void LstmUnitGradientCpu<float>(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias);
} // namespace detail
} // namespace caffe2

View File

@ -0,0 +1,123 @@
#include "caffe2/perfkernels/lstm_unit_cpu-impl.h"
namespace caffe2 {
namespace perfkernels {
namespace {
// Explicit initialize for float and AVX2 vectorization
template void LstmUnitImpl<float>(
const int N,
const int D,
const int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const bool drop_states,
float* C,
float* H,
const float forget_bias);
template void LstmUnitGradientImpl<float>(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias);
} // namespace
// Define templated implementation fo LSTM kernels on CPU supporting AVX2
template <typename T>
void LstmUnitImpl__avx2_fma(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias) {
LstmUnitImpl(
N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias);
}
template <typename T>
void LstmUnitGradientImpl__avx2_fma(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias) {
LstmUnitGradientImpl(
N,
D,
t,
C_prev,
X,
seqLengths,
C,
H,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
}
// Explicit initialize for float
template void LstmUnitImpl__avx2_fma<float>(
const int N,
const int D,
const int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const bool drop_states,
float* C,
float* H,
const float forget_bias);
template void LstmUnitGradientImpl__avx2_fma<float>(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias);
} // namespace perfkernels
} // namespace caffe2

View File

@ -0,0 +1,125 @@
#include "caffe2/perfkernels/lstm_unit_cpu_common.h"
#include "caffe2/perfkernels/common.h"
#include "caffe2/perfkernels/lstm_unit_cpu-impl.h"
namespace caffe2 {
namespace detail {
// Define templated implementation fo LSTM kernels on CPU
template <typename T>
void LstmUnitCpu(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias) {
// Do CPU dispatching
AVX2_FMA_DO(
perfkernels::LstmUnitImpl,
N,
D,
t,
H_prev,
C_prev,
X,
seqLengths,
drop_states,
C,
H,
forget_bias);
perfkernels::LstmUnitImpl(
N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias);
}
template <typename T>
void LstmUnitGradientCpu(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias) {
// Do CPU dispatching
AVX2_FMA_DO(
perfkernels::LstmUnitGradientImpl,
N,
D,
t,
C_prev,
X,
seqLengths,
C,
H,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
perfkernels::LstmUnitGradientImpl(
N,
D,
t,
C_prev,
X,
seqLengths,
C,
H,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
}
// Explicit initialize for float
template void LstmUnitCpu<float>(
const int N,
const int D,
const int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const bool drop_states,
float* C,
float* H,
const float forget_bias);
template void LstmUnitGradientCpu<float>(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias);
} // namespace detail
} // namespace caffe2

View File

@ -0,0 +1,71 @@
#pragma once
#include <cstdint>
namespace caffe2 {
namespace perfkernels {
template <typename T>
void LstmUnitImpl__avx2_fma(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias);
template <typename T>
void LstmUnitGradientImpl__avx2_fma(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias);
// Forward declaration of specialized functions
extern template void LstmUnitImpl__avx2_fma(
const int N,
const int D,
const int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const bool drop_states,
float* C,
float* H,
const float forget_bias);
extern template void LstmUnitGradientImpl__avx2_fma(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias);
} // namespace perfkernels
} // namespace caffe2

35
caffe2/perfkernels/math.h Normal file
View File

@ -0,0 +1,35 @@
#pragma once
#include <cstdint>
namespace caffe2 {
namespace math {
// Returns the quantized and compressed values of floating inputs
// The "fused" representation stores the [bitwidth][tail][min][max]
// with the quantized data in one array. Since we store 8/bitwidth
// quantized data in one byte, the last buckets of some bytes may have
// unused bits. There are totally tail buckets are unused.
// We encode *bitwidth* and *tail* at the beginning,
// following by 32-bit floating data respresenting min and max.
// | bitwidth | tail | min | max | ... int8 data ... |
// | 1B | 1B | 4B | 4B | ...output_data....|
// In output_data: the b-th bucket of the i-th byte stores
// the i-th data of the b-th segment of input row
void quantize_and_compress(
const float* input_data,
std::uint8_t* output_data,
std::uint64_t input_size,
std::uint64_t bitwidth,
bool random,
const float* random_buffer);
void decompress_and_dequantize(
const std::uint8_t* input_data,
float* output_data,
std::uint64_t input_size);
} // namespace math
} // namespace caffe2

View File

@ -0,0 +1,246 @@
// Implements the math functions for CPU.
// The implementation in this file allows us to route the underlying numerical
// computation library to different compiler options (-mno-avx2 or -mavx2).
#include <immintrin.h>
#include <cmath>
#include <cstdint>
#include <c10/util/irange.h>
using std::uint64_t;
using std::uint8_t;
namespace caffe2 {
namespace math {
static constexpr double QEPSILON = 1e-8;
void quantize_and_compress__avx2(
const float* input_data,
uint8_t* output_data,
uint64_t input_size,
uint64_t bitwidth,
bool random,
const float* random_buffer) {
__m256i shuffle_mask_v = _mm256_set_epi8(
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
0x0c,
0x08,
0x04,
0x00,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
0xff,
0x0c,
0x08,
0x04,
0x00);
__m256i permute_mask_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
uint64_t data_per_byte = 8 / bitwidth;
uint64_t tail = input_size % data_per_byte;
tail = tail ? data_per_byte - tail : 0;
uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
// basic info
float minimum_element = INFINITY, maximum_element = -INFINITY;
for (const auto i : c10::irange(input_size)) {
minimum_element =
(input_data[i] < minimum_element) ? input_data[i] : minimum_element;
maximum_element =
(input_data[i] > maximum_element) ? input_data[i] : maximum_element;
}
output_data[0] = bitwidth;
output_data[1] = tail;
reinterpret_cast<float*>(output_data + 2)[0] = minimum_element;
reinterpret_cast<float*>(output_data + 2)[1] = maximum_element;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float gap_inverse = 1. / (gap + QEPSILON);
uint8_t max_q = (1 << bitwidth) - 1;
uint64_t bit_start = 0;
if (random) {
for (uint64_t start = 0; start < input_size; start += segment_size) {
uint64_t stride = start + segment_size <= input_size ? segment_size
: input_size - start;
uint64_t i = 0;
constexpr int VLEN = 8;
for (; i < stride / VLEN * VLEN; i += VLEN) {
__m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]);
__m256 fval_v = _mm256_loadu_ps(input_data + start + i);
__m256 thetimes_v = _mm256_mul_ps(
_mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
_mm256_set1_ps(gap_inverse));
__m256 rounded_v = _mm256_floor_ps(_mm256_add_ps(thetimes_v, r_v));
rounded_v = _mm256_max_ps(
_mm256_setzero_ps(),
_mm256_min_ps(_mm256_set1_ps(max_q), rounded_v));
__m256i qval_v = _mm256_cvtps_epi32(rounded_v);
__m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
reinterpret_cast<const __m128i*>(output_data + 10 + i)));
orval_v =
_mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
*reinterpret_cast<int64_t*>(output_data + 10 + i) =
_mm256_extract_epi64(orval_v, 0);
}
for (; i < stride; ++i) {
float fval = input_data[start + i];
float thetimes = (fval - minimum_element) * gap_inverse;
float rounded = floor(thetimes + random_buffer[start + i]);
rounded = rounded < static_cast<float>(max_q)
? rounded
: static_cast<float>(max_q);
rounded = rounded > 0.0f ? rounded : 0.0f;
uint8_t qval = rounded;
uint8_t orval = output_data[10 + i];
output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
}
bit_start += bitwidth;
}
} else {
// !random
for (uint64_t start = 0; start < input_size; start += segment_size) {
uint64_t stride = start + segment_size <= input_size ? segment_size
: input_size - start;
uint64_t i = 0;
constexpr int VLEN = 8;
for (; i < stride / VLEN * VLEN; i += VLEN) {
__m256 fval_v = _mm256_loadu_ps(input_data + start + i);
__m256 thetimes_v = _mm256_mul_ps(
_mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
_mm256_set1_ps(gap_inverse));
thetimes_v = _mm256_max_ps(
_mm256_setzero_ps(),
_mm256_min_ps(_mm256_set1_ps(max_q), thetimes_v));
__m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps(
thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
reinterpret_cast<const __m128i*>(output_data + 10 + i)));
orval_v =
_mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
*reinterpret_cast<int64_t*>(output_data + 10 + i) =
_mm256_extract_epi64(orval_v, 0);
}
for (; i < stride; ++i) {
float fval = input_data[start + i];
float thetimes = (fval - minimum_element) * gap_inverse;
thetimes = thetimes < static_cast<float>(max_q)
? thetimes
: static_cast<float>(max_q);
thetimes = thetimes > 0.0f ? thetimes : 0.0f;
uint8_t qval = nearbyint(thetimes);
uint8_t orval = output_data[10 + i];
output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
}
bit_start += bitwidth;
}
} // !random
}
void decompress_and_dequantize__avx2(
const uint8_t* input_data,
float* output_data,
uint64_t input_size) {
// basic info
const float minimum_element =
reinterpret_cast<const float*>(input_data + 2)[0];
const float maximum_element =
reinterpret_cast<const float*>(input_data + 2)[1];
const uint64_t bitwidth = input_data[0];
const float gap =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
QEPSILON; // for exact recovering
const uint64_t tail = input_data[1];
const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
// decoding
uint64_t bit_start = 0;
const uint64_t segment_size = input_size - 10;
for (uint64_t start = 0; start < output_size; start += segment_size) {
uint64_t stride = start + segment_size <= output_size ? segment_size
: output_size - start;
uint8_t mask = (1 << bitwidth) - 1;
uint64_t i = 0;
// Can process 8 elements at a time because we need to expand uint8_t
// to int32_t to use epi32 vector instructions.
constexpr int VLEN = 8;
for (; i < stride / VLEN * VLEN; i += VLEN) {
__m128i in_v = _mm_lddqu_si128(
reinterpret_cast<const __m128i*>(input_data + 10 + i));
__m256i out_epi32_v = _mm256_and_si256(
_mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start),
_mm256_set1_epi32(mask));
__m256 out_v = _mm256_fmadd_ps(
_mm256_cvtepi32_ps(out_epi32_v),
_mm256_set1_ps(gap),
_mm256_set1_ps(minimum_element));
_mm256_storeu_ps(output_data + start + i, out_v);
}
for (; i < stride; ++i) {
output_data[start + i] =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
}
bit_start += bitwidth;
}
}
} // namespace math
} // namespace caffe2

View File

@ -0,0 +1,168 @@
// Implements the math functions for CPU.
// The implementation in this file allows us to route the underlying numerical
// computation library to different compiler options (-mno-avx2 or -mavx2).
#include <cfloat>
#include <cmath>
#include <cstdint>
#include "common.h"
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include "math.h"
#include <c10/util/irange.h>
using std::uint64_t;
using std::uint8_t;
namespace caffe2 {
namespace math {
static constexpr double QEPSILON = 1e-8;
void quantize_and_compress__base(
const float* input_data,
uint8_t* output_data,
uint64_t input_size,
uint64_t bitwidth,
bool random,
const float* random_buffer) {
uint64_t data_per_byte = 8 / bitwidth;
uint64_t tail = input_size % data_per_byte;
tail = tail ? data_per_byte - tail : 0;
uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
// basic info
float minimum_element = INFINITY, maximum_element = -INFINITY;
for (const auto i : c10::irange(input_size)) {
minimum_element =
input_data[i] < minimum_element ? input_data[i] : minimum_element;
maximum_element =
input_data[i] > maximum_element ? input_data[i] : maximum_element;
}
output_data[0] = bitwidth;
output_data[1] = tail;
reinterpret_cast<float*>(output_data + 2)[0] = minimum_element;
reinterpret_cast<float*>(output_data + 2)[1] = maximum_element;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float gap_inverse = 1. / (gap + QEPSILON);
uint8_t max_q = (1 << bitwidth) - 1;
uint64_t bit_start = 0;
if (random) {
for (uint64_t start = 0; start < input_size; start += segment_size) {
uint64_t stride = start + segment_size <= input_size ? segment_size
: input_size - start;
uint64_t i = 0;
for (; i < stride; ++i) {
float fval = input_data[start + i];
float thetimes = (fval - minimum_element) * gap_inverse;
float rounded = floor(thetimes + random_buffer[start + i]);
rounded = rounded < static_cast<float>(max_q)
? rounded
: static_cast<float>(max_q);
rounded = rounded > 0.0f ? rounded : 0.0f;
uint8_t qval = rounded;
uint8_t orval = output_data[10 + i];
output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
}
bit_start += bitwidth;
}
} else {
for (uint64_t start = 0; start < input_size; start += segment_size) {
uint64_t stride = start + segment_size <= input_size ? segment_size
: input_size - start;
uint64_t i = 0;
for (; i < stride; ++i) {
float fval = input_data[start + i];
float thetimes = (fval - minimum_element) * gap_inverse;
thetimes = thetimes < static_cast<float>(max_q)
? thetimes
: static_cast<float>(max_q);
thetimes = thetimes > 0.0f ? thetimes : 0.0f;
uint8_t qval = nearbyint(thetimes);
uint8_t orval = output_data[10 + i];
output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
}
bit_start += bitwidth;
}
}
}
decltype(quantize_and_compress__base) quantize_and_compress__avx2;
void quantize_and_compress(
const float* input_data,
uint8_t* output_data,
uint64_t input_size,
uint64_t bitwidth,
bool random,
const float* random_buffer) {
AVX2_DO(
quantize_and_compress,
input_data,
output_data,
input_size,
bitwidth,
random,
random_buffer);
BASE_DO(
quantize_and_compress,
input_data,
output_data,
input_size,
bitwidth,
random,
random_buffer);
}
void decompress_and_dequantize__base(
const uint8_t* input_data,
float* output_data,
uint64_t input_size) {
// basic info
const float minimum_element =
reinterpret_cast<const float*>(input_data + 2)[0];
const float maximum_element =
reinterpret_cast<const float*>(input_data + 2)[1];
const uint64_t bitwidth = input_data[0];
const float gap =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
QEPSILON; // for exact recovering
const uint64_t tail = input_data[1];
const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
// decoding
uint64_t bit_start = 0;
const uint64_t segment_size = input_size - 10;
for (uint64_t start = 0; start < output_size; start += segment_size) {
uint64_t stride = start + segment_size <= output_size ? segment_size
: output_size - start;
uint8_t mask = (1 << bitwidth) - 1;
uint64_t i = 0;
for (; i < stride; ++i) {
output_data[start + i] =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
}
bit_start += bitwidth;
}
}
decltype(decompress_and_dequantize__base) decompress_and_dequantize__avx2;
void decompress_and_dequantize(
const uint8_t* input_data,
float* output_data,
uint64_t input_size) {
AVX2_DO(decompress_and_dequantize, input_data, output_data, input_size);
BASE_DO(decompress_and_dequantize, input_data, output_data, input_size);
}
} // namespace math
} // namespace caffe2

View File

@ -0,0 +1,88 @@
#include <c10/util/Half.h>
#include "caffe2/perfkernels/typed_axpy.h"
#include "caffe2/perfkernels/common.h"
namespace caffe2 {
void TypedAxpy__base(int N, const float a, const float* x, float* y) {
for (int i = 0; i < N; ++i) {
y[i] += a * x[i];
}
}
decltype(TypedAxpy__base) TypedAxpy__avx2_fma;
decltype(TypedAxpy__base) TypedAxpy__avx_f16c;
template <>
void TypedAxpy<float, float>(int N, const float a, const float* x, float* y) {
AVX2_FMA_DO(TypedAxpy, N, a, x, y);
AVX_F16C_DO(TypedAxpy, N, a, x, y);
BASE_DO(TypedAxpy, N, a, x, y);
}
void TypedAxpyHalffloat__base(
int N,
const float a,
const at::Half* x,
float* y) {
for (int i = 0; i < N; ++i) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t intval;
float floatval;
} t1;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t t2, t3;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t1.intval = x[i].x & 0x7fff; // Non-sign bits
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t2 = x[i].x & 0x8000; // Sign bit
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t3 = x[i].x & 0x7c00; // Exponent
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t1.intval <<= 13; // Align mantissa on MSB
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t2 <<= 16; // Shift sign bit into position
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
t1.intval += 0x38000000; // Adjust bias
t1.intval = (t3 == 0 ? 0 : t1.intval); // Denormals-as-zero
t1.intval |= t2; // Re-insert sign bit
y[i] += t1.floatval * a;
}
}
decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx2_fma;
decltype(TypedAxpyHalffloat__base) TypedAxpyHalffloat__avx_f16c;
template <>
void TypedAxpy<at::Half, float>(
int N,
const float a,
const at::Half* x,
float* y) {
AVX2_FMA_DO(TypedAxpyHalffloat, N, a, x, y);
AVX_F16C_DO(TypedAxpyHalffloat, N, a, x, y);
BASE_DO(TypedAxpyHalffloat, N, a, x, y);
}
void TypedAxpy_uint8_float__base(
int N,
const float a,
const std::uint8_t* x,
float* y) {
for (int i = 0; i < N; ++i) {
y[i] += (float)(x[i]) * a;
}
}
decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx2_fma;
decltype(TypedAxpy_uint8_float__base) TypedAxpy_uint8_float__avx_f16c;
template <>
void TypedAxpy<std::uint8_t, float>(
int N,
const float a,
const std::uint8_t* x,
float* y) {
AVX2_FMA_DO(TypedAxpy_uint8_float, N, a, x, y);
BASE_DO(TypedAxpy_uint8_float, N, a, x, y);
}
} // namespace caffe2

View File

@ -0,0 +1,12 @@
#pragma once
namespace caffe2 {
// Similar to Axpy that calculate y = a * x + y, but allowing x and y to be
// of different data types.
// It also provides a performance optimization hint (use_a) to see if a is going
// to be 1 or not.
template <typename IN, typename OUT>
void TypedAxpy(int N, const OUT a, const IN* x, OUT* y);
} // namespace caffe2

View File

@ -0,0 +1,68 @@
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
#include <c10/util/Half.h>
#include <emmintrin.h>
#include <immintrin.h>
namespace caffe2 {
void TypedAxpy__avx_f16c(int N, const float a, const float* x, float* y) {
int current = 0;
const int bound = (N % 8) ? N - 8 : N;
__m256 mma = _mm256_set1_ps(a);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (; current < bound; current += 8) {
_mm256_storeu_ps(
y + current,
_mm256_add_ps(
_mm256_mul_ps(mma, _mm256_loadu_ps(x + current)),
_mm256_loadu_ps(y + current)));
}
if (bound != N) {
while (current < N) {
y[current] += x[current] * a;
++current;
}
}
}
void TypedAxpyHalffloat__avx_f16c(
int N,
const float a,
const at::Half* x,
float* y) {
// if x does not start at the 16 byte boundary, we will process the first few.
// before we get to a real one.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
while ((reinterpret_cast<unsigned long>(x) % 16) && N) {
*(y++) += _cvtsh_ss((*(x++)).x) * a;
--N;
}
// From now on we can do vectorized additions using __m256, which is 8 floats,
// so we will vectorize every 8 element and then resort to cvtsh_ss.
__m256 mma = _mm256_set1_ps(a);
int current = 0;
const int bound = (N % 8) ? N - 8 : N;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (; current < bound; current += 8) {
__m128i mmx_16 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current));
__m256 mmx_32 = _mm256_cvtph_ps(mmx_16);
__m256 mmy_in = _mm256_loadu_ps(y + current);
__m256 mmmul = _mm256_mul_ps(mmx_32, mma);
__m256 mmy_out = _mm256_add_ps(mmmul, mmy_in);
_mm256_storeu_ps(y + current, mmy_out);
}
if (bound != N) {
while (current < N) {
y[current] += _cvtsh_ss(x[current].x) * a;
++current;
}
}
}
} // namespace caffe2

View File

@ -0,0 +1,104 @@
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"
#include <c10/util/Half.h>
#include <emmintrin.h>
#include <immintrin.h>
namespace caffe2 {
void TypedAxpy__avx2_fma(int N, const float a, const float* x, float* y) {
int current = 0;
const int bound = (N % 8) ? N - 8 : N;
__m256 mma = _mm256_set1_ps(a);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (; current < bound; current += 8) {
_mm256_storeu_ps(
y + current,
_mm256_fmadd_ps(
mma, _mm256_loadu_ps(x + current), _mm256_loadu_ps(y + current)));
}
if (bound != N) {
while (current < N) {
y[current] += x[current] * a;
++current;
}
}
}
void TypedAxpyHalffloat__avx2_fma(
int N,
const float a,
const at::Half* x,
float* y) {
// if x does not start at the 16 byte boundary, we will process the first few.
// before we get to a real one.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
while ((reinterpret_cast<unsigned long>(x) % 16) && N) {
*(y++) += _cvtsh_ss((*(x++)).x) * a;
--N;
}
// From now on we can do vectorized additions using __m256, which is 8 floats,
// so we will vectorize every 8 element and then resort to cvtsh_ss.
__m256 mma = _mm256_set1_ps(a);
int current = 0;
const int bound = (N % 8) ? N - 8 : N;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (; current < bound; current += 8) {
__m128i mmx_16 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current));
__m256 mmx_32 = _mm256_cvtph_ps(mmx_16);
__m256 mmy = _mm256_loadu_ps(y + current);
mmy = _mm256_fmadd_ps(mmx_32, mma, mmy);
_mm256_storeu_ps(y + current, mmy);
}
if (bound != N) {
while (current < N) {
y[current] += _cvtsh_ss(x[current].x) * a;
++current;
}
}
}
void TypedAxpy_uint8_float__avx2_fma(
int N,
const float a,
const std::uint8_t* x,
float* y) {
// if x does not start at the 16 byte boundary, we will process the first few.
// before we get to a real one.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
while ((reinterpret_cast<unsigned long>(x) % 16) && N) {
*(y++) += static_cast<float>(*(x++)) * a;
--N;
}
// From now on we can do vectorized additions using __m256, which is 8 floats,
// so we will vectorize every 8 element and then resort to cvtsh_ss.
__m256 mma = _mm256_set1_ps(a);
int current = 0;
const int bound = (N % 8) ? N - 8 : N;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
for (; current < bound; current += 8) {
__m256i mmx_int32 = _mm256_cvtepi8_epi32(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(x + current)));
__m256 mmx_fp32 = _mm256_cvtepi32_ps(mmx_int32);
__m256 mmy = _mm256_loadu_ps(y + current);
mmy = _mm256_fmadd_ps(mmx_fp32, mma, mmy);
_mm256_storeu_ps(y + current, mmy);
}
if (bound != N) {
while (current < N) {
y[current] += (float)(x[current]) * a;
++current;
}
}
}
} // namespace caffe2

View File

@ -0,0 +1,28 @@
#pragma once
#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG)
#if defined(__clang__) && (__clang_major__ > 7)
#define IS_SANITIZER \
((__has_feature(address_sanitizer) == 1) || \
(__has_feature(memory_sanitizer) == 1) || \
(__has_feature(thread_sanitizer) == 1) || \
(__has_feature(undefined_sanitizer) == 1))
#if IS_SANITIZER == 0
#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)")
#define FAST_MATH _Pragma("clang fp contract(fast)")
#define VECTORIZED_KERNEL 1
#endif
#elif defined(_OPENMP) && (_OPENMP >= 201511)
// Support with OpenMP4.5 and above
#define VECTOR_LOOP _Pragma("omp for simd")
#define VECTORIZED_KERNEL 1
#define FAST_MATH
#endif
#endif
#ifndef VECTOR_LOOP
// Not supported
#define VECTOR_LOOP
#define FAST_MATH
#endif

View File

@ -10,11 +10,9 @@ if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
else()
file(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/.ci/docker/aotriton_version.txt" __AOTRITON_CI_INFO)
list(GET __AOTRITON_CI_INFO 3 __AOTRITON_CI_COMMIT)
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG ${__AOTRITON_CI_COMMIT}
GIT_TAG 24a3fe9cb57e5cda3c923df29743f9767194cc27
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
BINARY_DIR ${__AOTRITON_BUILD_DIR}
PREFIX ${__AOTRITON_INSTALL_DIR}

View File

@ -121,7 +121,6 @@ function(caffe2_print_configuration_summary)
if(${USE_ROCM})
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")

View File

@ -479,7 +479,9 @@ function(torch_compile_options libname)
# templated classes crossing library boundary get duplicated (but identical)
# definitions. It's easier to just disable it.
target_compile_options(${libname} PRIVATE
$<$<COMPILE_LANGUAGE:CXX>: -fvisibility=hidden>)
$<$<COMPILE_LANGUAGE:CXX>: -fvisibility=hidden>
$<$<COMPILE_LANGUAGE:OBJC>: -fvisibility=hidden>
$<$<COMPILE_LANGUAGE:OBJCXX>: -fvisibility=hidden>)
endif()
# Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression)

Some files were not shown because too many files have changed in this diff Show More