[float16]: Fast path for torch.dot with float16/bfloat16 (#152799)

Fixes #152798

Add the fast path for dot with contiguous tensors for float16/bfloat16 types.

Performance with patch (see issue for benchmark and current performance):

![Improved dot performance](https://github.com/user-attachments/assets/57f64e90-8191-4710-adb0-f430644827de)

**We see up to 10x+ improvement in performance.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152799
Approved by: https://github.com/malfet
This commit is contained in:
Krishna Bindumadhavan
2025-05-05 18:29:39 +00:00
committed by PyTorch MergeBot
parent 172a7c942e
commit d57bf53225
3 changed files with 79 additions and 3 deletions

View File

@ -90,6 +90,8 @@ namespace at::native {
#if !defined(C10_MOBILE)
DEFINE_DISPATCH(fp16_gemv_trans_stub);
DEFINE_DISPATCH(bf16_gemv_trans_stub);
DEFINE_DISPATCH(fp16_dot_stub);
DEFINE_DISPATCH(bf16_dot_stub);
#endif // !defined(C10_MOBILE)
namespace blas_impl {
@ -120,6 +122,15 @@ void fp16_gemv_trans(
fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
static float fp16_dot(
const int64_t n,
const Half* x,
const int64_t incx,
const Half* y,
const int64_t incy) {
return fp16_dot_stub(kCPU, n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
#if defined(__aarch64__) && !defined(C10_MOBILE)
@ -384,6 +395,16 @@ void gemv_fast_path<at::BFloat16>(
y,
*incy);
}
static float bf16_dot(
const int64_t n,
const BFloat16* x,
const int64_t incx,
const BFloat16* y,
const int64_t incy) {
return bf16_dot_stub(kCPU, n, x, incx, y, incy);
}
#if !defined(__aarch64__)
// Currently, only fp16_gemv_trans is built for non-aarch64.
template <>
@ -695,6 +716,34 @@ c10::complex<float> dot_impl(int64_t n, const c10::complex<float>* x, int64_t in
return dot_impl_floating(n, x, incx, y, incy);
}
template <>
Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
#if !defined(C10_MOBILE)
if (incx == 1 && incy == 1) {
return blas_impl::fp16_dot(n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
}
template <>
BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) {
if (n == 1) {
incx = 1;
incy = 1;
}
#if !defined(C10_MOBILE)
if (incx == 1 && incy == 1) {
return blas_impl::bf16_dot(n, x, incx, y, incy);
}
#endif // !defined(C10_MOBILE)
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
}
namespace {
template <typename scalar_t>
struct vdot_op {
@ -721,7 +770,7 @@ scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y
#endif
}
// Skip reinstantiating the explicitly specialized types `float` and `double`.
// Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`.
#define INSTANTIATE_DOT_IMPL(scalar_t) \
template scalar_t dot_impl<scalar_t>( \
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
@ -730,8 +779,6 @@ INSTANTIATE_DOT_IMPL(int8_t)
INSTANTIATE_DOT_IMPL(int16_t)
INSTANTIATE_DOT_IMPL(int)
INSTANTIATE_DOT_IMPL(int64_t)
INSTANTIATE_DOT_IMPL(c10::Half)
INSTANTIATE_DOT_IMPL(c10::BFloat16)
#define INSTANTIATE_VDOT_IMPL(scalar_t) \
template scalar_t vdot_impl<scalar_t>( \

View File

@ -475,12 +475,35 @@ void bf16_gemv_trans(
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
}
float fp16_dot(
const int64_t n,
const at::Half* x,
const int64_t incx,
const at::Half* y,
const int64_t incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
return fp16_dot_with_fp32_arith(x, y, n);
}
float bf16_dot(
const int64_t n,
const at::BFloat16* x,
const int64_t incx,
const at::BFloat16* y,
const int64_t incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
return bf16_dot_with_fp32_arith(x, y, n);
}
#endif // !defined(C10_MOBILE)
} // namespace CPU_CAPABILITY
#if !defined(C10_MOBILE)
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
REGISTER_DISPATCH(fp16_dot_stub, &fp16_dot)
REGISTER_DISPATCH(bf16_dot_stub, &bf16_dot)
#endif //!defined(C10_MOBILE)
} // namespace at::native

View File

@ -13,6 +13,12 @@ DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t);
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub)
using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t);
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub)
inline namespace CPU_CAPABILITY {
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);