ReducedPrecisionFloatGemvFastPathKernel: Correctly type parallel_for lambda arguments as int64_t (#152233)

This plus the previous irangeification PR seem like a better fix for #150637 than #150949 to me -- should make sure we are using 64-bit math for indexing everywhere.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152233
Approved by: https://github.com/Skylion007, https://github.com/cyyever
ghstack dependencies: #152232
This commit is contained in:
Scott Wolchok
2025-04-25 16:38:37 -07:00
committed by PyTorch MergeBot
parent 3b7d6bbe8b
commit 7e8b9b3f51

View File

@ -105,22 +105,21 @@ float fp16_dot_with_fp16_arith(const Half* x, const Half* a, int len) {
// Rather than unrolling to process multiple rows (transposed columns)
// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
// along an individual dot product.
// NB: lda must be long, otherwise it can cause int32 overflow
static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const Half* a, const int64_t lda, const Half *x, const float beta, Half* y, int incy) {
if (beta == 0.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
});
} else if (beta == 1.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] += fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
});
} else {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
@ -396,23 +395,22 @@ float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len)
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
}
// NB: lda must be long, otherwise it can cause int32 overflow
void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const Half* a, const int64_t lda, const Half *x, const float beta, Half* y, int incy) {
if (beta == 0.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
} else if (beta == 1.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
// We need to accumulate in fp32; y[i * incy] += ... gets wrong results.
y[i * incy] = static_cast<float>(y[i * incy]) + fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
} else {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
@ -451,9 +449,8 @@ float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec
}
}
// NB: lda must be long, otherwise it can cause int32 overflow
void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int64_t lda, const at::BFloat16 *x, at::BFloat16* y, int incy) {
parallel_for(0, n, 1, [&](int begin, int end) {
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m);
}