mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ReducedPrecisionFloatGemvFastPathKernel: Correctly type parallel_for lambda arguments as int64_t (#152233)
This plus the previous irangeification PR seem like a better fix for #150637 than #150949 to me -- should make sure we are using 64-bit math for indexing everywhere. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152233 Approved by: https://github.com/Skylion007, https://github.com/cyyever ghstack dependencies: #152232
This commit is contained in:
committed by
PyTorch MergeBot
parent
3b7d6bbe8b
commit
7e8b9b3f51
@ -105,22 +105,21 @@ float fp16_dot_with_fp16_arith(const Half* x, const Half* a, int len) {
|
||||
// Rather than unrolling to process multiple rows (transposed columns)
|
||||
// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
|
||||
// along an individual dot product.
|
||||
// NB: lda must be long, otherwise it can cause int32 overflow
|
||||
static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const Half* a, const int64_t lda, const Half *x, const float beta, Half* y, int incy) {
|
||||
if (beta == 0.0f) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m);
|
||||
}
|
||||
});
|
||||
} else if (beta == 1.0f) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] += fp16_dot_with_fp16_arith(x, a + lda * i, m);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp16_arith(x, a + lda * i, m);
|
||||
}
|
||||
@ -396,23 +395,22 @@ float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len)
|
||||
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
|
||||
}
|
||||
|
||||
// NB: lda must be long, otherwise it can cause int32 overflow
|
||||
void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const Half* a, const int64_t lda, const Half *x, const float beta, Half* y, int incy) {
|
||||
if (beta == 0.0f) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m);
|
||||
}
|
||||
});
|
||||
} else if (beta == 1.0f) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
// We need to accumulate in fp32; y[i * incy] += ... gets wrong results.
|
||||
y[i * incy] = static_cast<float>(y[i * incy]) + fp16_dot_with_fp32_arith(x, a + lda * i, m);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp32_arith(x, a + lda * i, m);
|
||||
}
|
||||
@ -451,9 +449,8 @@ float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec
|
||||
}
|
||||
}
|
||||
|
||||
// NB: lda must be long, otherwise it can cause int32 overflow
|
||||
void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int64_t lda, const at::BFloat16 *x, at::BFloat16* y, int incy) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
parallel_for(0, n, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m);
|
||||
}
|
||||
|
Reference in New Issue
Block a user