mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
73c49ee963 |
@ -1,5 +0,0 @@
|
||||
0.6b
|
||||
manylinux_2_17
|
||||
rocm6
|
||||
04b5df8c8123f90cba3ede7e971e6fbc6040d506
|
||||
3db6ecbc915893ff967abd6e1b43bd5f54949868873be60dc802086c3863e648
|
@ -113,18 +113,18 @@ COPY triton_version.txt triton_version.txt
|
||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
|
||||
|
||||
# Install AOTriton (Early fail)
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# Install ccache/sccache (do this last, so we get priority in PATH)
|
||||
COPY ./common/install_cache.sh install_cache.sh
|
||||
ENV PATH /opt/cache/bin:$PATH
|
||||
RUN bash ./install_cache.sh && rm install_cache.sh
|
||||
|
||||
# Install AOTriton
|
||||
COPY ci_commit_pins/aotriton.txt aotriton.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# Include BUILD_ENVIRONMENT environment variable in image
|
||||
ARG BUILD_ENVIRONMENT
|
||||
ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}
|
||||
|
1
.ci/docker/ci_commit_pins/aotriton.txt
Normal file
1
.ci/docker/ci_commit_pins/aotriton.txt
Normal file
@ -0,0 +1 @@
|
||||
24a3fe9cb57e5cda3c923df29743f9767194cc27
|
31
.ci/docker/common/install_aotriton.sh
Executable file → Normal file
31
.ci/docker/common/install_aotriton.sh
Executable file → Normal file
@ -4,20 +4,21 @@ set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
TARBALL='aotriton.tar.bz2'
|
||||
# This read command alwasy returns with exit code 1
|
||||
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
|
||||
ARCH=$(uname -m)
|
||||
AOTRITON_DIR="aotriton"
|
||||
AOTRITON_PINNED_NAME="aotriton" # No .txt extension
|
||||
AOTRITON_PINNED_COMMIT=$(get_pinned_commit ${AOTRITON_PINNED_NAME})
|
||||
AOTRITON_INSTALL_PREFIX="$1"
|
||||
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}.tar.bz2"
|
||||
|
||||
cd "${AOTRITON_INSTALL_PREFIX}"
|
||||
# Must use -L to follow redirects
|
||||
curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}"
|
||||
ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1)
|
||||
if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then
|
||||
echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256},"
|
||||
echo " which does not match the expected value ${SHA256}."
|
||||
exit
|
||||
fi
|
||||
tar xf "${TARBALL}" && rm -rf "${TARBALL}"
|
||||
git clone https://github.com/ROCm/aotriton.git "${AOTRITON_DIR}"
|
||||
cd "${AOTRITON_DIR}"
|
||||
git checkout "${AOTRITON_PINNED_COMMIT}"
|
||||
git submodule sync --recursive
|
||||
git submodule update --init --recursive --force --depth 1
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G Ninja -DCMAKE_INSTALL_PREFIX=./install_dir -DCMAKE_BUILD_TYPE=Release -DAOTRITON_COMPRESS_KERNEL=OFF -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_SHARED=ON
|
||||
ninja install
|
||||
mkdir -p "${AOTRITON_INSTALL_PREFIX}"
|
||||
cp -r install_dir/* "${AOTRITON_INSTALL_PREFIX}"
|
||||
find /tmp/ -mindepth 1 -delete
|
||||
rm -rf ~/.triton
|
||||
|
@ -105,18 +105,18 @@ COPY triton_version.txt triton_version.txt
|
||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
|
||||
|
||||
# Install AOTriton
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# Install ccache/sccache (do this last, so we get priority in PATH)
|
||||
COPY ./common/install_cache.sh install_cache.sh
|
||||
ENV PATH /opt/cache/bin:$PATH
|
||||
RUN bash ./install_cache.sh && rm install_cache.sh
|
||||
|
||||
# Install AOTriton
|
||||
COPY ci_commit_pins/aotriton.txt aotriton.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN bash ./install_aotriton.sh /opt/rocm/aotriton && rm -rf install_aotriton.sh aotriton aotriton.txt common_utils.sh
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# Include BUILD_ENVIRONMENT environment variable in image
|
||||
ARG BUILD_ENVIRONMENT
|
||||
ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}
|
||||
|
@ -62,6 +62,4 @@ readability-string-compare,
|
||||
'
|
||||
HeaderFilterRegex: '^(aten/|c10/|torch/).*$'
|
||||
WarningsAsErrors: '*'
|
||||
CheckOptions:
|
||||
misc-header-include-cycle.IgnoredFilesList: 'format.h;ivalue.h;custom_class.h;Dict.h;List.h'
|
||||
...
|
||||
|
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
b829e936f7cc61b48149f5f957a451a38bf2a178
|
||||
1980f8af5bcd0bb2ce51965cf79d8d4c25dad8a0
|
||||
|
@ -1099,6 +1099,7 @@ exclude_patterns = [
|
||||
'test/test_namedtuple_return_api.py',
|
||||
'test/test_native_functions.py',
|
||||
'test/test_native_mha.py',
|
||||
'test/test_nestedtensor.py',
|
||||
'test/test_nn.py',
|
||||
'test/test_out_dtype_op.py',
|
||||
'test/test_overrides.py',
|
||||
|
@ -461,8 +461,15 @@ filegroup(
|
||||
filegroup(
|
||||
name = "caffe2_perfkernels_srcs",
|
||||
srcs = [
|
||||
"caffe2/perfkernels/adagrad.cc",
|
||||
"caffe2/perfkernels/embedding_lookup.cc",
|
||||
"caffe2/perfkernels/embedding_lookup_idx.cc",
|
||||
"caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc",
|
||||
"caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup_idx.cc",
|
||||
"caffe2/perfkernels/fused_nbit_rowwise_conversion.cc",
|
||||
"caffe2/perfkernels/lstm_unit_cpu_common.cc",
|
||||
"caffe2/perfkernels/math_cpu_base.cc",
|
||||
"caffe2/perfkernels/typed_axpy.cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -865,13 +865,12 @@ cmake_dependent_option(
|
||||
# Suspect users building from source will need this
|
||||
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
# CAVEAT: Again, do not check USE_ROCM here Flash Attention2 will error while
|
||||
# building for sm52 while Mem Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
USE_MEM_EFF_ATTENTION
|
||||
"Enable memory-efficient attention for scaled dot product attention.\
|
||||
Will be disabled if not supported by the platform" ON
|
||||
"USE_CUDA OR USE_ROCM" OFF)
|
||||
Will be disabled if not supported by the platform" ON "USE_CUDA" OFF)
|
||||
|
||||
if(DEBUG_CUDA)
|
||||
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
|
||||
|
@ -40,7 +40,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permissions strictly required, and keep your libraries updated with the latest security patches.
|
||||
If you plan to open your model to untrusted inputs, be aware that inputs can also be used as vectors by malicious agents. To minimize risks, make sure to give your model only the permisisons strictly required, and keep your libraries updated with the lates security patches.
|
||||
|
||||
If applicable, prepare your model against bad inputs and prompt injections. Some recommendations:
|
||||
- Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using fuzzing for prompt injection).
|
||||
|
@ -364,7 +364,7 @@ class TORCH_API Context {
|
||||
bool enabled_flashSDP = true;
|
||||
bool enabled_mem_efficientSDP = true;
|
||||
bool enabled_mathSDP = true;
|
||||
bool enabled_cudnnSDP = true;
|
||||
bool enabled_cudnnSDP = false;
|
||||
#ifdef USE_ROCM
|
||||
bool benchmark_cudnn = true;
|
||||
#else
|
||||
@ -385,11 +385,8 @@ class TORCH_API Context {
|
||||
? at::LinalgBackend::Cusolver
|
||||
: at::LinalgBackend::Default;
|
||||
at::BlasBackend blas_preferred_backend =
|
||||
#ifdef USE_ROCM
|
||||
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
|
||||
#else
|
||||
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
|
||||
#endif
|
||||
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
|
||||
c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true)
|
||||
? at::BlasBackend::Cublaslt
|
||||
: at::BlasBackend::Cublas;
|
||||
#ifdef C10_MOBILE
|
||||
|
@ -143,7 +143,7 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
|
||||
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
|
||||
false, "Unsupported device_type: " + c10::to_string(ctx.device_type));
|
||||
}
|
||||
}
|
||||
|
||||
@ -167,7 +167,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLInt:
|
||||
@ -186,7 +186,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kInt bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kInt bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat:
|
||||
@ -202,7 +202,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLBfloat:
|
||||
@ -212,7 +212,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLComplex:
|
||||
@ -228,7 +228,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLBool:
|
||||
@ -238,11 +238,11 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
|
||||
false, "Unsupported kDLBool bits " + c10::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
|
||||
TORCH_CHECK(false, "Unsupported code " + c10::to_string(dtype.code));
|
||||
}
|
||||
return stype;
|
||||
}
|
||||
@ -298,7 +298,9 @@ Tensor fromDLPack(DLManagedTensor* src) {
|
||||
return fromDLPack(src, std::move(deleter));
|
||||
}
|
||||
|
||||
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
|
||||
Tensor fromDLPack(
|
||||
DLManagedTensor* src,
|
||||
std::function<void(void*)> deleter) {
|
||||
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
|
||||
ScalarType stype = toScalarType(src->dl_tensor.dtype);
|
||||
if (!src->dl_tensor.strides) {
|
||||
|
@ -462,7 +462,7 @@ inline Tensor _sum_to(
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
|
||||
if (shape[i - leading_dims] == 1 &&
|
||||
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
|
@ -19,13 +19,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
|
||||
auto strides = t->sym_strides();
|
||||
auto sizes = t->sym_sizes();
|
||||
for (const auto i : c10::irange(strides.size())) {
|
||||
// NB: The size oblivious test is written very carefully here. When
|
||||
// unbacked SymInts are involved, we should try to conservatively report
|
||||
// if memory overlap /could/ happen under some setting of unbacked
|
||||
// SymInts. Thus, if I have u0 size, we should assume that this has > 1
|
||||
// elements (first expression), but if I have a u0 stride, I should NOT
|
||||
// assume that it is not zero (second expression)
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) {
|
||||
if (strides[i] == 0 && sizes[i] > 1) {
|
||||
return MemOverlap::Yes;
|
||||
}
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ TORCH_API std::ostream& operator<<(
|
||||
const std::vector<TensorIndex>& tensor_indices);
|
||||
|
||||
namespace impl {
|
||||
inline Tensor applySlice(
|
||||
static inline Tensor applySlice(
|
||||
const Tensor& self,
|
||||
int64_t dim,
|
||||
c10::SymInt start,
|
||||
@ -227,7 +227,7 @@ inline Tensor applySlice(
|
||||
dim, std::move(start), std::move(stop), std::move(step));
|
||||
}
|
||||
|
||||
inline Tensor applySelect(
|
||||
static inline Tensor applySelect(
|
||||
const Tensor& self,
|
||||
int64_t dim,
|
||||
SymInt index,
|
||||
@ -266,7 +266,9 @@ inline Tensor applySelect(
|
||||
return self.select_symint(dim, std::move(index));
|
||||
}
|
||||
|
||||
inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
|
||||
static inline Tensor boolToIndexingTensorCPUOrCUDA(
|
||||
const Tensor& self,
|
||||
bool value) {
|
||||
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
||||
// false as empty.
|
||||
if (value) {
|
||||
@ -276,7 +278,7 @@ inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
|
||||
}
|
||||
}
|
||||
|
||||
inline Tensor boolToIndexingTensorNonNativeDeviceType(
|
||||
static inline Tensor boolToIndexingTensorNonNativeDeviceType(
|
||||
const Tensor& self,
|
||||
bool value) {
|
||||
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
||||
@ -288,7 +290,7 @@ inline Tensor boolToIndexingTensorNonNativeDeviceType(
|
||||
}
|
||||
}
|
||||
|
||||
inline Tensor boolToIndexingTensor(
|
||||
static inline Tensor boolToIndexingTensor(
|
||||
const Tensor& self,
|
||||
bool value,
|
||||
const at::Device& self_device) {
|
||||
@ -299,13 +301,13 @@ inline Tensor boolToIndexingTensor(
|
||||
}
|
||||
}
|
||||
|
||||
inline Tensor scalarToTensorNonNativeDeviceType(
|
||||
static inline Tensor scalarToTensorNonNativeDeviceType(
|
||||
const Scalar& v,
|
||||
const TensorOptions& options) {
|
||||
return at::scalar_tensor(v, options);
|
||||
}
|
||||
|
||||
inline void recordTensorIndex(
|
||||
static inline void recordTensorIndex(
|
||||
const Tensor& tensor,
|
||||
std::vector<Tensor>& outIndices,
|
||||
int64_t* dim_ptr) {
|
||||
@ -315,7 +317,7 @@ inline void recordTensorIndex(
|
||||
(*dim_ptr)++;
|
||||
};
|
||||
|
||||
inline c10::List<::std::optional<Tensor>> typeConvertIndices(
|
||||
static inline c10::List<::std::optional<Tensor>> typeConvertIndices(
|
||||
const Tensor& /*self*/,
|
||||
std::vector<Tensor>&& indices) {
|
||||
c10::List<::std::optional<Tensor>> converted_inds;
|
||||
@ -336,7 +338,7 @@ inline c10::List<::std::optional<Tensor>> typeConvertIndices(
|
||||
// construct a `std::vector` container to be consumed by the C++
|
||||
// `count_specified_dimensions` function, which adds 100s of nanoseconds
|
||||
// overhead and is undesirable.
|
||||
inline int64_t count_specified_dimensions(
|
||||
static inline int64_t count_specified_dimensions(
|
||||
const ArrayRef<TensorIndex>& indices) {
|
||||
// Count the number of indexed dimensions (everything but ellipsis and None)
|
||||
int64_t count = 0;
|
||||
@ -370,7 +372,7 @@ inline int64_t count_specified_dimensions(
|
||||
//
|
||||
// The rest of the functions are in `at::indexing::impl` namespace, signifying
|
||||
// that they shouldn't be used from Python indexing implementation.
|
||||
inline Tensor scalarToTensor(
|
||||
static inline Tensor scalarToTensor(
|
||||
const Scalar& v,
|
||||
const TensorOptions& options,
|
||||
const at::Device& self_device) {
|
||||
@ -385,7 +387,7 @@ inline Tensor scalarToTensor(
|
||||
// To match numpy semantics:
|
||||
// As a special case for backwards compatibility,
|
||||
// strip away unit dimensions from the left of 'src'
|
||||
inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
|
||||
static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
|
||||
size_t first_non1_src = sizes.size();
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
// Unbacked SymInt has different behavior, but this is sound because
|
||||
@ -400,7 +402,7 @@ inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
|
||||
return sizes.slice(first_non1_src);
|
||||
}
|
||||
|
||||
inline void copy_to(const Tensor& dst, const Tensor& src) {
|
||||
static inline void copy_to(const Tensor& dst, const Tensor& src) {
|
||||
if (dst.sym_sizes().equals(src.sym_sizes())) {
|
||||
// A shortcut to avoid generating hard-coded constant sizes during tracing.
|
||||
// This is not a perfect solution: when src & dst have different shapes,
|
||||
@ -419,7 +421,7 @@ inline void copy_to(const Tensor& dst, const Tensor& src) {
|
||||
|
||||
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
|
||||
// indexing functions from Python ]
|
||||
inline Tensor handleDimInMultiDimIndexing(
|
||||
static inline Tensor handleDimInMultiDimIndexing(
|
||||
const Tensor& prev_dim_result,
|
||||
const Tensor& original_tensor,
|
||||
const TensorIndex& index,
|
||||
@ -507,7 +509,7 @@ inline Tensor handleDimInMultiDimIndexing(
|
||||
namespace impl {
|
||||
// This mirrors `applySlicing` in
|
||||
// torch/csrc/autograd/python_variable_indexing.cpp
|
||||
inline Tensor applySlicing(
|
||||
static inline Tensor applySlicing(
|
||||
const Tensor& self,
|
||||
const ArrayRef<TensorIndex>& indices,
|
||||
std::vector<Tensor>& outIndices,
|
||||
@ -548,13 +550,13 @@ inline Tensor applySlicing(
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
inline Tensor dispatch_index(
|
||||
static inline Tensor dispatch_index(
|
||||
const Tensor& self,
|
||||
std::vector<Tensor>&& indices) {
|
||||
return self.index(impl::typeConvertIndices(self, std::move(indices)));
|
||||
}
|
||||
|
||||
inline Tensor dispatch_index_put_(
|
||||
static inline Tensor dispatch_index_put_(
|
||||
Tensor& self,
|
||||
std::vector<Tensor>&& indices,
|
||||
const Tensor& value) {
|
||||
@ -596,7 +598,7 @@ inline Tensor dispatch_index_put_(
|
||||
// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
|
||||
// `disable_slice_optimization` when calling C++ tensor indexing functions from
|
||||
// Python ]
|
||||
inline Tensor get_item(
|
||||
static inline Tensor get_item(
|
||||
const Tensor& self,
|
||||
const ArrayRef<TensorIndex>& indices,
|
||||
bool disable_slice_optimization = false) {
|
||||
@ -662,7 +664,7 @@ inline Tensor get_item(
|
||||
// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
|
||||
// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
|
||||
// tensor indexing functions from Python ]
|
||||
inline void set_item(
|
||||
static inline void set_item(
|
||||
const Tensor& self,
|
||||
const ArrayRef<TensorIndex>& indices,
|
||||
const Tensor& value,
|
||||
|
@ -22,6 +22,7 @@
|
||||
#endif
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <c10/util/SmallBuffer.h>
|
||||
|
||||
#include <array>
|
||||
@ -1397,7 +1398,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type));
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", c10::to_string((int)setup_type));
|
||||
}
|
||||
//coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
|
||||
if (ndim() > 1){
|
||||
|
@ -31,7 +31,7 @@ struct TemplateEnv {
|
||||
// Add a number 'v' to the map at key 'k'
|
||||
template <typename T>
|
||||
void d(const std::string& k, const T& v) {
|
||||
strings_[k] = std::to_string(v);
|
||||
strings_[k] = c10::to_string(v);
|
||||
lists_.erase(k);
|
||||
}
|
||||
|
||||
|
@ -150,7 +150,7 @@ Generator make_generator(Args&&... args) {
|
||||
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
||||
*/
|
||||
template <typename T>
|
||||
inline T * check_generator(std::optional<Generator> gen) {
|
||||
static inline T * check_generator(std::optional<Generator> gen) {
|
||||
TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
|
||||
TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
|
||||
TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
|
||||
@ -164,7 +164,7 @@ inline T * check_generator(std::optional<Generator> gen) {
|
||||
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
||||
*/
|
||||
template <typename T>
|
||||
inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
|
||||
static inline T* get_generator_or_default(const std::optional<Generator>& gen, const Generator& default_gen) {
|
||||
return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
|
||||
}
|
||||
|
||||
@ -177,7 +177,7 @@ namespace detail {
|
||||
* - The new state tensor must be a torch.ByteTensor
|
||||
* - Data of the new state tensor must be contiguous
|
||||
*/
|
||||
inline void check_rng_state(const c10::TensorImpl& new_state) {
|
||||
static inline void check_rng_state(const c10::TensorImpl& new_state) {
|
||||
TORCH_CHECK_TYPE(
|
||||
new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
|
||||
"RNG state must be a torch.ByteTensor"
|
||||
|
@ -478,6 +478,8 @@ namespace impl {
|
||||
// (maybe except for some internal prim ops).
|
||||
using GenericList = List<IValue>;
|
||||
|
||||
const IValue* ptr_to_first_element(const GenericList& list);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,4 +350,11 @@ void List<T>::unsafeSetElementType(TypePtr t) {
|
||||
impl_->elementType = std::move(t);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
inline const IValue* ptr_to_first_element(const GenericList& list) {
|
||||
return &list.impl_->list[0];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -953,7 +953,7 @@ TensorBase make_tensor_base(Args&&... args) {
|
||||
|
||||
} // namespace detail
|
||||
|
||||
inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
|
||||
static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
|
||||
return legacyExtractDispatchKey(t.key_set());
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ namespace impl {
|
||||
// on TLS.
|
||||
//
|
||||
// NB: If there is no valid dispatch key, this will return Undefined
|
||||
inline DispatchKeySet computeDispatchKeySet(
|
||||
static inline DispatchKeySet computeDispatchKeySet(
|
||||
DispatchKeySet ks,
|
||||
// The key mask lets us eliminate (by zero entries) keys which should not
|
||||
// be considered for dispatch. There are two cases when we use this:
|
||||
|
@ -66,51 +66,51 @@ class Operation {
|
||||
|
||||
// treat the last N elements of the stack as a list, looking up
|
||||
// element i
|
||||
inline IValue& peek(Stack& stack, size_t i, size_t N) {
|
||||
static inline IValue& peek(Stack& stack, size_t i, size_t N) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
|
||||
return *(stack.end() - N + i);
|
||||
}
|
||||
inline IValue& peek(Stack* stack, size_t i, size_t N) {
|
||||
static inline IValue& peek(Stack* stack, size_t i, size_t N) {
|
||||
return peek(*stack, i, N);
|
||||
}
|
||||
inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
|
||||
static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
|
||||
return *(stack.end() - N + i);
|
||||
}
|
||||
inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
|
||||
static inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
|
||||
return peek(*stack, i, N);
|
||||
}
|
||||
// treat the last N elements of the stack as a list, looking up the
|
||||
// slice starting at index i and having length len
|
||||
inline at::ArrayRef<IValue> peekSlice(
|
||||
static inline at::ArrayRef<IValue> peekSlice(
|
||||
const Stack& stack,
|
||||
size_t i,
|
||||
size_t len,
|
||||
size_t N) {
|
||||
return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
|
||||
}
|
||||
inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
|
||||
static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
|
||||
return peekSlice(stack, 0, N, N);
|
||||
}
|
||||
inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
|
||||
static inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
|
||||
return last(*stack, N);
|
||||
}
|
||||
inline void drop(Stack& stack, size_t n) {
|
||||
static inline void drop(Stack& stack, size_t n) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
|
||||
stack.erase(stack.end() - n, stack.end());
|
||||
}
|
||||
inline void drop(Stack* stack, size_t n) {
|
||||
static inline void drop(Stack* stack, size_t n) {
|
||||
drop(*stack, n);
|
||||
}
|
||||
inline IValue pop(Stack& stack) {
|
||||
static inline IValue pop(Stack& stack) {
|
||||
auto r = std::move(stack.back());
|
||||
stack.pop_back();
|
||||
return r;
|
||||
}
|
||||
inline IValue pop(Stack* stack) {
|
||||
static inline IValue pop(Stack* stack) {
|
||||
return pop(*stack);
|
||||
}
|
||||
inline std::vector<IValue> pop(Stack& stack, size_t n) {
|
||||
static inline std::vector<IValue> pop(Stack& stack, size_t n) {
|
||||
std::vector<IValue> result;
|
||||
result.reserve(n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
@ -127,7 +127,7 @@ inline std::vector<IValue> pop(Stack& stack, size_t n) {
|
||||
// b = pop(stack).toTensor();
|
||||
// a = pop(stack).toInt();
|
||||
template <typename... Types>
|
||||
inline void pop(Stack& stack, Types&... args) {
|
||||
static inline void pop(Stack& stack, Types&... args) {
|
||||
size_t i = 0;
|
||||
constexpr size_t N = sizeof...(args);
|
||||
(void)std::initializer_list<int>{
|
||||
@ -135,15 +135,15 @@ inline void pop(Stack& stack, Types&... args) {
|
||||
drop(stack, N);
|
||||
}
|
||||
template <typename... Types>
|
||||
inline void pop(Stack* stack, Types&... args) {
|
||||
static inline void pop(Stack* stack, Types&... args) {
|
||||
pop(*stack, args...);
|
||||
}
|
||||
template <typename Type>
|
||||
inline void push_one(Stack& stack, Type&& arg) {
|
||||
static inline void push_one(Stack& stack, Type&& arg) {
|
||||
stack.emplace_back(std::forward<Type>(arg));
|
||||
}
|
||||
|
||||
inline void push_one(Stack& stack, c10::TensorOptions options) {
|
||||
static inline void push_one(Stack& stack, c10::TensorOptions options) {
|
||||
stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
|
||||
stack.emplace_back(options.layout());
|
||||
stack.emplace_back(options.device());
|
||||
@ -151,15 +151,15 @@ inline void push_one(Stack& stack, c10::TensorOptions options) {
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
inline void push(Stack& stack, Types&&... args) {
|
||||
static inline void push(Stack& stack, Types&&... args) {
|
||||
(void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
|
||||
}
|
||||
template <typename... Types>
|
||||
inline void push(Stack* stack, Types&&... args) {
|
||||
static inline void push(Stack* stack, Types&&... args) {
|
||||
return push(*stack, std::forward<Types>(args)...);
|
||||
}
|
||||
template <class T>
|
||||
inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
|
||||
static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
|
||||
for (T elem : elements) {
|
||||
stack.push_back(std::move(elem));
|
||||
}
|
||||
|
@ -59,6 +59,13 @@ view_as_complex_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
return std::make_tuple(result, 0);
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>>
|
||||
to_other_batch_rule(const Tensor& self, optional<int64_t> self_bdim,
|
||||
const Tensor& other, optional<int64_t> other_bdim,
|
||||
bool non_blocking,
|
||||
bool copy, std::optional<at::MemoryFormat> memory_format) {
|
||||
return std::make_tuple(self.to(other, non_blocking, copy, memory_format), self_bdim);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
|
@ -23,7 +23,7 @@ enum class GeluType {
|
||||
END
|
||||
};
|
||||
|
||||
inline GeluType get_gelutype_enum(const c10::string_view approximate) {
|
||||
static GeluType get_gelutype_enum(const c10::string_view approximate) {
|
||||
if (approximate == "none") {
|
||||
return GeluType::None;
|
||||
} else if (approximate == "tanh") {
|
||||
@ -33,7 +33,7 @@ inline GeluType get_gelutype_enum(const c10::string_view approximate) {
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string gelutype_to_string(const GeluType type) {
|
||||
static std::string gelutype_to_string(const GeluType type) {
|
||||
switch(type) {
|
||||
case GeluType::None: return "none";
|
||||
case GeluType::Tanh: return "tanh";
|
||||
|
@ -28,15 +28,15 @@ using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, con
|
||||
DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
|
||||
DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
|
||||
|
||||
inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
|
||||
static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
|
||||
return (a / b) * c + ((a % b) * c) / b;
|
||||
}
|
||||
|
||||
inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
|
||||
static inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
|
||||
return 1 + ((a + 1) * c - 1) / b;
|
||||
}
|
||||
|
||||
inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
|
||||
static inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
|
||||
int64_t ndim = gradOutput_.ndimension();
|
||||
for (const auto i : c10::irange(1, ndim)) {
|
||||
TORCH_CHECK(gradOutput_.size(i) > 0,
|
||||
|
@ -75,7 +75,7 @@ namespace {
|
||||
}
|
||||
}
|
||||
|
||||
inline bool cudnnv8_enabled_check_debug() {
|
||||
static inline bool cudnnv8_enabled_check_debug() {
|
||||
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
|
||||
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
|
||||
static uint8_t cudnnv8_debugcount = 0;
|
||||
@ -86,7 +86,7 @@ inline bool cudnnv8_enabled_check_debug() {
|
||||
return cudnnv8_flag == 1;
|
||||
}
|
||||
|
||||
inline bool cudnnv8_use_heur_mode_b() {
|
||||
static inline bool cudnnv8_use_heur_mode_b() {
|
||||
return is_cudnnv8_heuristic_mode_b();
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
|
||||
// (which the user can change) and computed inputs (which the user can
|
||||
// only indirectly affect). It would be an interesting exercise to
|
||||
// come up with a general framework to handle such situations.)
|
||||
inline void convolution_shape_check(
|
||||
static void convolution_shape_check(
|
||||
CheckedFrom c,
|
||||
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
||||
@ -212,7 +212,7 @@ inline void convolution_shape_check(
|
||||
// takes an extra output_padding argument to resolve the ambiguity.
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> _conv_output_size(
|
||||
static inline std::vector<T> _conv_output_size(
|
||||
ArrayRef<T> input_size, ArrayRef<T> weight_size,
|
||||
ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
|
||||
) {
|
||||
@ -231,14 +231,14 @@ inline std::vector<T> _conv_output_size(
|
||||
return output_size;
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> conv_output_size(
|
||||
static inline std::vector<int64_t> conv_output_size(
|
||||
IntArrayRef input_size, IntArrayRef weight_size,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
|
||||
) {
|
||||
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
|
||||
}
|
||||
|
||||
inline std::vector<c10::SymInt> conv_output_size(
|
||||
static inline std::vector<c10::SymInt> conv_output_size(
|
||||
SymIntArrayRef input_size, SymIntArrayRef weight_size,
|
||||
SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
|
||||
) {
|
||||
@ -264,14 +264,14 @@ std::vector<T> _conv_input_size(
|
||||
return input_size;
|
||||
}
|
||||
|
||||
inline std::vector<c10::SymInt> conv_input_size(
|
||||
static inline std::vector<c10::SymInt> conv_input_size(
|
||||
SymIntArrayRef output_size, SymIntArrayRef weight_size,
|
||||
SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
|
||||
) {
|
||||
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> conv_input_size(
|
||||
static inline std::vector<int64_t> conv_input_size(
|
||||
IntArrayRef output_size, IntArrayRef weight_size,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
||||
) {
|
||||
@ -295,27 +295,27 @@ std::vector<T> _conv_weight_size(
|
||||
return weight_size;
|
||||
}
|
||||
|
||||
inline std::vector<c10::SymInt> conv_weight_size(
|
||||
static inline std::vector<c10::SymInt> conv_weight_size(
|
||||
SymIntArrayRef input_size, SymIntArrayRef output_size,
|
||||
SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
||||
) {
|
||||
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> conv_weight_size(
|
||||
static inline std::vector<int64_t> conv_weight_size(
|
||||
IntArrayRef input_size, IntArrayRef output_size,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
||||
) {
|
||||
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
||||
}
|
||||
|
||||
inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
|
||||
static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
|
||||
std::vector<int64_t> shape(dim, 1);
|
||||
shape[1] = -1;
|
||||
return bias.reshape(shape);
|
||||
}
|
||||
|
||||
inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
|
||||
static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
|
||||
// disable NHWC for float64 input.
|
||||
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
|
||||
input.scalar_type() == at::kDouble ||
|
||||
@ -351,7 +351,7 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
|
||||
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
|
||||
|
||||
|
||||
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
|
||||
// disable NHWC for float64 input.
|
||||
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
|
||||
@ -378,7 +378,7 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
|
||||
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
|
||||
}
|
||||
|
||||
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
|
||||
// disable NHWC for float64 input.
|
||||
if (input.scalar_type() == at::kDouble ||
|
||||
@ -405,7 +405,7 @@ inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Ten
|
||||
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
|
||||
}
|
||||
|
||||
inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
|
||||
auto input_memory_format = input.suggest_memory_format();
|
||||
auto weight_memory_format = weight.suggest_memory_format();
|
||||
@ -417,7 +417,7 @@ inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tenso
|
||||
return can_use_thnn_channels_last_2d;
|
||||
}
|
||||
|
||||
inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
static inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
||||
|
||||
// check layout only for xpu tensor.
|
||||
if (!input.is_xpu() || !weight.is_xpu()) {
|
||||
|
@ -254,7 +254,7 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<a
|
||||
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
|
||||
*/
|
||||
template<typename scalar_t, typename accscalar_t>
|
||||
C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
|
||||
C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
|
||||
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
||||
if (x == 0) {
|
||||
return INFINITY;
|
||||
@ -376,7 +376,7 @@ C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
|
||||
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
|
||||
// Assumes x is close to zero and uses a Taylor expansion.
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
||||
C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
||||
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
|
||||
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
|
||||
scalar_t numer = 1;
|
||||
@ -394,7 +394,7 @@ C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, sc
|
||||
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
|
||||
// Assumes x is close to zero and uses a Taylor expansion.
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
||||
C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
||||
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
|
||||
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
|
||||
for (int i = 1; i <= 8; ++i) {
|
||||
@ -412,7 +412,7 @@ C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, sca
|
||||
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
|
||||
// To ensure numerical stability, this computation is performed at higher precision.
|
||||
template<typename scalar_t, typename accscalar_t>
|
||||
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
|
||||
C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
|
||||
const accscalar_t total = alpha + beta;
|
||||
const accscalar_t mean = alpha / total;
|
||||
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
|
||||
@ -452,7 +452,7 @@ C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha
|
||||
// This function inputs total=alpha+beta to make it easy to implement
|
||||
// Dirichlet reparameterized gradients in terms of Betas.
|
||||
template<typename scalar_t, typename accscalar_t>
|
||||
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
|
||||
C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
|
||||
accscalar_t x_ = static_cast<accscalar_t>(x);
|
||||
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
|
||||
accscalar_t total_ = static_cast<accscalar_t>(total);
|
||||
|
@ -6,7 +6,7 @@
|
||||
namespace at::native {
|
||||
|
||||
template<typename scalar_t>
|
||||
inline std::vector<int> generate_intervals(
|
||||
static inline std::vector<int> generate_intervals(
|
||||
scalar_t sample,
|
||||
int64_t inputSize,
|
||||
int64_t outputSize,
|
||||
@ -28,7 +28,7 @@ inline std::vector<int> generate_intervals(
|
||||
}
|
||||
|
||||
template <int64_t ndim>
|
||||
inline void fractional_max_pool_check_shape(
|
||||
static inline void fractional_max_pool_check_shape(
|
||||
const Tensor& input,
|
||||
const Tensor& randomSamples) {
|
||||
|
||||
|
@ -27,7 +27,7 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
||||
static inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
||||
if (tensor.is_conj()) {
|
||||
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
|
||||
} else {
|
||||
@ -35,7 +35,7 @@ inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
inline DimVector batched_matrix_contiguous_strides(
|
||||
static inline DimVector batched_matrix_contiguous_strides(
|
||||
const IntArrayRef sizes,
|
||||
const bool f_contig = false) {
|
||||
// f_contig chooses between the strides of a batch of Fortran (F-contiguous)
|
||||
@ -62,7 +62,7 @@ inline DimVector batched_matrix_contiguous_strides(
|
||||
* P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
|
||||
* matrix starting at Q.data_ptr()[B * M' * N'].
|
||||
*/
|
||||
inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
||||
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
||||
// If src is already in batched column major format, then
|
||||
// this will be efficient (no reordering of the data will occur)
|
||||
// because the first transpose will make the tensor contiguous,
|
||||
@ -75,7 +75,7 @@ inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
||||
/*
|
||||
* contig chooses between C-contig (true) and F-contig (false)
|
||||
*/
|
||||
inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
|
||||
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
|
||||
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
|
||||
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
|
||||
: cloneBatchedColumnMajor(clone));
|
||||
@ -92,7 +92,7 @@ inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor&
|
||||
* which is either the original batch size of the input, or its larger
|
||||
* broadcasted shape.
|
||||
*/
|
||||
inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
||||
static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
||||
at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) {
|
||||
nrows = (nrows == -1) ? src.size(-2) : nrows;
|
||||
auto copy_sizes = desired_batch_sizes.has_value()
|
||||
@ -109,7 +109,7 @@ inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
||||
* Given batches of matrices with arbitrary batch dim,
|
||||
* computes the number of batches.
|
||||
*/
|
||||
inline int64_t batchCount(const Tensor& batched_matrices) {
|
||||
static inline int64_t batchCount(const Tensor& batched_matrices) {
|
||||
int64_t result = 1;
|
||||
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
||||
result *= batched_matrices.size(i);
|
||||
@ -118,15 +118,15 @@ inline int64_t batchCount(const Tensor& batched_matrices) {
|
||||
}
|
||||
|
||||
// Computes the number of elements of a matrix in a batched matrix tensor
|
||||
inline int64_t matrixStride(const Tensor& batched_matrices) {
|
||||
static inline int64_t matrixStride(const Tensor& batched_matrices) {
|
||||
return batched_matrices.size(-1) * batched_matrices.size(-2);
|
||||
}
|
||||
|
||||
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
|
||||
inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
|
||||
static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
|
||||
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
|
||||
}
|
||||
inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
|
||||
static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
|
||||
checkIsMatrix(self, f_name, arg_name);
|
||||
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
|
||||
f_name,
|
||||
@ -134,7 +134,7 @@ inline void squareCheckInputs(const Tensor& self, const char* const f_name, cons
|
||||
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
|
||||
}
|
||||
|
||||
inline void checkInputsSolver(const Tensor& A,
|
||||
static inline void checkInputsSolver(const Tensor& A,
|
||||
const Tensor& B,
|
||||
const bool left,
|
||||
const char* const f_name) {
|
||||
@ -146,14 +146,14 @@ inline void checkInputsSolver(const Tensor& A,
|
||||
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
|
||||
}
|
||||
|
||||
inline bool is_row_or_column_contiguous(const Tensor& t) {
|
||||
static inline bool is_row_or_column_contiguous(const Tensor& t) {
|
||||
// This could be made more general, similar to how it's checked in matmul, which would allow to
|
||||
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
||||
// We choose to be conservative for simplicity
|
||||
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
||||
}
|
||||
|
||||
inline TransposeType to_transpose_type(const bool contig, const bool conj) {
|
||||
static inline TransposeType to_transpose_type(const bool contig, const bool conj) {
|
||||
if (conj) {
|
||||
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
|
||||
else { return TransposeType::ConjTranspose; }
|
||||
@ -261,7 +261,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu
|
||||
}
|
||||
|
||||
// Returns the epsilon value for floating types except half
|
||||
inline double _get_epsilon(const ScalarType& sc_type) {
|
||||
static inline double _get_epsilon(const ScalarType& sc_type) {
|
||||
switch (sc_type) {
|
||||
case at::ScalarType::Float:
|
||||
return static_cast<double>(std::numeric_limits<float>::epsilon());
|
||||
@ -274,7 +274,7 @@ inline double _get_epsilon(const ScalarType& sc_type) {
|
||||
|
||||
// Validates input shapes and devices
|
||||
// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
|
||||
inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
|
||||
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
|
||||
TORCH_CHECK(self.device() == A.device(),
|
||||
"Expected b and A to be on the same device, but found b on ",
|
||||
self.device(), " and A on ", A.device(), " instead.");
|
||||
@ -293,7 +293,7 @@ inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const ch
|
||||
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
|
||||
}
|
||||
|
||||
inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
|
||||
static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
|
||||
auto dtype = t.scalar_type();
|
||||
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
|
||||
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
|
||||
@ -305,13 +305,13 @@ inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, co
|
||||
|
||||
|
||||
// Checks if all the Tensors in a TensorList are of the same dimensions
|
||||
inline void checkAllSameDim(TensorList tensors, int64_t dim) {
|
||||
static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
|
||||
for (auto &t : tensors) {
|
||||
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
|
||||
}
|
||||
}
|
||||
|
||||
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
|
||||
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
|
||||
// broadcast the batch dimensions of arg1 and arg2.
|
||||
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
|
||||
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
|
||||
@ -325,7 +325,7 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_
|
||||
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
|
||||
}
|
||||
|
||||
inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
||||
static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
||||
// If there's no name we assume we don't want to check the errors
|
||||
if (name != nullptr) {
|
||||
linearSolveCheckInputs(arg1, arg2, name);
|
||||
@ -338,7 +338,7 @@ inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1
|
||||
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
|
||||
static inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
|
||||
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
|
||||
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
|
||||
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
|
||||
@ -346,7 +346,7 @@ inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor&
|
||||
}
|
||||
|
||||
// Return a permutation with the given axes moved to the end.
|
||||
inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
||||
static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
||||
const std::vector<int64_t> a = axes.vec();
|
||||
const int64_t ndim = self.ndimension();
|
||||
std::vector<int64_t> perm;
|
||||
@ -368,7 +368,7 @@ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
||||
}
|
||||
|
||||
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
|
||||
inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
||||
static inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
||||
bool compute_q;
|
||||
bool reduced;
|
||||
if (mode == "reduced") {
|
||||
@ -388,7 +388,7 @@ inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
||||
}
|
||||
|
||||
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
|
||||
inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
||||
static inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
||||
const Tensor& input,
|
||||
bool reduced) {
|
||||
int64_t m = input.size(-2), n = input.size(-1);
|
||||
@ -407,7 +407,7 @@ inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
||||
return std::make_tuple(q_sizes, q_strides, n_columns_q);
|
||||
}
|
||||
|
||||
inline bool svd_uses_cusolver(const Tensor& A) {
|
||||
static inline bool svd_uses_cusolver(const Tensor& A) {
|
||||
// if cusolver is available, it is used unconditionally
|
||||
return A.is_cuda()
|
||||
&& at::globalContext().hasCuSOLVER()
|
||||
@ -417,7 +417,7 @@ inline bool svd_uses_cusolver(const Tensor& A) {
|
||||
|
||||
// Function used instead of .to so that the original strides are retained
|
||||
// .to doesn't retain strides and make the output tensor contiguous
|
||||
inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
|
||||
static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
|
||||
auto strided_to = at::empty_strided(original_tensor.sizes(),
|
||||
original_tensor.strides(),
|
||||
options);
|
||||
@ -433,7 +433,7 @@ inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOpti
|
||||
// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
|
||||
// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
|
||||
// be `vec(0, 2, 1, 3)`.
|
||||
inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
|
||||
static inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
|
||||
TORCH_CHECK(
|
||||
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
|
||||
"duplicate or invalid dimensions");
|
||||
@ -453,7 +453,7 @@ inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64
|
||||
// will reverse a given permutation.
|
||||
// The reverse permutation array is created by swapping the indices and their
|
||||
// associated values from the given permutation array.
|
||||
inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
|
||||
static inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
|
||||
int64_t ndim = permutation.size();
|
||||
std::vector<int64_t> reverse_permutation(ndim);
|
||||
for (const auto dim_ind : c10::irange(ndim)) {
|
||||
@ -464,7 +464,7 @@ inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> perm
|
||||
|
||||
// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
|
||||
// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
|
||||
inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
||||
static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
||||
auto mn = std::min(m, n);
|
||||
auto mx = std::max(m, n);
|
||||
if (jobz == 'N') {
|
||||
@ -484,14 +484,14 @@ inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
||||
|
||||
// This function checks whether the uplo argument input is valid
|
||||
// Allowed strings are "u", "U", "l", "L"
|
||||
inline void checkUplo(const c10::string_view uplo) {
|
||||
static inline void checkUplo(const c10::string_view uplo) {
|
||||
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
|
||||
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
|
||||
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
|
||||
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
|
||||
}
|
||||
|
||||
inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
||||
static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
||||
TORCH_CHECK(
|
||||
result.device() == input.device(),
|
||||
fn_name,
|
||||
@ -504,7 +504,7 @@ inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor in
|
||||
// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
|
||||
// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
|
||||
// c10::canCast is used for checking the "safe copy" dtype requirements.
|
||||
inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
||||
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
||||
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
|
||||
TORCH_CHECK(
|
||||
can_cast,
|
||||
@ -514,7 +514,7 @@ inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result
|
||||
}
|
||||
|
||||
// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
|
||||
inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
|
||||
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
|
||||
bool can_cast = c10::canCast(result_type, out_type);
|
||||
TORCH_CHECK(
|
||||
can_cast,
|
||||
@ -523,7 +523,7 @@ inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType ou
|
||||
out_name, " with dtype ", out_type);
|
||||
}
|
||||
|
||||
inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
|
||||
static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
|
||||
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
|
||||
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
|
||||
}
|
||||
@ -538,7 +538,7 @@ inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f
|
||||
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
|
||||
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
|
||||
*/
|
||||
inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
||||
static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
||||
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
|
||||
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
|
||||
return vector_case;
|
||||
@ -547,7 +547,7 @@ inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other)
|
||||
/*
|
||||
Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
|
||||
*/
|
||||
inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
|
||||
static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
|
||||
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
|
||||
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
|
||||
}
|
||||
@ -578,7 +578,7 @@ class BroadcastLinearIndices {
|
||||
}
|
||||
};
|
||||
|
||||
inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
||||
static inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
||||
IntArrayRef input_strides = input.strides();
|
||||
IntArrayRef input_sizes = input.sizes();
|
||||
auto ndim = input.dim();
|
||||
@ -599,7 +599,7 @@ inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
||||
batch_stride_compatible;
|
||||
}
|
||||
|
||||
inline bool is_blas_compatible_row_major_order(const Tensor& input) {
|
||||
static inline bool is_blas_compatible_row_major_order(const Tensor& input) {
|
||||
IntArrayRef input_strides = input.strides();
|
||||
IntArrayRef input_sizes = input.sizes();
|
||||
auto ndim = input.dim();
|
||||
|
@ -675,6 +675,15 @@ Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::op
|
||||
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
|
||||
}
|
||||
|
||||
// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages.
|
||||
static Tensor nll_loss(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
|
||||
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index));
|
||||
}
|
||||
|
||||
Tensor nll_loss_nd_symint(
|
||||
const Tensor& self,
|
||||
const Tensor& target,
|
||||
|
@ -147,7 +147,7 @@ jiterator_also_stringify_as(jiterator_code(
|
||||
#define CENTRAL_RANGE 0.7
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_erfinv(T y) {
|
||||
/* Function to calculate inverse error function. Rational approximation
|
||||
is used to generate an initial approximation, which is then improved to
|
||||
@ -232,7 +232,7 @@ Date: February 1996
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*/
|
||||
template <typename scalar_t, bool is_cuda=false>
|
||||
C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
|
||||
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
|
||||
using acc_t = at::acc_type<scalar_t, is_cuda>;
|
||||
const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
|
||||
constexpr acc_t zero = acc_t{0.0};
|
||||
@ -324,7 +324,7 @@ C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_floa
|
||||
* N 0
|
||||
*/
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) {
|
||||
C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) {
|
||||
T result = 0;
|
||||
for (size_t i = 0; i <= len; i++) {
|
||||
result = result * x + A[i];
|
||||
@ -332,7 +332,7 @@ C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) {
|
||||
return result;
|
||||
}
|
||||
|
||||
inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
static inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
double sign = +1;
|
||||
double result = 0;
|
||||
if (x < 0.5) {
|
||||
@ -350,7 +350,7 @@ inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return sign * result;
|
||||
}
|
||||
|
||||
inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
static inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
float sign = +1;
|
||||
float result = 0;
|
||||
if (x < 0.5f) {
|
||||
@ -372,7 +372,7 @@ inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
|
||||
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*/
|
||||
inline double calc_digamma(double x) {
|
||||
static inline double calc_digamma(double x) {
|
||||
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
||||
static double PSI_10 = 2.25175258906672110764;
|
||||
if (x == 0) {
|
||||
@ -430,7 +430,7 @@ inline double calc_digamma(double x) {
|
||||
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
|
||||
* See note [3-Clause BSD License for the Cephes Math Library].
|
||||
*/
|
||||
inline float calc_digamma(float x) {
|
||||
static inline float calc_digamma(float x) {
|
||||
// See [C++ Standard Reference: Gamma Function]
|
||||
static float PSI_10 = 2.25175258906672110764f;
|
||||
if (x == 0) {
|
||||
@ -485,16 +485,16 @@ inline float calc_digamma(float x) {
|
||||
return result + logf(x) - (0.5f / x) - y;
|
||||
}
|
||||
|
||||
inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
|
||||
static inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
|
||||
return calc_digamma(static_cast<float>(a));
|
||||
}
|
||||
|
||||
inline c10::Half calc_digamma(c10::Half a) {
|
||||
static inline c10::Half calc_digamma(c10::Half a) {
|
||||
return calc_digamma(static_cast<float>(a));
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
|
||||
static inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
|
||||
// already blocked if n <= 1
|
||||
const auto one = scalar_t{1};
|
||||
return ((n % 2) ? one : -one) *
|
||||
@ -519,7 +519,7 @@ inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
|
||||
* See NOTICE for the licenses.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
|
||||
static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
|
||||
const scalar_t denom[], int64_t N) {
|
||||
// evaluating rational function, i.e., the ratio of two polynomials
|
||||
// the coefficients for numerator are given by `num` while coeffs for
|
||||
@ -1061,7 +1061,7 @@ static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
static inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
/* the calculation of the regularized upper incomplete gamma function
|
||||
* is done differently based on the values of a and x:
|
||||
* - if x and/or a is at the boundary of defined region, then assign the
|
||||
@ -1141,7 +1141,7 @@ inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
static inline scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
/* the calculation of the regularized lower incomplete gamma function
|
||||
* is done differently based on the values of a and x:
|
||||
* - if x and/or a is at the boundary of defined region, then assign the
|
||||
@ -1203,39 +1203,39 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
}
|
||||
|
||||
template <>
|
||||
C10_UNUSED inline c10::BFloat16 calc_igamma<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
|
||||
C10_UNUSED c10::BFloat16 calc_igamma<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
|
||||
return calc_igamma<float>(float(a), float(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
C10_UNUSED inline c10::Half calc_igamma<c10::Half>(c10::Half a, c10::Half x) {
|
||||
C10_UNUSED c10::Half calc_igamma<c10::Half>(c10::Half a, c10::Half x) {
|
||||
return calc_igamma<float>(float(a), float(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
C10_UNUSED inline c10::BFloat16 calc_igammac<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
|
||||
C10_UNUSED c10::BFloat16 calc_igammac<c10::BFloat16>(c10::BFloat16 a, c10::BFloat16 x) {
|
||||
return calc_igammac<float>(float(a), float(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
C10_UNUSED inline c10::Half calc_igammac<c10::Half>(c10::Half a, c10::Half x) {
|
||||
C10_UNUSED c10::Half calc_igammac<c10::Half>(c10::Half a, c10::Half x) {
|
||||
return calc_igammac<float>(float(a), float(x));
|
||||
}
|
||||
|
||||
inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); }
|
||||
|
||||
template <typename T>
|
||||
inline T abs_impl(T v) {
|
||||
static T abs_impl(T v) {
|
||||
return std::abs(v);
|
||||
}
|
||||
|
||||
template <>
|
||||
C10_UNUSED inline uint8_t abs_impl(uint8_t v) {
|
||||
C10_UNUSED uint8_t abs_impl(uint8_t v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_integral<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_integral<T>::value, T>::type
|
||||
calc_gcd(T a, T b) {
|
||||
a = abs_impl(a);
|
||||
b = abs_impl(b);
|
||||
@ -1284,7 +1284,7 @@ C10_HOST_DEVICE c10::complex<T> exp2_impl(c10::complex<T> x) {
|
||||
* required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1.
|
||||
*/
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
chbevl(const T x, const T array[], size_t len) {
|
||||
T b0, b1, b2;
|
||||
|
||||
@ -1310,7 +1310,7 @@ chbevl(const T x, const T array[], size_t len) {
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I0(x)
|
||||
* in the interval [0,8].
|
||||
*
|
||||
@ -1336,7 +1336,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
static inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
*
|
||||
@ -1361,7 +1361,7 @@ inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I1(x)
|
||||
* in the interval [0,8].
|
||||
@ -1388,7 +1388,7 @@ chebyshev_coefficients_i1e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_A() {
|
||||
/* Chebyshev coefficients for exp(-x) I1(x)
|
||||
* in the interval [0,8].
|
||||
@ -1417,7 +1417,7 @@ chebyshev_coefficients_i1e_A() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<double, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
@ -1443,7 +1443,7 @@ chebyshev_coefficients_i1e_B() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
static inline typename std::enable_if<std::is_same<float, T>::value, std::tuple<const T*, size_t>>::type
|
||||
chebyshev_coefficients_i1e_B() {
|
||||
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
|
||||
* in the inverted interval [8,infinity].
|
||||
@ -1463,7 +1463,7 @@ chebyshev_coefficients_i1e_B() {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i0(T _x) {
|
||||
T x = std::abs(_x);
|
||||
|
||||
@ -1481,7 +1481,7 @@ calc_i0(T _x) {
|
||||
}
|
||||
|
||||
// Upcast bfloat16 input to float for numerical accuracy purposes
|
||||
inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
|
||||
static inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
|
||||
|
||||
/*
|
||||
* This function is derived from the implementation of the i1 function in the Cephes Math Library.
|
||||
@ -1493,7 +1493,7 @@ inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i1(T _x) {
|
||||
T x = std::abs(_x);
|
||||
|
||||
@ -1522,7 +1522,7 @@ calc_i1(T _x) {
|
||||
* of all inputs to convert them into the domain of the approximation.
|
||||
*/
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_i1e(T _x) {
|
||||
T x = std::abs(_x);
|
||||
|
||||
@ -1549,7 +1549,7 @@ calc_i1e(T _x) {
|
||||
* (integrated from minus infinity to x) is equal to y.
|
||||
*/
|
||||
template <typename T>
|
||||
inline C10_HOST_DEVICE T calc_ndtri(T y0) {
|
||||
static inline C10_HOST_DEVICE T calc_ndtri(T y0) {
|
||||
|
||||
/* sqrt(2pi) */
|
||||
constexpr T s2pi = 2.50662827463100050242E0;
|
||||
@ -1737,7 +1737,7 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) {
|
||||
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
C10_HOST_DEVICE static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
erfcx_y100(T y100)
|
||||
{
|
||||
switch (static_cast<int>(y100)) {
|
||||
@ -2148,7 +2148,7 @@ return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
C10_HOST_DEVICE static inline typename std::enable_if<std::is_floating_point<T>::value, T>::type
|
||||
calc_erfcx(T x)
|
||||
{
|
||||
if (at::_isnan(x)) {
|
||||
@ -2188,7 +2188,7 @@ calc_erfcx(T x)
|
||||
* See NOTICE for the licenses.
|
||||
*/
|
||||
template <typename T>
|
||||
inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
|
||||
static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
|
||||
T t = x * c10::frac_sqrt_2<T>;
|
||||
if (x < T{-1.0}) {
|
||||
return std::log(calc_erfcx(-t) / 2) - t * t;
|
||||
@ -2198,7 +2198,7 @@ inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T airy_ai_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T airy_ai_forward(T x) {
|
||||
static const T AN[] = {
|
||||
+3.46538101525629032477e-01,
|
||||
+1.20075952739645805542e+01,
|
||||
@ -2377,7 +2377,7 @@ inline C10_HOST_DEVICE T airy_ai_forward(T x) {
|
||||
} // T airy_ai(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
@ -2489,7 +2489,7 @@ inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
|
||||
} // bessel_j0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
@ -2597,7 +2597,7 @@ inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
|
||||
} // bessel_j1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.96936729297347051624e-04,
|
||||
+8.28352392107440799803e-02,
|
||||
@ -2712,7 +2712,7 @@ inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
|
||||
} // bessel_y0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
|
||||
static const T PP[] = {
|
||||
+7.62125616208173112003e-04,
|
||||
+7.31397056940917570436e-02,
|
||||
@ -2826,7 +2826,7 @@ inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
|
||||
} // bessel_y1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -2865,12 +2865,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
} // chebyshev_polynomial_t_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
|
||||
return chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
|
||||
} // chebyshev_polynomial_t_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -2913,12 +2913,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
|
||||
} // chebyshev_polynomial_u_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
|
||||
return chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
|
||||
} // chebyshev_polynomial_u_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -2969,12 +2969,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
|
||||
} // chebyshev_polynomial_v_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
|
||||
return chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
|
||||
} // chebyshev_polynomial_v_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3029,12 +3029,12 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
|
||||
} // chebyshev_polynomial_w_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
|
||||
return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
|
||||
} // chebyshev_polynomial_w_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3061,17 +3061,17 @@ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
|
||||
} // hermite_polynomial_h_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false, std::enable_if_t<!std::is_floating_point<T>::value, int> = 0>
|
||||
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
|
||||
return hermite_polynomial_h_forward(x, static_cast<int64_t>(n));
|
||||
} // hermite_polynomial_h_forward(T x, T n)
|
||||
|
||||
template<typename T, bool is_cuda=false, std::enable_if_t<std::is_floating_point<T>::value, int> = 0>
|
||||
inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
|
||||
return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast<int64_t>(n) : static_cast<int64_t>(-1));
|
||||
} // hermite_polynomial_h_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3098,12 +3098,12 @@ inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
|
||||
} // hermite_polynomial_he_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
|
||||
return hermite_polynomial_he_forward(x, static_cast<int64_t>(n));
|
||||
} // hermite_polynomial_he_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3134,12 +3134,12 @@ inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
|
||||
} // laguerre_polynomial_l_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
|
||||
return laguerre_polynomial_l_forward(x, static_cast<int64_t>(n));
|
||||
} // laguerre_polynomial_l_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3174,12 +3174,12 @@ inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
|
||||
} // legendre_polynomial_p_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
|
||||
return legendre_polynomial_p_forward(x, static_cast<int64_t>(n));
|
||||
} // legendre_polynomial_p_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
|
||||
static const T A[] = {
|
||||
-4.41534164647933937950e-18,
|
||||
+3.33079451882223809783e-17,
|
||||
@ -3268,7 +3268,7 @@ inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
|
||||
} // modified_bessel_i0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
|
||||
static const T A[] = {
|
||||
+2.77791411276104639959e-18,
|
||||
-2.11142121435816608115e-17,
|
||||
@ -3364,7 +3364,7 @@ inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
|
||||
} // modified_bessel_i1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
|
||||
static const T A[] = {
|
||||
+1.37446543561352307156e-16,
|
||||
+4.25981614279661018399e-14,
|
||||
@ -3441,7 +3441,7 @@ inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
|
||||
} // modified_bessel_k0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
|
||||
static const T A[] = {
|
||||
-7.02386347938628759343e-18,
|
||||
-2.42744985051936593393e-15,
|
||||
@ -3519,7 +3519,7 @@ inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
|
||||
} // modified_bessel_k1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
|
||||
static const T A[] = {
|
||||
+1.37446543561352307156e-16,
|
||||
+4.25981614279661018399e-14,
|
||||
@ -3596,7 +3596,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
|
||||
} // T scaled_modified_bessel_k0_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
|
||||
static const T A[] = {
|
||||
-7.02386347938628759343e-18,
|
||||
-2.42744985051936593393e-15,
|
||||
@ -3674,7 +3674,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
|
||||
} // T scaled_modified_bessel_k1_forward(T x)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3717,12 +3717,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
|
||||
} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
|
||||
return shifted_chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
|
||||
} // shifted_chebyshev_polynomial_t_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3769,12 +3769,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
|
||||
} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
|
||||
return shifted_chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
|
||||
} // shifted_chebyshev_polynomial_u_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3825,12 +3825,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
|
||||
} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
|
||||
return shifted_chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
|
||||
} // shifted_chebyshev_polynomial_v_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
|
||||
if (n < 0) {
|
||||
return T(0.0);
|
||||
}
|
||||
@ -3881,12 +3881,12 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
|
||||
} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
|
||||
|
||||
template<typename T, bool is_cuda=false>
|
||||
inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
|
||||
static inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
|
||||
return shifted_chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
|
||||
} // shifted_chebyshev_polynomial_w_forward(T x, T n)
|
||||
|
||||
template<typename T>
|
||||
inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
|
||||
static inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
|
||||
if (std::isinf(x)) {
|
||||
return T(0.0);
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
|
||||
namespace padding {
|
||||
|
||||
template <int dim>
|
||||
inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
|
||||
static inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
|
||||
|
||||
TORCH_CHECK(padding.size() == 2 * dim,
|
||||
"padding size is expected to be ", 2 * dim,
|
||||
|
@ -48,7 +48,7 @@ DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel);
|
||||
namespace {
|
||||
|
||||
template <typename dest_t, typename src_t>
|
||||
inline dest_t
|
||||
static inline dest_t
|
||||
safe_downcast(src_t v)
|
||||
{
|
||||
TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
|
||||
@ -58,7 +58,7 @@ safe_downcast(src_t v)
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline T pooling_output_shape_pad_lr(
|
||||
static inline T pooling_output_shape_pad_lr(
|
||||
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
|
||||
bool ceil_mode) {
|
||||
T outputSize = div_rtn<T>(
|
||||
@ -75,7 +75,7 @@ inline T pooling_output_shape_pad_lr(
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline T pooling_output_shape(
|
||||
static inline T pooling_output_shape(
|
||||
T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
|
||||
TORCH_CHECK(stride != 0, "stride should not be zero");
|
||||
TORCH_CHECK(pad >= 0,
|
||||
@ -117,7 +117,7 @@ inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
|
||||
}
|
||||
|
||||
// AveragePool2d/DilatedMaxPool2d (forward)
|
||||
inline void
|
||||
static inline void
|
||||
pool2d_shape_check(
|
||||
const Tensor& input,
|
||||
int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
|
||||
@ -164,7 +164,7 @@ pool2d_shape_check(
|
||||
}
|
||||
|
||||
// DilatedMaxPool2d (backward)
|
||||
inline void
|
||||
static inline void
|
||||
max_pool2d_backward_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& gradOutput,
|
||||
@ -192,7 +192,7 @@ max_pool2d_backward_shape_check(
|
||||
}
|
||||
|
||||
// AveragePool2d (backward)
|
||||
inline void
|
||||
static inline void
|
||||
avg_pool2d_backward_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& gradOutput,
|
||||
@ -218,7 +218,7 @@ avg_pool2d_backward_shape_check(
|
||||
}
|
||||
|
||||
// AveragePool3d/DilatedMaxPool3d (forward)
|
||||
inline void
|
||||
static inline void
|
||||
pool3d_shape_check(
|
||||
const Tensor& input,
|
||||
int64_t nslices,
|
||||
@ -280,7 +280,7 @@ pool3d_shape_check(
|
||||
"Output size is too small");
|
||||
}
|
||||
|
||||
inline void
|
||||
static inline void
|
||||
max_pool3d_backward_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& gradOutput,
|
||||
@ -317,7 +317,7 @@ max_pool3d_backward_shape_check(
|
||||
check_dim_size(indices, ndim, ndim-1, owidth);
|
||||
}
|
||||
|
||||
inline void
|
||||
static inline void
|
||||
avg_pool3d_backward_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& gradOutput,
|
||||
|
@ -24,7 +24,7 @@ namespace native {
|
||||
// only non-zero result.
|
||||
template <class T,
|
||||
typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
|
||||
inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
|
||||
static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
|
||||
T result = 1;
|
||||
while (b) {
|
||||
if (b & 1) {
|
||||
@ -38,13 +38,13 @@ inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
|
||||
|
||||
template <class T,
|
||||
typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
|
||||
inline HOST_DEVICE T powi(T a, T b) {
|
||||
static inline HOST_DEVICE T powi(T a, T b) {
|
||||
return powi_impl(a, b);
|
||||
}
|
||||
|
||||
template <class T,
|
||||
typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
|
||||
inline HOST_DEVICE T powi(T a, T b) {
|
||||
static inline HOST_DEVICE T powi(T a, T b) {
|
||||
if ( b < 0 ) {
|
||||
if ( a == 1 ) {
|
||||
return 1;
|
||||
|
@ -31,7 +31,7 @@ constexpr scalar_t lower_bound() {
|
||||
return lim::has_infinity ? -lim::infinity() : lim::lowest();
|
||||
}
|
||||
|
||||
inline Tensor restride_dim(
|
||||
static inline Tensor restride_dim(
|
||||
const Tensor& src, int64_t dim,
|
||||
IntArrayRef replacement_shape
|
||||
) {
|
||||
@ -96,13 +96,13 @@ inline std::optional<Tensor> _allreduce_return_trivial(
|
||||
" but found ", out.option())\
|
||||
}
|
||||
|
||||
inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
|
||||
static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
|
||||
OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
|
||||
OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
|
||||
OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
|
||||
}
|
||||
|
||||
inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
|
||||
static inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
|
||||
ScalarType scalarType = self.scalar_type();
|
||||
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
|
||||
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
|
||||
@ -111,7 +111,7 @@ inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype
|
||||
|
||||
using DimMask = TensorIterator::DimMask;
|
||||
|
||||
inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
if (opt_dims.has_value()) {
|
||||
return DimVector(opt_dims.value());
|
||||
} else {
|
||||
@ -121,7 +121,7 @@ inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
}
|
||||
}
|
||||
|
||||
inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
|
||||
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
|
||||
DimMask mask;
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
@ -150,7 +150,7 @@ inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keep
|
||||
return shape;
|
||||
}
|
||||
|
||||
inline void resize_reduction_result(
|
||||
static void resize_reduction_result(
|
||||
Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
|
||||
ScalarType /*dtype*/)
|
||||
{
|
||||
@ -167,7 +167,7 @@ inline Tensor create_reduction_result(
|
||||
return at::empty(shape, self.options().dtype(dtype));
|
||||
}
|
||||
|
||||
inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
|
||||
static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
|
||||
if (keepdim) {
|
||||
return result;
|
||||
}
|
||||
@ -182,7 +182,7 @@ inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask,
|
||||
return result.as_strided(shape, stride);
|
||||
}
|
||||
|
||||
inline TensorIterator make_reduction(
|
||||
static TensorIterator make_reduction(
|
||||
const char* name, Tensor& result, const Tensor& self,
|
||||
at::OptionalIntArrayRef dim_opt,
|
||||
bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
|
||||
@ -207,7 +207,7 @@ inline TensorIterator make_reduction(
|
||||
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
|
||||
}
|
||||
|
||||
inline C10_UNUSED TensorIterator make_reduction(
|
||||
static C10_UNUSED TensorIterator make_reduction(
|
||||
const char* name, Tensor& result, const Tensor& self,
|
||||
at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
|
||||
// special case for type promotion in mixed precision, improves computational
|
||||
@ -222,7 +222,7 @@ inline C10_UNUSED TensorIterator make_reduction(
|
||||
return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
|
||||
}
|
||||
|
||||
inline TensorIterator make_reduction(
|
||||
static TensorIterator make_reduction(
|
||||
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
|
||||
at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
|
||||
ScalarType dtype2) {
|
||||
@ -259,13 +259,13 @@ inline TensorIterator make_reduction(
|
||||
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
|
||||
}
|
||||
|
||||
inline C10_UNUSED TensorIterator make_reduction(
|
||||
static C10_UNUSED TensorIterator make_reduction(
|
||||
const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
|
||||
at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
|
||||
return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
|
||||
}
|
||||
|
||||
inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
|
||||
static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
|
||||
if (self.ndimension() == 0) {
|
||||
TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
|
||||
": Expected reduction dim -1 or 0 for scalar but got ", dim);
|
||||
@ -276,7 +276,7 @@ inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const c
|
||||
}
|
||||
}
|
||||
|
||||
inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
|
||||
static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
|
||||
TORCH_CHECK(
|
||||
!dim.empty(),
|
||||
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
|
||||
@ -286,7 +286,7 @@ inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, con
|
||||
}
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> get_zero_numel_tensor_size(
|
||||
static std::vector<int64_t> get_zero_numel_tensor_size(
|
||||
const Tensor& self,
|
||||
const int64_t dim,
|
||||
const bool keepdim,
|
||||
@ -313,7 +313,7 @@ inline std::vector<int64_t> get_zero_numel_tensor_size(
|
||||
// This function should be called when you are reducing a zero-numel tensor and want to
|
||||
// resize the output and return it. This function exists for resizing zero-numel
|
||||
// tensors when the size of the reduction dimension is non-zero.
|
||||
inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
|
||||
static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
|
||||
const Tensor& self, const int64_t dim,
|
||||
const bool keepdim, const char *fn_name) {
|
||||
auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
|
||||
@ -349,7 +349,7 @@ inline ScalarType get_dtype_from_result(Tensor& result, std::optional<ScalarType
|
||||
|
||||
namespace at::meta {
|
||||
|
||||
inline C10_UNUSED DimVector get_reduction_shape(
|
||||
static C10_UNUSED DimVector get_reduction_shape(
|
||||
const Tensor& self,
|
||||
IntArrayRef dims,
|
||||
bool keepdim,
|
||||
@ -358,7 +358,7 @@ inline C10_UNUSED DimVector get_reduction_shape(
|
||||
return native::shape_from_dim_mask(self, mask, keepdim);
|
||||
}
|
||||
|
||||
inline void resize_reduction(
|
||||
static void resize_reduction(
|
||||
impl::MetaBase& meta,
|
||||
const Tensor& self,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
@ -379,7 +379,7 @@ inline void resize_reduction(
|
||||
meta.maybe_get_output(), self, dims_, keepdim);
|
||||
}
|
||||
|
||||
inline void resize_reduction_with_indices(
|
||||
static void resize_reduction_with_indices(
|
||||
impl::MetaBase& meta,
|
||||
const Tensor& self,
|
||||
IntArrayRef dims,
|
||||
@ -396,7 +396,7 @@ inline void resize_reduction_with_indices(
|
||||
meta.maybe_get_output(1), self, dims_, keepdim);
|
||||
}
|
||||
|
||||
inline TensorIterator make_reduction(
|
||||
static TensorIterator make_reduction(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
@ -412,7 +412,7 @@ inline TensorIterator make_reduction(
|
||||
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
|
||||
}
|
||||
|
||||
inline TensorIterator make_reduction(
|
||||
static TensorIterator make_reduction(
|
||||
const Tensor& self,
|
||||
const Tensor& result1,
|
||||
const Tensor& result2,
|
||||
@ -434,7 +434,7 @@ inline TensorIterator make_reduction(
|
||||
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
|
||||
}
|
||||
|
||||
inline C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
||||
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
|
@ -6,7 +6,7 @@ namespace at::native {
|
||||
|
||||
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
||||
|
||||
inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
if (reduce == "max" || reduce == "amax") {
|
||||
return ReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
@ -23,7 +23,7 @@ inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
}
|
||||
|
||||
// used for `scatter_reduce`, old options for BC.
|
||||
inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
||||
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
||||
if (use_new_options) {
|
||||
return get_reduction_enum(reduce);
|
||||
} else {
|
||||
|
@ -40,7 +40,7 @@ TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
|
||||
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
|
||||
TORCH_API void resize_bytes_nocuda(const Storage& storage, c10::SymInt size_bytes);
|
||||
|
||||
inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
|
||||
static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
|
||||
// It does not make sense to try to resize a storage
|
||||
// to hold 0 elements, and this can break
|
||||
// if storage_offset is positive but
|
||||
@ -79,7 +79,7 @@ template <>
|
||||
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
|
||||
|
||||
template <typename T>
|
||||
inline void checkInBoundsForStorage(
|
||||
static inline void checkInBoundsForStorage(
|
||||
ArrayRef<T> size,
|
||||
ArrayRef<T> stride,
|
||||
T storage_offset,
|
||||
@ -111,7 +111,7 @@ inline void checkInBoundsForStorage(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||
static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||
ArrayRef<T> size, ArrayRef<T> stride) {
|
||||
// FIXME: stride should be optional
|
||||
if (stride.data()) {
|
||||
|
@ -440,6 +440,15 @@ TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) (
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor softmax(const Tensor& input_, const int64_t dim_) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
return at::_softmax(input_, dim_, false);
|
||||
}();
|
||||
namedinference::propagate_names(result, input_);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
@ -496,6 +505,15 @@ Tensor special_softmax(const Tensor& input_, const int64_t dim_, std::optional<S
|
||||
return at::softmax(input_, dim_, dtype);
|
||||
}
|
||||
|
||||
static Tensor log_softmax(const Tensor& input_, const int64_t dim_) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
return at::_log_softmax(input_, dim_, false);
|
||||
}();
|
||||
namedinference::propagate_names(result, input_);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
|
@ -1195,6 +1195,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
|
||||
#undef REPR
|
||||
}
|
||||
|
||||
static Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const Tensor& window,
|
||||
const bool center, const bool normalized, const optional<bool> onesidedOpt,
|
||||
const optional<int64_t> lengthOpt) {
|
||||
return at::native::istft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
|
||||
onesidedOpt, lengthOpt, /*return_complex=*/false);
|
||||
}
|
||||
|
||||
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
|
||||
const auto input_sizes = input.sizes();
|
||||
const auto input_strides = input.strides();
|
||||
|
@ -172,10 +172,18 @@ Tensor arange(
|
||||
return at::arange_out(result, start, end, step);
|
||||
}
|
||||
|
||||
static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) {
|
||||
return at::arange_out(result, start, end, /*step=*/1);
|
||||
}
|
||||
|
||||
Tensor& arange_out(const Scalar& end, Tensor& result) {
|
||||
return at::arange_out(result, /*start=*/0, end, /*step=*/1);
|
||||
}
|
||||
|
||||
static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) {
|
||||
return at::arange_out(result, start, end, /*step=*/1);
|
||||
}
|
||||
|
||||
Tensor _dim_arange(const Tensor& like, int64_t dim) {
|
||||
return at::arange(like.size(dim), like.options().dtype(at::kLong));
|
||||
}
|
||||
|
@ -105,6 +105,10 @@ Tensor & detach_(Tensor & self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
static Tensor contiguous(const Tensor & self) {
|
||||
return contiguous(self, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
Tensor contiguous(const Tensor& self, MemoryFormat memory_format) {
|
||||
if (self.is_contiguous(memory_format)) {
|
||||
return self;
|
||||
|
@ -210,6 +210,7 @@
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
@ -1180,6 +1181,14 @@ Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef s
|
||||
return result;
|
||||
}
|
||||
|
||||
static Tensor as_strided_tensorimpl_meta(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
|
||||
auto storage_offset = storage_offset_.value_or(self.storage_offset());
|
||||
auto result = at::detail::make_tensor<TensorImpl>(
|
||||
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
|
||||
setStrided(result, size, stride, storage_offset);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void setStridedUnchecked(
|
||||
const Tensor& self,
|
||||
@ -1240,6 +1249,10 @@ const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymInt
|
||||
return self;
|
||||
}
|
||||
|
||||
static Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||
return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Should just use narrow_copy_out, but this API is used internally at Meta:
|
||||
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
|
||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||
@ -3574,6 +3587,10 @@ Tensor view_as(const Tensor& self, const Tensor& other) {
|
||||
return self.view_symint(other.sym_sizes());
|
||||
}
|
||||
|
||||
static int64_t numel(const Tensor& self) {
|
||||
return self.unsafeGetTensorImpl()->numel();
|
||||
}
|
||||
|
||||
std::vector<Tensor> unbind(const Tensor &self, int64_t dim) {
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
int64_t size = self.size(dim);
|
||||
|
@ -180,6 +180,10 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
|
||||
compute_triu_tril<UpperTriangle>(self, k, result);
|
||||
}
|
||||
|
||||
static Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) {
|
||||
return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes));
|
||||
}
|
||||
|
||||
Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
|
||||
if (sizes.size() != 2) {
|
||||
throw std::runtime_error("expected matrix input");
|
||||
|
@ -210,7 +210,7 @@ std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArray
|
||||
return {nbatch, channels, output_depth, output_height, output_width};
|
||||
}
|
||||
|
||||
inline void upsample_2d_shape_check(
|
||||
static inline void upsample_2d_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
int64_t nbatch,
|
||||
@ -251,7 +251,7 @@ inline void upsample_2d_shape_check(
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t compute_scales_value(
|
||||
static inline scalar_t compute_scales_value(
|
||||
const std::optional<double> scale,
|
||||
int64_t input_size,
|
||||
int64_t output_size) {
|
||||
@ -263,7 +263,7 @@ inline scalar_t compute_scales_value(
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t area_pixel_compute_scale(
|
||||
static inline scalar_t area_pixel_compute_scale(
|
||||
int64_t input_size,
|
||||
int64_t output_size,
|
||||
bool align_corners,
|
||||
@ -281,7 +281,7 @@ inline scalar_t area_pixel_compute_scale(
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t area_pixel_compute_source_index(
|
||||
static inline scalar_t area_pixel_compute_source_index(
|
||||
scalar_t scale,
|
||||
int64_t dst_index,
|
||||
bool align_corners,
|
||||
@ -308,7 +308,7 @@ inline scalar_t area_pixel_compute_source_index(
|
||||
}
|
||||
}
|
||||
|
||||
inline int64_t nearest_neighbor_compute_source_index(
|
||||
static inline int64_t nearest_neighbor_compute_source_index(
|
||||
const float scale,
|
||||
int64_t dst_index,
|
||||
int64_t input_size) {
|
||||
@ -319,7 +319,7 @@ inline int64_t nearest_neighbor_compute_source_index(
|
||||
return src_index;
|
||||
}
|
||||
|
||||
inline int64_t nearest_neighbor_exact_compute_source_index(
|
||||
static inline int64_t nearest_neighbor_exact_compute_source_index(
|
||||
const float scale,
|
||||
int64_t dst_index,
|
||||
int64_t input_size) {
|
||||
@ -331,7 +331,7 @@ inline int64_t nearest_neighbor_exact_compute_source_index(
|
||||
return src_index;
|
||||
}
|
||||
|
||||
inline int64_t nearest_idx(
|
||||
static inline int64_t nearest_idx(
|
||||
int64_t output_index,
|
||||
int64_t input_size,
|
||||
int64_t output_size,
|
||||
@ -352,7 +352,7 @@ inline int64_t nearest_idx(
|
||||
}
|
||||
}
|
||||
|
||||
inline int64_t nearest_exact_idx(
|
||||
static inline int64_t nearest_exact_idx(
|
||||
int64_t output_index,
|
||||
int64_t input_size,
|
||||
int64_t output_size,
|
||||
@ -392,17 +392,17 @@ static void upsample_increment_value_bounded(
|
||||
// Based on
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
template <typename scalar_t>
|
||||
inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
|
||||
static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
|
||||
return ((A + 2) * x - (A + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
|
||||
static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
|
||||
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void get_cubic_upsample_coefficients(
|
||||
static inline void get_cubic_upsample_coefficients(
|
||||
scalar_t coeffs[4],
|
||||
scalar_t t) {
|
||||
scalar_t A = -0.75;
|
||||
@ -418,7 +418,7 @@ inline void get_cubic_upsample_coefficients(
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline scalar_t cubic_interp1d(
|
||||
static inline scalar_t cubic_interp1d(
|
||||
scalar_t x0,
|
||||
scalar_t x1,
|
||||
scalar_t x2,
|
||||
@ -434,7 +434,7 @@ inline scalar_t cubic_interp1d(
|
||||
// type can accurately represent, the type casting to `int64_t` might exceed
|
||||
// `input_size`, causing overflow. So we guard it with `std::min` below.
|
||||
template<typename scalar_t, typename opmath_t>
|
||||
inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
|
||||
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
|
||||
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
|
||||
lambda = std::min(
|
||||
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
|
||||
@ -443,7 +443,7 @@ inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64
|
||||
}
|
||||
|
||||
template<typename scalar_t, typename opmath_t>
|
||||
inline void compute_source_index_and_lambda(
|
||||
static inline void compute_source_index_and_lambda(
|
||||
int64_t& input_index0,
|
||||
int64_t& input_index1,
|
||||
scalar_t& lambda0,
|
||||
|
@ -82,7 +82,7 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o
|
||||
|
||||
template <typename func_t,
|
||||
typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
||||
inline void
|
||||
static inline void
|
||||
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
||||
using traits = function_traits<func_t>;
|
||||
using result_type = typename traits::result_type;
|
||||
@ -97,7 +97,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t
|
||||
|
||||
template <typename func_t,
|
||||
typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
|
||||
inline void
|
||||
static inline void
|
||||
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
||||
using traits = function_traits<func_t>;
|
||||
for (; i < n; i++) {
|
||||
@ -111,7 +111,7 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t
|
||||
// Basic loop operation (one output, N inputs). May be auto-vectorized
|
||||
// by the compiler. Supports inputs and outputs of different types.
|
||||
template <typename func_t>
|
||||
inline void
|
||||
static inline void
|
||||
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
||||
using traits = function_traits<func_t>;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
@ -166,7 +166,7 @@ void handle_tuple_outputs(char* C10_RESTRICT data[],
|
||||
// 2. Iterate over the members of the returned tuple, set the corresponding
|
||||
// output tensor by the tuple member in `handle_tuple_outputs` function.
|
||||
template <typename func_t>
|
||||
inline void
|
||||
static inline void
|
||||
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
||||
using traits = function_traits<func_t>;
|
||||
|
||||
@ -195,7 +195,7 @@ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_
|
||||
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
|
||||
// is 0, then there are no scalar inputs.
|
||||
template <typename func_t, typename vec_func_t>
|
||||
inline void
|
||||
static inline void
|
||||
vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
|
||||
using traits = function_traits<vec_func_t>;
|
||||
using scalar_t = typename function_traits<func_t>::result_type;
|
||||
@ -228,7 +228,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
|
||||
|
||||
|
||||
template <typename traits, typename cb_t>
|
||||
inline void unroll_contiguous_scalar_checks(
|
||||
static inline void unroll_contiguous_scalar_checks(
|
||||
const int64_t* /*strides*/,
|
||||
std::index_sequence<>,
|
||||
cb_t&& cb) {
|
||||
@ -236,7 +236,7 @@ inline void unroll_contiguous_scalar_checks(
|
||||
}
|
||||
|
||||
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
|
||||
inline void unroll_contiguous_scalar_checks(
|
||||
static inline void unroll_contiguous_scalar_checks(
|
||||
const int64_t* strides,
|
||||
std::index_sequence<INDEX0, INDEX...>,
|
||||
cb_t&& cb) {
|
||||
|
@ -21,21 +21,21 @@ using namespace vec;
|
||||
|
||||
// reduction that is contiguous over the input in dim 0
|
||||
template <typename traits>
|
||||
inline bool is_contiguous_reduction(const int64_t* strides) {
|
||||
static inline bool is_contiguous_reduction(const int64_t* strides) {
|
||||
return strides[0] == 0 &&
|
||||
strides[1] == sizeof(typename traits::arg2_t);
|
||||
}
|
||||
|
||||
// reduction that is contiguous over the input in dim 1
|
||||
template <typename traits>
|
||||
inline bool is_outer_reduction(const int64_t* strides) {
|
||||
static inline bool is_outer_reduction(const int64_t* strides) {
|
||||
return strides[0] == 0 &&
|
||||
strides[2] == sizeof(typename traits::result_type) &&
|
||||
strides[3] == sizeof(typename traits::arg2_t);
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t>
|
||||
inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
|
||||
static inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
|
||||
func_t op, vec_func_t vop, bool reduce) {
|
||||
VEC_LOOP_HEADER(func_t, data)
|
||||
const char* in1_ptr = data[1];
|
||||
@ -69,7 +69,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
|
||||
static inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
|
||||
for (const auto j C10_UNUSED : c10::irange(n)) {
|
||||
f();
|
||||
data[0] += strides[0];
|
||||
@ -79,7 +79,7 @@ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n,
|
||||
|
||||
// computes the reduction out = op(out, in)
|
||||
template <typename func_t, typename vec_func_t>
|
||||
inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
|
||||
static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
|
||||
VEC_LOOP_HEADER(func_t, data)
|
||||
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
||||
int64_t count = n / (4 * Vec::size());
|
||||
@ -93,7 +93,7 @@ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_fu
|
||||
|
||||
// computes the reduction out = op(out, in)
|
||||
template <typename func_t, typename vec_func_t>
|
||||
inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
|
||||
static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
|
||||
VEC_LOOP_HEADER(func_t, data)
|
||||
|
||||
// reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
|
||||
@ -132,13 +132,13 @@ static void set_results(const res_t result, const TensorIteratorBase &iter, cons
|
||||
}
|
||||
|
||||
template<typename traits, std::size_t i = 0, typename... tuple_t>
|
||||
inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
|
||||
static inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
|
||||
for_each_in_tuple(const std::tuple<tuple_t...>& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
|
||||
return i;
|
||||
}
|
||||
|
||||
template<typename traits, std::size_t i = 0, typename... tuple_t>
|
||||
inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
|
||||
static inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
|
||||
for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIteratorBase &iter, const int num_outputs) {
|
||||
if (i < (size_t)num_outputs) {
|
||||
set_result<traits>(i, std::get<i>(t), iter, num_outputs);
|
||||
@ -286,7 +286,7 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo
|
||||
// when reduction is on most inner dimension (dim 0 in TensorIterator)
|
||||
// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
|
||||
// can be used.
|
||||
inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
|
||||
static inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
|
||||
return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
|
||||
&& iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
|
||||
}
|
||||
|
@ -1002,7 +1002,7 @@ std::string generate_code(
|
||||
std::string extra_args = "";
|
||||
for (size_t i = 0; i < extra_args_typenames.size(); i++) {
|
||||
auto type = std::string(extra_args_typenames[i]);
|
||||
auto name = "extra_arg_" + std::to_string(i);
|
||||
auto name = "extra_arg_" + std::string(to_string(i));
|
||||
extra_params += "," + type + " " + name;
|
||||
extra_args += ", " + name;
|
||||
}
|
||||
|
@ -13,8 +13,7 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool isTraining,
|
||||
bool is_causal,
|
||||
@ -35,8 +34,7 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -130,8 +128,7 @@ struct MHAParams {
|
||||
int64_t h;
|
||||
int64_t s_q;
|
||||
int64_t s_kv;
|
||||
int64_t d_qk;
|
||||
int64_t d_v;
|
||||
int64_t d;
|
||||
double dropout_probability;
|
||||
bool is_causal;
|
||||
bool return_softmaxstats;
|
||||
@ -143,8 +140,7 @@ void setMHAParams(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
@ -159,8 +155,7 @@ void setMHAParams(
|
||||
}
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.d_qk = d_qk;
|
||||
params.d_v = d_v;
|
||||
params.d = d;
|
||||
params.s_q = s_q;
|
||||
params.s_kv = s_kv;
|
||||
params.dropout_probability = dropout_probability;
|
||||
@ -198,8 +193,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
@ -212,8 +206,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -256,8 +249,7 @@ auto build_graph_and_tensors(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool return_softmaxstats,
|
||||
bool is_causal,
|
||||
@ -391,8 +383,7 @@ auto build_graph_and_tensors_backward(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -523,8 +514,7 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool return_softmaxstats,
|
||||
bool is_causal,
|
||||
@ -538,7 +528,7 @@ void run_cudnn_SDP_fprop(
|
||||
Tensor& dropoutoffset) {
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
o = at::empty_strided(
|
||||
{b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options());
|
||||
{b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options());
|
||||
if (return_softmaxstats) {
|
||||
// TODO(eqy): verify that this is correct
|
||||
softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
|
||||
@ -549,8 +539,7 @@ void run_cudnn_SDP_fprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -567,8 +556,7 @@ void run_cudnn_SDP_fprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
d,
|
||||
scaling_factor,
|
||||
return_softmaxstats,
|
||||
is_causal,
|
||||
@ -611,8 +599,7 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -627,27 +614,9 @@ void run_cudnn_SDP_bprop(
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset) {
|
||||
Tensor dO_ = dO;
|
||||
if (!dO.strides()[dO.strides().size() - 1]) {
|
||||
TORCH_WARN(
|
||||
"cuDNN SDPA backward got an innermost stride of 0 in grad_out, which is unsupported. Materializing a contiguous\
|
||||
tensor which will increase memory usage...");
|
||||
dO_ = dO.contiguous();
|
||||
}
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
auto key = MHACacheKeyWrapper(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_probability,
|
||||
is_causal,
|
||||
true);
|
||||
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
|
||||
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
|
||||
graph_and_tensors_backward graph_and_tensors_backward_values;
|
||||
if (graph_and_tensors_backward_ptr) {
|
||||
@ -658,8 +627,7 @@ void run_cudnn_SDP_bprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
d,
|
||||
scaling_factor,
|
||||
is_causal,
|
||||
dropout_probability,
|
||||
@ -667,7 +635,7 @@ void run_cudnn_SDP_bprop(
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
dO_,
|
||||
dO,
|
||||
softmaxstats,
|
||||
dQ,
|
||||
dK,
|
||||
@ -709,4 +677,5 @@ void run_cudnn_SDP_bprop(
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif
|
||||
|
@ -9,8 +9,7 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_k,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool isTraining,
|
||||
bool is_causal,
|
||||
@ -28,8 +27,7 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d_k,
|
||||
int64_t d_v,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
inline void col2im_shape_check(
|
||||
static inline void col2im_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
int64_t output_height,
|
||||
@ -135,7 +135,7 @@ inline void col2im_shape_check(
|
||||
}
|
||||
}
|
||||
|
||||
inline void im2col_shape_check(
|
||||
static inline void im2col_shape_check(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
int64_t kernel_height,
|
||||
|
@ -27,7 +27,53 @@ Tensor mkldnn_convolution(
|
||||
TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static Tensor mkldnn_convolution_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
|
||||
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
|
||||
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub);
|
||||
|
||||
static Tensor mkldnn_convolution_transpose(
|
||||
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static Tensor mkldnn_convolution_transpose_backward_input(
|
||||
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool bias_defined) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor> mkldnn_convolution_transpose_backward_weights(
|
||||
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, bool bias_defined) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
|
||||
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
|
||||
int64_t groups, std::array<bool,3> output_mask) {
|
||||
TORCH_CHECK(false, "mkldnn_convolution_transpose_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_stub);
|
||||
REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub);
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include <ATen/ops/native_batch_norm_backward_native.h>
|
||||
#include <ATen/ops/native_batch_norm_native.h>
|
||||
#endif
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
|
||||
#if !AT_MKLDNN_ENABLED()
|
||||
|
||||
@ -38,7 +37,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
|
||||
TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
|
||||
static std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
|
||||
double eps, bool inplace) {
|
||||
@ -82,6 +81,7 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
||||
#else // AT_MKLDNN_ENABLED
|
||||
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
#include <ideep/abstract_types.hpp>
|
||||
|
||||
|
@ -36,7 +36,7 @@ void check_mkldnn_binary_fusion_inputs(
|
||||
const Tensor& weight,
|
||||
const Tensor& bias);
|
||||
|
||||
inline std::vector<int64_t> padding_r(
|
||||
static inline std::vector<int64_t> padding_r(
|
||||
IntArrayRef padding, IntArrayRef output_padding)
|
||||
{
|
||||
// ConvTranpose padding adjustment
|
||||
@ -60,7 +60,7 @@ inline std::vector<int64_t> padding_r(
|
||||
// Make sure input has default contiguous strides if it's contiguous tensors for better performance.
|
||||
// For example, for tensor of size = [1, 1280], stride = [0, 1], we'll convert it to size = [1, 1280], stride = [1280, 1]
|
||||
// before calling oneDNN for better performance.
|
||||
inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) {
|
||||
static inline Tensor may_convert_to_default_contiguous_strides(const Tensor& input) {
|
||||
auto input_size = input.sizes().vec();
|
||||
auto input_stride = input.strides().vec();
|
||||
auto input_default_contiguous_strides = c10::contiguous_strides(input_size);
|
||||
|
@ -18,21 +18,26 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
|
||||
/* coefficients in rational expansion */
|
||||
|
||||
float y_abs = abs(y);
|
||||
if (y_abs >= 1.0f) {{
|
||||
output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
|
||||
if(y_abs > 1.0f){{
|
||||
output[index] = NAN;
|
||||
return;
|
||||
}}
|
||||
if (y_abs <= 0.7f) {{
|
||||
if(y_abs == 1.0f){{
|
||||
output[index] = copysign(INFINITY, y);
|
||||
return;
|
||||
}}
|
||||
if(y_abs <= 0.7f) {{
|
||||
z = y * y;
|
||||
num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
|
||||
dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
|
||||
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
|
||||
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f);
|
||||
x = y * num / dem;
|
||||
}} else {{
|
||||
}}
|
||||
else{{
|
||||
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
|
||||
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
|
||||
dem = (d[1] * z + d[0]) * z + 1.0f;
|
||||
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
|
||||
dem = (d[1]*z + d[0])*z + 1.0f;
|
||||
x = copysign(num, y) / dem;
|
||||
}}
|
||||
|
||||
output[index] = {0}(x);
|
||||
}})METAL";
|
||||
output[index] = x;
|
||||
}})METAL";
|
@ -143,7 +143,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -193,8 +193,8 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
|
||||
std::to_string(negative_slope.to<double>());
|
||||
string key =
|
||||
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -242,7 +242,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -285,7 +285,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim);
|
||||
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
|
||||
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
|
||||
@ -539,8 +539,8 @@ TORCH_IMPL_FUNC(threshold_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) +
|
||||
":" + std::to_string(value.to<double>());
|
||||
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to<double>()) + ":" +
|
||||
to_string(value.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -587,7 +587,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>());
|
||||
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
|
||||
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
|
||||
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to<double>()) + ":" +
|
||||
to_string(scale.to<double>()) + ":" + to_string(input_scale.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -916,8 +916,8 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
|
||||
std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" +
|
||||
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
|
||||
to_string(alpha.to<double>()) + ":" + to_string(scale.to<double>()) + ":" +
|
||||
to_string(input_scale.to<double>()) + ":" + to_string(is_result);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -1010,7 +1010,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor
|
||||
@ -1052,7 +1052,7 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
|
||||
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
MPSGraphTensor* gradOutputTensor =
|
||||
@ -1855,8 +1855,8 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
|
||||
std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>());
|
||||
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to<double>()) +
|
||||
":" + to_string(max.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
@ -136,8 +136,8 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" +
|
||||
std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble());
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
|
||||
":" + to_string(alpha_.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
@ -33,7 +33,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
|
||||
|
@ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
|
||||
|
||||
string bias_shape_key;
|
||||
if (bias_defined) {
|
||||
bias_shape_key = std::to_string(bias_shape[0]);
|
||||
bias_shape_key = to_string(bias_shape[0]);
|
||||
} else {
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) +
|
||||
":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" +
|
||||
to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) +
|
||||
":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" +
|
||||
bias_shape_key;
|
||||
|
||||
} else {
|
||||
key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
|
||||
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
|
||||
to_string(bias_defined) + ":" + bias_shape_key;
|
||||
}
|
||||
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
|
||||
@ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" +
|
||||
to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) +
|
||||
":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" +
|
||||
string([ns_shape_key UTF8String]);
|
||||
|
||||
} else {
|
||||
key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
}
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" +
|
||||
to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
} else {
|
||||
key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
}
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -63,7 +63,7 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" +
|
||||
std::to_string(val1) + ":" + std::to_string(val2);
|
||||
to_string(val1) + ":" + to_string(val2);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->stateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
|
||||
@ -469,7 +469,7 @@ static Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample);
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);
|
||||
|
@ -236,7 +236,7 @@ static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& gra
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
|
||||
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
|
||||
|
@ -229,8 +229,8 @@ static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" +
|
||||
std::to_string(alpha.toDouble());
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
|
||||
to_string(alpha.toDouble());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -331,8 +331,8 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* selfTensor = nil;
|
||||
MPSGraphTensor* otherTensor = nil;
|
||||
@ -615,8 +615,8 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
|
||||
MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
|
||||
|
@ -69,7 +69,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) +
|
||||
string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) +
|
||||
getTensorsStringKey({input, target, grad_output});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
}
|
||||
@autoreleasepool {
|
||||
string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) +
|
||||
std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction);
|
||||
to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
|
||||
to_string(isTargetCasted) + ":" + reductionToString(reduction);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output,
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
// TODO: Make the key
|
||||
string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
|
||||
getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
|
||||
string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
|
||||
reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
|
||||
getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape);
|
||||
@ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input,
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
// smooth_l1_loss_mps:
|
||||
// ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta
|
||||
@ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" +
|
||||
reductionToString(reduction) + ":" + std::to_string(beta);
|
||||
reductionToString(reduction) + ":" + to_string(beta);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
|
@ -106,7 +106,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -173,7 +173,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -221,8 +221,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
bool start_less_end = (start.to<double>() <= end.to<double>());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) +
|
||||
std::to_string(start_less_end);
|
||||
string key =
|
||||
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
|
@ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor,
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t});
|
||||
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" +
|
||||
keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData);
|
||||
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" +
|
||||
keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
@ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0";
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" +
|
||||
string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
|
||||
@ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t,
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
|
||||
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
|
||||
@autoreleasepool {
|
||||
NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key =
|
||||
func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
|
||||
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputScalarType = input_t.scalar_type();
|
||||
MPSGraphTensor* inputTensor =
|
||||
@ -1217,7 +1217,7 @@ TORCH_IMPL_FUNC(any_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
MPSShape* input_t_shape = getMPSShape(input_t);
|
||||
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
|
||||
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
|
||||
getMPSTypeString(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSDataType input_type = getMPSDataType(input_t);
|
||||
@ -1313,7 +1313,7 @@ TORCH_IMPL_FUNC(all_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
MPSShape* input_t_shape = getMPSShape(input_t);
|
||||
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
|
||||
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
|
||||
getMPSTypeString(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSDataType input_type = getMPSDataType(input_t);
|
||||
@ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t,
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
getTensorsStringKey(indices_t);
|
||||
string key =
|
||||
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* castInputTensor =
|
||||
|
@ -108,8 +108,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) +
|
||||
":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest);
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
|
||||
":dim" + to_string(dim_) + ":largest" + to_string(largest);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
@ -320,12 +320,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
string key =
|
||||
"cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
} else {
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size());
|
||||
}
|
||||
for (auto idx : skipped_tensor_indices) {
|
||||
key += "," + std::to_string(idx);
|
||||
|
@ -60,8 +60,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" +
|
||||
std::to_string(dim) + ":descending" + std::to_string(descending);
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
|
||||
":descending" + to_string(descending);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
|
@ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph
|
||||
|
@ -13,6 +13,32 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace at::native {
|
||||
static const std::string& getMetalType(const c10::ScalarType& t) {
|
||||
// Mapping from c10::ScalarType to integral type that can be used for unary ops
|
||||
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
|
||||
{c10::ScalarType::Half, "half"},
|
||||
{c10::ScalarType::Float, "float"},
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Bool, "bool"},
|
||||
{c10::ScalarType::Char, "int8_t"},
|
||||
{c10::ScalarType::Byte, "uint8_t"},
|
||||
};
|
||||
|
||||
auto it = scalar_to_metal_type.find(t);
|
||||
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static const std::string& getMetalType(const c10::Scalar& s) {
|
||||
return getMetalType(s.type());
|
||||
}
|
||||
|
||||
static const std::string& getMetalType(const Tensor& t) {
|
||||
return getMetalType(t.scalar_type());
|
||||
}
|
||||
|
||||
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
|
||||
|
||||
TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
|
||||
@ -31,8 +57,7 @@ TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
|
||||
}
|
||||
using namespace mps;
|
||||
@autoreleasepool {
|
||||
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel",
|
||||
{scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
|
||||
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)});
|
||||
|
||||
if (!self.is_contiguous()) {
|
||||
inputTensor = inputTensor.contiguous();
|
||||
|
@ -36,8 +36,8 @@ static std::string getUniqueKey(const ScalarType& dtype,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
(dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" +
|
||||
std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]";
|
||||
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
|
||||
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
}
|
||||
|
||||
// dim arg not supported when non consecutive, ie sorted
|
||||
|
@ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
|
||||
getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" +
|
||||
getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
|
||||
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -42,7 +42,7 @@ static std::string getStridedKey(const ScalarType& self_dtype,
|
||||
}
|
||||
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]";
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
}
|
||||
|
||||
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
|
||||
|
@ -14728,12 +14728,12 @@
|
||||
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/native/quantized/cpu/BinaryOps.h>
|
||||
#include <ATen/native/quantized/cpu/QuantizedOps.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
|
||||
@ -501,7 +500,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
|
||||
static Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
|
||||
return qadd<false>(std::move(qa), std::move(qb), scale, zero_point);
|
||||
}
|
||||
|
||||
|
@ -172,6 +172,16 @@ Tensor mean_quantized_cpu(
|
||||
return result;
|
||||
}
|
||||
|
||||
static Tensor& mean_out_quantized_cpu(
|
||||
Tensor& result,
|
||||
const Tensor& self,
|
||||
DimnameList dim,
|
||||
bool keepdim,
|
||||
std::optional<ScalarType> opt_dtype) {
|
||||
return mean_out_quantized_cpu(
|
||||
self, dimnames_to_positions(self, dim), keepdim, opt_dtype, result);
|
||||
}
|
||||
|
||||
// qstd
|
||||
inline bool is_std_inner_dim_fast_path(
|
||||
const Tensor& self,
|
||||
|
@ -216,6 +216,20 @@ Tensor upsample_bilinear2d_quantized_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
using at::native::upsample::compute_output_size;
|
||||
using at::native::upsample::get_scale_value;
|
||||
|
||||
static Tensor upsample_bilinear2d_quantized_cpu(
|
||||
const Tensor& input,
|
||||
at::OptionalIntArrayRef output_size,
|
||||
bool align_corners,
|
||||
std::optional<ArrayRef<double>> scale_factors) {
|
||||
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
|
||||
auto scale_h = get_scale_value(scale_factors, 0);
|
||||
auto scale_w = get_scale_value(scale_factors, 1);
|
||||
return upsample_bilinear2d_quantized_cpu(input, osize, align_corners, scale_h, scale_w);
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(qupsample_bilinear2d_nhwc_stub);
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -218,5 +218,25 @@ Tensor _upsample_nearest_exact2d_quantized_cpu(
|
||||
return _upsample_nearest2d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(input, osize, scale_h, scale_w);
|
||||
}
|
||||
|
||||
static Tensor upsample_nearest2d_quantized_cpu(
|
||||
const Tensor& input,
|
||||
at::OptionalIntArrayRef output_size,
|
||||
std::optional<ArrayRef<double>> scale_factors) {
|
||||
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
|
||||
auto scale_h = get_scale_value(scale_factors, 0);
|
||||
auto scale_w = get_scale_value(scale_factors, 1);
|
||||
return upsample_nearest2d_quantized_cpu(input, osize, scale_h, scale_w);
|
||||
}
|
||||
|
||||
static Tensor _upsample_nearest_exact2d_quantized_cpu(
|
||||
const Tensor& input,
|
||||
at::OptionalIntArrayRef output_size,
|
||||
std::optional<ArrayRef<double>> scale_factors) {
|
||||
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
|
||||
auto scale_h = get_scale_value(scale_factors, 0);
|
||||
auto scale_w = get_scale_value(scale_factors, 1);
|
||||
return _upsample_nearest_exact2d_quantized_cpu(input, osize, scale_h, scale_w);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -1,7 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
@ -36,6 +35,7 @@
|
||||
#endif
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
|
||||
namespace {
|
||||
// To have a sanity check for maximum matrix size.
|
||||
@ -1848,15 +1848,15 @@ class QConvInt8ForBC final {
|
||||
int64_t output_zero_point) {
|
||||
if (kReluFused) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv" +
|
||||
std::to_string(kSpatialDim),
|
||||
"d_relu, have been removed, please update your model to remove these arguments.");
|
||||
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv"
|
||||
+ c10::to_string(kSpatialDim) + "d_relu, " +
|
||||
"have been removed, please update your model to remove these arguments.");
|
||||
return packed_weight->apply_relu(act, output_scale, output_zero_point);
|
||||
} else {
|
||||
TORCH_WARN_ONCE(
|
||||
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv",
|
||||
std::to_string(kSpatialDim),
|
||||
"d, have been removed, please update your model to remove these arguments.");
|
||||
"Arguments [stride, padding, dilation, groups] in ops.quantized.conv"
|
||||
+ c10::to_string(kSpatialDim) + "d, " +
|
||||
"have been removed, please update your model to remove these arguments.");
|
||||
return packed_weight->apply(act, output_scale, output_zero_point);
|
||||
}
|
||||
}
|
||||
|
@ -342,10 +342,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
|
||||
output_shape[cols_dim] = output_columns;
|
||||
at::SymDimVector output_shape_vec(output_shape);
|
||||
|
||||
return at::empty_symint(
|
||||
output_shape_vec,
|
||||
weight.options().dtype(weight.scalar_type()),
|
||||
weight.suggest_memory_format());
|
||||
return at::empty_symint(output_shape_vec, weight.options().dtype(weight.scalar_type()), weight.suggest_memory_format());
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -376,10 +373,9 @@ Tensor _qembeddingbag_nbit_prepack_helper(
|
||||
int NUM_ELEM_PER_BYTE = 8 / bit_width;
|
||||
TORCH_CHECK(
|
||||
weight_contig.size(weight.dim() - 1) % NUM_ELEM_PER_BYTE == 0,
|
||||
"qembeddingbag_",
|
||||
std::to_string(bit_width),
|
||||
"bit_prepack only works for the number of columns a multiple of ",
|
||||
std::to_string(NUM_ELEM_PER_BYTE));
|
||||
"qembeddingbag_" + c10::to_string(bit_width) +
|
||||
"bit_prepack only works for the number of columns a multiple of " +
|
||||
c10::to_string(NUM_ELEM_PER_BYTE));
|
||||
|
||||
// The "fused" representation stores the scale and bias with the
|
||||
// row-wise quantized data in one tensor.
|
||||
@ -555,9 +551,11 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
||||
TORCH_FN(QEmbeddingPackWeights::run));
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, Meta, m) {
|
||||
m.impl(
|
||||
"quantized::embedding_bag_byte_prepack", qembeddingbag_byte_prepack_meta);
|
||||
"quantized::embedding_bag_byte_prepack",
|
||||
qembeddingbag_byte_prepack_meta);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point)
|
||||
run(plan_desc);
|
||||
execution_plan_cache[key] = plan_desc;
|
||||
return quantized_output.view(orig_sizes);
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn");
|
||||
|
@ -252,7 +252,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
|
||||
|
@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn");
|
||||
|
@ -624,6 +624,15 @@ Tensor _sparse_softmax(const Tensor& self, Dimname dim, optional<ScalarType> dty
|
||||
return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype);
|
||||
}
|
||||
|
||||
static Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
return at::_sparse_log_softmax(input_, dim_, false);
|
||||
}();
|
||||
namedinference::propagate_names(result, input_);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
|
||||
auto result = [&]() {
|
||||
NoNamesGuard guard;
|
||||
|
@ -270,6 +270,10 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value) {
|
||||
return div_out_sparse_zerodim(self, value, self);
|
||||
}
|
||||
|
||||
static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, SparseTensor& r) {
|
||||
return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r);
|
||||
}
|
||||
|
||||
Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
|
||||
auto commonDtype = at::result_type(self, value);
|
||||
if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
|
||||
@ -283,6 +287,10 @@ Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<c10::string
|
||||
return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self);
|
||||
}
|
||||
|
||||
static SparseTensor& div_out_sparse_scalar(const SparseTensor& t, Scalar value, std::optional<c10::string_view> rounding_mode, SparseTensor& r) {
|
||||
return div_out_sparse_zerodim(t, wrapped_scalar_tensor(value), std::move(rounding_mode), r);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// floor_divide(SparseTensor, Scalar)
|
||||
// --------------------------------------------------------------------
|
||||
@ -342,6 +350,10 @@ Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) {
|
||||
return floor_divide_out_sparse_zerodim(self, value, self);
|
||||
}
|
||||
|
||||
static SparseTensor& floor_divide_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) {
|
||||
return floor_divide_out_sparse_zerodim(t, wrapped_scalar_tensor(value), r);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// norm(SparseTensor, Scalar)
|
||||
// --------------------------------------------------------------------
|
||||
|
@ -666,7 +666,7 @@ Tensor scaled_dot_product_attention(
|
||||
case sdp::SDPBackend::cudnn_attention: {
|
||||
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
|
||||
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
|
||||
query_, key, value, compute_logsumexp, dropout_p, is_causal, scale);
|
||||
query_, key, value, dropout_p, is_causal, compute_logsumexp, scale);
|
||||
return std::get<0>(out_lse_softmax);
|
||||
}
|
||||
case sdp::SDPBackend::flash_attention: {
|
||||
|
@ -72,17 +72,10 @@
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
#ifndef USE_ROCM
|
||||
// MemoryEfficient Attention Specific Imports for CUDA
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
|
||||
#else
|
||||
// MemoryEfficient Attention Specific Imports for ROCM
|
||||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
@ -735,27 +728,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
||||
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
|
||||
}
|
||||
|
||||
// Adapted from TE
|
||||
// extract seed and offset from PhiloxCudaState
|
||||
__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr) {
|
||||
if (arg.captured_) {
|
||||
*seed_ptr = static_cast<int64_t>(*arg.seed_.ptr);
|
||||
*offset_ptr = static_cast<int64_t>(
|
||||
*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
|
||||
} else {
|
||||
*seed_ptr = static_cast<int64_t>(arg.seed_.val);
|
||||
*offset_ptr = static_cast<int64_t>(arg.offset_.val);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
bool compute_logsumexp,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
c10::optional<double> scale) {
|
||||
bool training,
|
||||
std::optional<double> scale) {
|
||||
// Used for tracking usage statistics
|
||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
|
||||
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
|
||||
@ -764,8 +744,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
const int64_t max_seqlen_batch_v = value.size(2);
|
||||
TORCH_CHECK(
|
||||
@ -774,42 +754,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
||||
|
||||
Tensor attention, log_sumexp;
|
||||
|
||||
at::Tensor cudnn_seed, cudnn_offset;
|
||||
cudnn_seed = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
cudnn_offset = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
|
||||
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
|
||||
|
||||
// See Note [Seed and Offset Device] in _efficient_attention_forward
|
||||
at::PhiloxCudaState philox_state;
|
||||
const bool in_capture_stream =
|
||||
at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None;
|
||||
if (use_dropout) {
|
||||
// Device
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
// if using dropout, we produce 1 random number for each element of the
|
||||
// attention tensor
|
||||
// TODO(eqy): should state be advanced per thread (local) amount or per call/launch (global) amount
|
||||
philox_state = gen->philox_cuda_state(batch_size * num_heads * max_seqlen_batch_q * max_seqlen_batch_k);
|
||||
unpack_cudnn<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
philox_state, static_cast<int64_t*>(cudnn_seed.data_ptr()), static_cast<int64_t*>(cudnn_offset.data_ptr()));
|
||||
}
|
||||
|
||||
auto cudnn_seed = at::zeros({1}, query.options().dtype(kLong));
|
||||
auto cudnn_offset = at::zeros({1}, query.options().dtype(kLong));
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
Tensor debugmask;
|
||||
|
||||
run_cudnn_SDP_fprop(batch_size/*int64_t b*/,
|
||||
num_heads/*int64_t h*/,
|
||||
max_seqlen_batch_q/*int64_t s_q*/,
|
||||
max_seqlen_batch_k/*int64_t s_kv*/,
|
||||
head_dim_qk/*int64_t d_qk*/,
|
||||
head_dim_v/*int64_t d_v*/,
|
||||
head_dim/*int64_t d*/,
|
||||
softmax_scale/*float scaling_factor*/,
|
||||
compute_logsumexp/* bool */,
|
||||
training/* bool */,
|
||||
is_causal/* bool */,
|
||||
dropout_p/*double dropout_probability*/,
|
||||
query/* Tensor q*/,
|
||||
@ -820,7 +775,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
||||
cudnn_seed/*Tensor dropoutseed*/,
|
||||
cudnn_offset/*Tensor dropoutoffset*/);
|
||||
|
||||
return std::make_tuple(attention, log_sumexp, cudnn_seed, cudnn_offset);
|
||||
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
|
||||
@ -1107,64 +1062,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(device));
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// ROCM Implementation
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
|
||||
}
|
||||
|
||||
// AOTriton may accept aligned on logsumexp tensor in the future for better
|
||||
// performance, but for now it requires compact logsumexp tensor, even if
|
||||
// compute_logsumexp is false
|
||||
constexpr int kAlignLSE = 1;
|
||||
res = at::empty({B, M, num_heads, Kv}, query.options());
|
||||
logsumexp = at::empty(
|
||||
{ B, num_heads, max_seqlen_q },
|
||||
query.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
|
||||
at::Tensor q_t = query.transpose(1, 2);
|
||||
at::Tensor k_t = key.transpose(1, 2);
|
||||
at::Tensor v_t = value.transpose(1, 2);
|
||||
at::Tensor output_t = res.transpose(1, 2);
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
|
||||
}
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
|
||||
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
|
||||
hipError_t err; // TODO: Error handling
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(softmax_lse, "M"),
|
||||
mk_aotensor(output_t, "Out"),
|
||||
dropout_p,
|
||||
use_dropout ? *seed_t.data_ptr<int64_t>() : 0,
|
||||
use_dropout ? *offset_t.data_ptr<int64_t>() : 0,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
stream);
|
||||
if (!compute_logsumexp) {
|
||||
// Set the tensor to empty when compute_logsumexp is false
|
||||
logsumexp = at::empty(
|
||||
{ B * num_heads, max_seqlen_q, 0 },
|
||||
query.options().dtype(at::ScalarType::Float));
|
||||
}
|
||||
#else
|
||||
// CUDA Implementation
|
||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
||||
const int computeCapability = p->major * 10 + p->minor;
|
||||
|
||||
@ -1334,7 +1231,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
||||
TORCH_CHECK(kernel_launched, "cutlassF: no kernel found to launch!");
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
#endif // USE_ROCM
|
||||
return std::make_tuple(
|
||||
std::move(res),
|
||||
std::move(logsumexp),
|
||||
@ -1355,7 +1251,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
|
||||
|
||||
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
|
||||
|
||||
#if defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
namespace {
|
||||
/**
|
||||
* simple kernel that populates a tensor with rand uniform values.
|
||||
@ -1405,7 +1301,7 @@ __global__ void rand_uniform_kernel(
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
#endif // defined(USE_MEM_EFF_ATTENTION) and !defined(USE_ROCM)
|
||||
#endif
|
||||
/**
|
||||
* fill tensor with random uniform values. only used for testing, not much
|
||||
* attention is paid to performance
|
||||
@ -1423,17 +1319,6 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
|
||||
const int64_t n_keys = self.size(3);
|
||||
#if defined(USE_MEM_EFF_ATTENTION)
|
||||
|
||||
#ifdef USE_ROCM
|
||||
using aotriton::v2::flash::debug_fill_dropout_rng;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
hipError_t err; // TODO: Error handling
|
||||
|
||||
err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
|
||||
static_cast<uint64_t>(seed),
|
||||
static_cast<uint64_t>(offset),
|
||||
stream);
|
||||
#else
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
rng_engine_inputs = at::PhiloxCudaState(seed, offset);
|
||||
at::cuda::CUDAGuard device_guard(self.device());
|
||||
@ -1447,7 +1332,6 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
|
||||
rng_engine_inputs,
|
||||
reinterpret_cast<float*>(self.data_ptr()),
|
||||
self.numel());
|
||||
#endif
|
||||
|
||||
return self;
|
||||
#endif
|
||||
|
@ -36,18 +36,11 @@
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
#ifndef USE_ROCM
|
||||
// MemoryEfficient Attention Specific Imports for CUDA
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
|
||||
#else
|
||||
// MemoryEfficient Attention Specific Imports for ROCM
|
||||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
@ -171,34 +164,21 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
// const Tensor& cumulative_sequence_length_q,
|
||||
// const Tensor& cumulative_sequence_length_k,
|
||||
// const int64_t max_seqlen_batch_q,
|
||||
// const int64_t max_seqlen_batch_k,
|
||||
const Tensor& cumulative_sequence_length_q,
|
||||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
c10::optional<double> scale) {
|
||||
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
if (ctx.deterministicAlgorithms()) {
|
||||
if (ctx.deterministicAlgorithmsWarnOnly()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"cuDNN Attention defaults to a non-deterministic algorithm. ",
|
||||
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
std::optional<double> scale) {
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_batch_q = query.size(1);
|
||||
const int64_t max_seqlen_batch_k = key.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
@ -206,8 +186,7 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
||||
num_heads /*int64_t h*/,
|
||||
max_seqlen_batch_q /*int64_t s_q*/,
|
||||
max_seqlen_batch_k /*int64_t s_kv*/,
|
||||
head_dim_qk /*int64_t d_qk*/,
|
||||
head_dim_v /*int64_t d_v*/,
|
||||
head_dim /*int64_t d*/,
|
||||
softmax_scale /*float scaling_factor*/,
|
||||
is_causal /*bool is_causal*/,
|
||||
dropout_p /*float dropout_probability*/,
|
||||
@ -369,6 +348,7 @@ _efficient_attention_backward(
|
||||
grad_bias = at::empty(sz, bias->options())
|
||||
.slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim);
|
||||
}
|
||||
at::Tensor workspace;
|
||||
|
||||
const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
|
||||
|
||||
@ -388,62 +368,6 @@ _efficient_attention_backward(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// ROCM Implementation
|
||||
TORCH_CHECK(!num_splits_key.has_value(),
|
||||
"ROCM does not support num_split_keys in _efficient_attention_forward");
|
||||
TORCH_CHECK(!window_size.has_value(),
|
||||
"ROCM does not support window_size in _efficient_attention_forward");
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
|
||||
}
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
|
||||
}
|
||||
at::Tensor q_t = query.permute({0,2,1,3});
|
||||
at::Tensor k_t = key.permute({0,2,1,3});
|
||||
at::Tensor v_t = value.permute({0,2,1,3});
|
||||
at::Tensor out_t = out.permute({0,2,1,3});
|
||||
at::Tensor dq_t = grad_q.permute({0,2,1,3});
|
||||
at::Tensor dk_t = grad_k.permute({0,2,1,3});
|
||||
at::Tensor dv_t = grad_v.permute({0,2,1,3});
|
||||
at::Tensor dout_t = grad_out.permute({0,2,1,3});
|
||||
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
|
||||
hipError_t err;
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||
mk_aotensor<2>(softmax_lse, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
float(dropout_p),
|
||||
rng_engine_inputs.seed_.val,
|
||||
rng_engine_inputs.offset_.val,
|
||||
is_causal,
|
||||
stream);
|
||||
#else
|
||||
at::Tensor workspace;
|
||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
||||
const int computeCapability = p->major * 10 + p->minor;
|
||||
|
||||
@ -700,9 +624,8 @@ _efficient_attention_backward(
|
||||
}));
|
||||
TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!");
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
#endif // USE_ROCM
|
||||
return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias));
|
||||
#endif // defined(USE_MEM_EFF_ATTENTION)
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
@ -45,28 +44,14 @@
|
||||
|
||||
namespace sdp {
|
||||
namespace {
|
||||
|
||||
// TODO(eqy): more benchmarking to determine whether this should include sm86/89
|
||||
// Needs to be kept in-sync with test_fused_chocie in test_transformers.py
|
||||
bool check_prefer_cudnn_attention() {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
return dprops->major >= 9;
|
||||
}
|
||||
|
||||
// flash_attention V2 is universally faster than efficient_attention and Math
|
||||
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
|
||||
constexpr std::array<SDPBackend, num_backends> default_order{
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::cudnn_attention,
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::math};
|
||||
constexpr std::array<SDPBackend, num_backends> cudnn_order{
|
||||
SDPBackend::cudnn_attention,
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::math};
|
||||
static const bool prefer_cudnn = check_prefer_cudnn_attention();
|
||||
return prefer_cudnn ? cudnn_order : default_order;
|
||||
return default_order;
|
||||
}
|
||||
|
||||
bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {
|
||||
@ -230,17 +215,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
||||
using sm50 = SMVersion<5, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
#if USE_ROCM
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm50, sm90>(dprops)) {
|
||||
if (debug) {
|
||||
@ -253,7 +227,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -466,6 +439,17 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_is_causal(sdp_params const& params, bool debug) {
|
||||
// Check that the input is causal
|
||||
if (!params.is_causal) {
|
||||
if (debug) {
|
||||
TORCH_WARN("CuDNN requires is_causal=True.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_for_nested_inputs(sdp_params const& params, bool debug) {
|
||||
// Check that the input is nested
|
||||
if (has_for_nested_inputs(params)) {
|
||||
@ -489,6 +473,22 @@ bool check_dtypes_low_precision(sdp_params const& params, bool debug) {
|
||||
}
|
||||
}
|
||||
|
||||
bool check_runtime_enabled_cudnn(sdp_params const& params, bool debug) {
|
||||
static c10::once_flag supported_flag;
|
||||
static bool supported = false;
|
||||
c10::call_once(supported_flag, []() {
|
||||
supported = (c10::utils::check_env("TORCH_CUDNN_SDPA_ENABLED") == true);
|
||||
});
|
||||
if (!supported) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1`");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
|
||||
// We check the global context to see if user has explicitly turned of cudnn
|
||||
// sdp kernels
|
||||
@ -501,15 +501,13 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_cudnn_deterministic(const sdp_params& params, bool debug) {
|
||||
auto& ctx = at::globalContext();
|
||||
if (ctx.deterministicAlgorithms()) {
|
||||
if (!ctx.deterministicAlgorithmsWarnOnly()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN SDPA is not deterministic.");
|
||||
}
|
||||
return false;
|
||||
bool check_cudnn_requires_grad(sdp_params const& params, bool debug) {
|
||||
// Check that the input is causal
|
||||
if (input_requires_grad(params)) {
|
||||
if (debug) {
|
||||
TORCH_WARN("CuDNN does not currently support inputs with requires_grad=True.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -517,29 +515,21 @@ bool check_cudnn_deterministic(const sdp_params& params, bool debug) {
|
||||
} // namespace
|
||||
|
||||
bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
||||
#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \
|
||||
(defined(CUDNN_VERSION) && CUDNN_VERSION < 8900)
|
||||
TORCH_WARN_ONCE(!debug, "Torch was not compiled with cuDNN attention.");
|
||||
return false;
|
||||
#endif
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
// Replace with std::to_array when we migrate to c++20
|
||||
constexpr auto general_constraints =
|
||||
array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_for_nested_inputs,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim>*/>,
|
||||
check_all_tensors_on_device,
|
||||
check_tensor_shapes,
|
||||
check_cudnn_tensor_shapes,
|
||||
check_runtime_enabled_cudnn,
|
||||
check_runtime_disabled_cudnn,
|
||||
check_cudnn_deterministic,
|
||||
// check_cudnn_layout,
|
||||
check_cudnn_hardware_support,
|
||||
check_all_tensors_on_device,
|
||||
check_cudnn_tensor_shapes,
|
||||
check_cudnn_layout,
|
||||
// check_is_causal,
|
||||
check_dtypes_low_precision,
|
||||
check_for_attn_mask_cudnn,
|
||||
check_cudnn_hardware_support
|
||||
);
|
||||
check_for_nested_inputs,
|
||||
check_cudnn_requires_grad,
|
||||
check_dtypes_low_precision);
|
||||
for (auto& constraint : general_constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
@ -607,10 +597,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
constexpr auto less_than_sm80_mem_efficient_dtypes =
|
||||
array_of<at::ScalarType>(at::kHalf, at::kFloat);
|
||||
#ifdef USE_ROCM
|
||||
constexpr auto aotriton_mem_efficient_dtypes =
|
||||
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
#endif
|
||||
|
||||
// Define gate functions that determine if a mem efficient kernel can be ran
|
||||
constexpr auto general_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||
@ -626,10 +612,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
|
||||
if (has_for_nested_inputs(params)) {
|
||||
#ifdef USE_ROCM
|
||||
TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
|
||||
return false;
|
||||
#endif
|
||||
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_requires_grad_and_nested,
|
||||
check_batch_size_nested,
|
||||
@ -652,14 +634,10 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
|
||||
#else
|
||||
auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprop->major >= 8) {
|
||||
return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
|
||||
}
|
||||
#endif
|
||||
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
|
||||
}
|
||||
|
||||
|
@ -1,130 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include <aotriton/dtypes.h>
|
||||
#include <aotriton/util.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK(TENSOR.is_contiguous());
|
||||
|
||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK( \
|
||||
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
||||
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
TORCH_CHECK( \
|
||||
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK( \
|
||||
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
|
||||
}
|
||||
|
||||
namespace sdp {
|
||||
|
||||
namespace aotriton_adapter {
|
||||
|
||||
inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
|
||||
{
|
||||
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
|
||||
CAST_TYPE(kByte, kUInt8);
|
||||
CAST_TYPE(kUInt16, kUInt16);
|
||||
CAST_TYPE(kUInt32, kUInt32);
|
||||
CAST_TYPE(kUInt64, kUInt64);
|
||||
CAST_TYPE(kChar, kInt8);
|
||||
CAST_TYPE(kShort, kInt16);
|
||||
CAST_TYPE(kInt, kInt32);
|
||||
CAST_TYPE(kLong, kInt64);
|
||||
CAST_TYPE(kHalf, kFloat16);
|
||||
CAST_TYPE(kFloat, kFloat32);
|
||||
CAST_TYPE(kBFloat16, kBFloat16);
|
||||
return aotriton::DType::kUnknown;
|
||||
#undef CAST_TYPE
|
||||
}
|
||||
|
||||
template<typename TargetType, int Rank>
|
||||
struct IntArrayRefCaster {
|
||||
// std::array<TargetType, Rank> cast(IntArrayRef);
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 1> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 2> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 2>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 3> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 3>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 4> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 4>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2)),
|
||||
static_cast<TargetType>(ref.at(3))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<int Rank = 4>
|
||||
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
|
||||
{
|
||||
const auto strides = q.strides();
|
||||
int real_rank = strides.size();
|
||||
if (real_rank != Rank) { // Lazy convertion of tensor_name
|
||||
TORCH_CHECK(false,
|
||||
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
|
||||
+ " but is " + std::to_string(real_rank));
|
||||
}
|
||||
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
|
||||
cast_dtype(q.dtype()));
|
||||
}
|
||||
|
||||
} // namespace aotriton_adapter
|
||||
|
||||
} // namespace sdp
|
||||
|
||||
namespace at::native {
|
||||
|
||||
inline int64_t ceil_div(int64_t numerator, int64_t denominator) {
|
||||
return (numerator + (denominator - 1)) / denominator;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif // USE_ROCM
|
@ -54,15 +54,16 @@
|
||||
#include <ATen/ops/pad.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
#include <ATen/native/transformers/hip/flash_attn/flash_api.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
// AOTriton headers
|
||||
#include <aotriton/dtypes.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#include <aotriton/util.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
@ -72,10 +73,90 @@ void check_gpu_arch(hipStream_t stream) {
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)")
|
||||
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
|
||||
}
|
||||
}
|
||||
|
||||
aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
|
||||
{
|
||||
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
|
||||
CAST_TYPE(kByte, kUInt8);
|
||||
CAST_TYPE(kUInt16, kUInt16);
|
||||
CAST_TYPE(kUInt32, kUInt32);
|
||||
CAST_TYPE(kUInt64, kUInt64);
|
||||
CAST_TYPE(kChar, kInt8);
|
||||
CAST_TYPE(kShort, kInt16);
|
||||
CAST_TYPE(kInt, kInt32);
|
||||
CAST_TYPE(kLong, kInt64);
|
||||
CAST_TYPE(kHalf, kFloat16);
|
||||
CAST_TYPE(kFloat, kFloat32);
|
||||
CAST_TYPE(kBFloat16, kBFloat16);
|
||||
return aotriton::DType::kUnknown;
|
||||
#undef CAST_TYPE
|
||||
}
|
||||
|
||||
template<typename TargetType, int Rank>
|
||||
struct IntArrayRefCaster {
|
||||
// std::array<TargetType, Rank> cast(IntArrayRef);
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 1> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 2> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 2>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 3> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 3>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 4> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 4>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2)),
|
||||
static_cast<TargetType>(ref.at(3))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<int Rank = 4>
|
||||
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
|
||||
{
|
||||
const auto strides = q.strides();
|
||||
int real_rank = strides.size();
|
||||
if (real_rank != Rank) { // Lazy convertion of tensor_name
|
||||
TORCH_CHECK(false,
|
||||
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
|
||||
+ " but is " + std::to_string(real_rank));
|
||||
}
|
||||
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
|
||||
cast_dtype(q.dtype()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
@ -219,13 +300,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
|
||||
hipError_t err; // TODO: Error handling
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(output_t, "Out"),
|
||||
@ -418,20 +495,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
hipError_t err; // TODO: Error handling
|
||||
{
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_t, "dq"),
|
||||
mk_aotensor(dk_t, "dk"),
|
||||
mk_aotensor(dv_t, "dv"),
|
||||
empty_bias,
|
||||
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
p_dropout,
|
||||
|
@ -266,18 +266,7 @@ inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug)
|
||||
inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
|
||||
if (params.attn_mask.has_value()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Flash Attention do not support non-null attn_mask.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO(eqy): remove this once support is added
|
||||
inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) {
|
||||
if (params.attn_mask.has_value()) {
|
||||
if (debug) {
|
||||
TORCH_WARN("cuDNN Attention does not support non-null attn_mask.");
|
||||
TORCH_WARN("Flash Attention does not support non-null attn_mask.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -324,7 +313,7 @@ inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
|
||||
(query_dim == 4))) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
|
||||
"Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
|
||||
query_dim,
|
||||
", Key dim: ",
|
||||
params.key.dim(),
|
||||
@ -436,7 +425,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool
|
||||
if (zero_seq_len_q || zero_seq_len_k) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"All fused kernels do not support zero seq_len_q or seq_len_kv.");
|
||||
"Both fused kernels do not support zero seq_len_q or seq_len_kv.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -471,7 +460,7 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool
|
||||
}
|
||||
epilogue_message << " instead.";
|
||||
TORCH_WARN(
|
||||
"All fused kernels require the last dimension of the input to have stride 1. ",
|
||||
"Both fused kernels require the last dimension of the input to have stride 1. ",
|
||||
"Got Query.stride(-1): ",
|
||||
params.query.sym_stride(-1),
|
||||
", Key.stride(-1): ",
|
||||
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,pass,2
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,fail_accuracy,8
|
||||
yolov3,pass,9
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,fail_to_run,0
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,pass,2
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,pass,2
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,fail_to_run,0
|
||||
|
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,0
|
||||
yolov3,pass,2
|
||||
|
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,fail_accuracy,8
|
||||
yolov3,pass,9
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user