mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Don't go through dispatch for *_dot_with_fp32_arith (#140834)
We don't need to dispatch for these because they're only used from within ATen/native/cpu, which is rebuilt per-CPU_CAPABILITY anyway. Differential Revision: [D66012283](https://our.internmc.facebook.com/intern/diff/D66012283/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140834 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
baf756a785
commit
5df9207ba9
@ -85,9 +85,7 @@ extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int
|
||||
|
||||
namespace at::native {
|
||||
#if !defined(C10_MOBILE)
|
||||
DEFINE_DISPATCH(fp16_dot_with_fp32_arith_stub);
|
||||
DEFINE_DISPATCH(fp16_gemv_trans_stub);
|
||||
DEFINE_DISPATCH(bf16_dot_with_fp32_arith_stub);
|
||||
DEFINE_DISPATCH(bf16_gemv_trans_stub);
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
@ -105,18 +103,6 @@ void fp16_gemv_trans(
|
||||
Half* y,
|
||||
const int incy);
|
||||
|
||||
float fp16_dot_with_fp32_arith(
|
||||
const Half* vec1,
|
||||
const Half* vec2,
|
||||
int64_t len);
|
||||
|
||||
float fp16_dot_with_fp32_arith(
|
||||
const Half* x,
|
||||
const Half* a,
|
||||
int64_t len) {
|
||||
return fp16_dot_with_fp32_arith_stub(kCPU, x, a, len);
|
||||
}
|
||||
|
||||
void fp16_gemv_trans(
|
||||
const int m,
|
||||
const int n,
|
||||
@ -143,17 +129,6 @@ void bf16_gemv_trans(
|
||||
at::BFloat16* y,
|
||||
const int incy);
|
||||
|
||||
float bf16_dot_with_fp32_arith(
|
||||
const at::BFloat16* vec1,
|
||||
const at::BFloat16* vec2,
|
||||
int64_t len);
|
||||
|
||||
float bf16_dot_with_fp32_arith(
|
||||
const at::BFloat16* vec1,
|
||||
const at::BFloat16* vec2,
|
||||
int64_t len) {
|
||||
return bf16_dot_with_fp32_arith_stub(kCPU, vec1, vec2, len);
|
||||
}
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE)
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/cpu/zmath.h>
|
||||
#include <ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Unroll.h>
|
||||
|
||||
@ -366,7 +367,7 @@ void gemm_notrans_(
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
|
||||
return at::native::blas_impl::fp16_dot_with_fp32_arith(
|
||||
return at::native::CPU_CAPABILITY::fp16_dot_with_fp32_arith(
|
||||
a, b, len);
|
||||
}
|
||||
|
||||
@ -403,7 +404,7 @@ void gemm_transa_(
|
||||
}
|
||||
|
||||
static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) {
|
||||
return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len);
|
||||
return at::native::CPU_CAPABILITY::bf16_dot_with_fp32_arith(a, b, len);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -474,9 +474,7 @@ void bf16_gemv_trans(
|
||||
} // namespace CPU_CAPABILITY
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith)
|
||||
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
|
||||
REGISTER_DISPATCH(bf16_dot_with_fp32_arith_stub, &bf16_dot_with_fp32_arith)
|
||||
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
|
||||
#endif //!defined(C10_MOBILE)
|
||||
|
||||
|
@ -7,14 +7,15 @@
|
||||
|
||||
namespace at::native {
|
||||
#if !defined(C10_MOBILE)
|
||||
using fp16_dot_fn = float(*)(const Half*, const Half*, int64_t);
|
||||
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
|
||||
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub)
|
||||
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
|
||||
|
||||
using bf16_dot_fn = float(*)(const BFloat16*, const BFloat16*, int64_t);
|
||||
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
|
||||
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_with_fp32_arith_stub)
|
||||
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
|
||||
|
||||
inline namespace CPU_CAPABILITY {
|
||||
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
|
||||
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);
|
||||
} // inline namespace CPU_CAPABILITY
|
||||
#endif // !defined(C10_MOBILE)
|
||||
} // namespace at::native
|
||||
|
Reference in New Issue
Block a user