Don't go through dispatch for *_dot_with_fp32_arith (#140834)

We don't need to dispatch for these because they're only used from within ATen/native/cpu, which is rebuilt per-CPU_CAPABILITY anyway.

Differential Revision: [D66012283](https://our.internmc.facebook.com/intern/diff/D66012283/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140834
Approved by: https://github.com/malfet
This commit is contained in:
Scott Wolchok
2024-11-15 09:46:34 -08:00
committed by PyTorch MergeBot
parent baf756a785
commit 5df9207ba9
4 changed files with 8 additions and 33 deletions

View File

@ -85,9 +85,7 @@ extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int
namespace at::native {
#if !defined(C10_MOBILE)
DEFINE_DISPATCH(fp16_dot_with_fp32_arith_stub);
DEFINE_DISPATCH(fp16_gemv_trans_stub);
DEFINE_DISPATCH(bf16_dot_with_fp32_arith_stub);
DEFINE_DISPATCH(bf16_gemv_trans_stub);
#endif // !defined(C10_MOBILE)
@ -105,18 +103,6 @@ void fp16_gemv_trans(
Half* y,
const int incy);
float fp16_dot_with_fp32_arith(
const Half* vec1,
const Half* vec2,
int64_t len);
float fp16_dot_with_fp32_arith(
const Half* x,
const Half* a,
int64_t len) {
return fp16_dot_with_fp32_arith_stub(kCPU, x, a, len);
}
void fp16_gemv_trans(
const int m,
const int n,
@ -143,17 +129,6 @@ void bf16_gemv_trans(
at::BFloat16* y,
const int incy);
float bf16_dot_with_fp32_arith(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
int64_t len);
float bf16_dot_with_fp32_arith(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
int64_t len) {
return bf16_dot_with_fp32_arith_stub(kCPU, vec1, vec2, len);
}
#endif // !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE)

View File

@ -3,6 +3,7 @@
#include <ATen/Parallel.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/cpu/zmath.h>
#include <ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h>
@ -366,7 +367,7 @@ void gemm_notrans_(
#if !defined(C10_MOBILE)
static float compute_dot(const at::Half* a, const at::Half* b, int64_t len) {
return at::native::blas_impl::fp16_dot_with_fp32_arith(
return at::native::CPU_CAPABILITY::fp16_dot_with_fp32_arith(
a, b, len);
}
@ -403,7 +404,7 @@ void gemm_transa_(
}
static float compute_dot(const at::BFloat16* a, const at::BFloat16* b, int64_t len) {
return at::native::blas_impl::bf16_dot_with_fp32_arith(a, b, len);
return at::native::CPU_CAPABILITY::bf16_dot_with_fp32_arith(a, b, len);
}
template <>

View File

@ -474,9 +474,7 @@ void bf16_gemv_trans(
} // namespace CPU_CAPABILITY
#if !defined(C10_MOBILE)
REGISTER_DISPATCH(fp16_dot_with_fp32_arith_stub, &fp16_dot_with_fp32_arith)
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
REGISTER_DISPATCH(bf16_dot_with_fp32_arith_stub, &bf16_dot_with_fp32_arith)
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
#endif //!defined(C10_MOBILE)

View File

@ -7,14 +7,15 @@
namespace at::native {
#if !defined(C10_MOBILE)
using fp16_dot_fn = float(*)(const Half*, const Half*, int64_t);
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_with_fp32_arith_stub)
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
using bf16_dot_fn = float(*)(const BFloat16*, const BFloat16*, int64_t);
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_with_fp32_arith_stub)
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
inline namespace CPU_CAPABILITY {
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);
} // inline namespace CPU_CAPABILITY
#endif // !defined(C10_MOBILE)
} // namespace at::native