mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[float16]: Fast path for torch.dot with float16/bfloat16 (#152799)
Fixes #152798 Add the fast path for dot with contiguous tensors for float16/bfloat16 types. Performance with patch (see issue for benchmark and current performance):  **We see up to 10x+ improvement in performance.** Pull Request resolved: https://github.com/pytorch/pytorch/pull/152799 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b06cbd49f1
commit
f47bf38e30
@ -90,6 +90,8 @@ namespace at::native {
|
||||
#if !defined(C10_MOBILE)
|
||||
DEFINE_DISPATCH(fp16_gemv_trans_stub);
|
||||
DEFINE_DISPATCH(bf16_gemv_trans_stub);
|
||||
DEFINE_DISPATCH(fp16_dot_stub);
|
||||
DEFINE_DISPATCH(bf16_dot_stub);
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
namespace blas_impl {
|
||||
@ -120,6 +122,24 @@ void fp16_gemv_trans(
|
||||
fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
|
||||
}
|
||||
|
||||
static float fp16_dot(
|
||||
const int64_t n,
|
||||
const Half* x,
|
||||
const int64_t incx,
|
||||
const Half* y,
|
||||
const int64_t incy) {
|
||||
return fp16_dot_stub(kCPU, n, x, incx, y, incy);
|
||||
}
|
||||
|
||||
static float bf16_dot(
|
||||
const int64_t n,
|
||||
const BFloat16* x,
|
||||
const int64_t incx,
|
||||
const BFloat16* y,
|
||||
const int64_t incy) {
|
||||
return bf16_dot_stub(kCPU, n, x, incx, y, incy);
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE)
|
||||
@ -695,6 +715,34 @@ c10::complex<float> dot_impl(int64_t n, const c10::complex<float>* x, int64_t in
|
||||
return dot_impl_floating(n, x, incx, y, incy);
|
||||
}
|
||||
|
||||
template <>
|
||||
Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) {
|
||||
if (n == 1) {
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
#if !defined(C10_MOBILE)
|
||||
if (incx == 1 && incy == 1) {
|
||||
return blas_impl::fp16_dot(n, x, incx, y, incy);
|
||||
}
|
||||
#endif // !defined(C10_MOBILE)
|
||||
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
|
||||
}
|
||||
|
||||
template <>
|
||||
BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) {
|
||||
if (n == 1) {
|
||||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
#if !defined(C10_MOBILE)
|
||||
if (incx == 1 && incy == 1) {
|
||||
return blas_impl::bf16_dot(n, x, incx, y, incy);
|
||||
}
|
||||
#endif // !defined(C10_MOBILE)
|
||||
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<float>{});
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct vdot_op {
|
||||
@ -721,7 +769,7 @@ scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y
|
||||
#endif
|
||||
}
|
||||
|
||||
// Skip reinstantiating the explicitly specialized types `float` and `double`.
|
||||
// Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`.
|
||||
#define INSTANTIATE_DOT_IMPL(scalar_t) \
|
||||
template scalar_t dot_impl<scalar_t>( \
|
||||
int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy);
|
||||
@ -730,8 +778,6 @@ INSTANTIATE_DOT_IMPL(int8_t)
|
||||
INSTANTIATE_DOT_IMPL(int16_t)
|
||||
INSTANTIATE_DOT_IMPL(int)
|
||||
INSTANTIATE_DOT_IMPL(int64_t)
|
||||
INSTANTIATE_DOT_IMPL(c10::Half)
|
||||
INSTANTIATE_DOT_IMPL(c10::BFloat16)
|
||||
|
||||
#define INSTANTIATE_VDOT_IMPL(scalar_t) \
|
||||
template scalar_t vdot_impl<scalar_t>( \
|
||||
|
@ -475,12 +475,35 @@ void bf16_gemv_trans(
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
|
||||
return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
|
||||
}
|
||||
|
||||
float fp16_dot(
|
||||
const int64_t n,
|
||||
const at::Half* x,
|
||||
const int64_t incx,
|
||||
const at::Half* y,
|
||||
const int64_t incy) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
|
||||
return fp16_dot_with_fp32_arith(x, y, n);
|
||||
}
|
||||
|
||||
float bf16_dot(
|
||||
const int64_t n,
|
||||
const at::BFloat16* x,
|
||||
const int64_t incx,
|
||||
const at::BFloat16* y,
|
||||
const int64_t incy) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1);
|
||||
return bf16_dot_with_fp32_arith(x, y, n);
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE)
|
||||
} // namespace CPU_CAPABILITY
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
|
||||
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
|
||||
REGISTER_DISPATCH(fp16_dot_stub, &fp16_dot)
|
||||
REGISTER_DISPATCH(bf16_dot_stub, &bf16_dot)
|
||||
#endif //!defined(C10_MOBILE)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -13,6 +13,12 @@ DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
|
||||
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
|
||||
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
|
||||
|
||||
using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t);
|
||||
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub)
|
||||
|
||||
using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t);
|
||||
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub)
|
||||
|
||||
inline namespace CPU_CAPABILITY {
|
||||
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
|
||||
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);
|
||||
|
Reference in New Issue
Block a user