[BE] Cleanup stale comments/copy from gemm (#162001)

Followup after https://github.com/pytorch/pytorch/pull/154012

Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162001
Approved by: https://github.com/drisspg
ghstack dependencies: #161999
This commit is contained in:
Nikita Shulga
2025-09-02 14:06:36 -07:00
committed by PyTorch MergeBot
parent 3f6d88f04c
commit 24492cbab2

View File

@ -457,24 +457,9 @@ void gemm(
return; return;
} }
#endif #endif
// for the fallback path, first compute gemm with beta = 0,
// and then add c in full precision.
int64_t c_size = n * m;
std::vector<float> float_c(c_size, 0.f);
gemm_no_downcast_stub( gemm_no_downcast_stub(
at::kCPU, at::kBFloat16, at::kCPU, at::kBFloat16,
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
auto offset = j * ldc + i;
// beta == 0 won't propagate NaN from C
if (beta == 0.f) {
c[offset] = float_c[j * m + i];
} else {
c[offset] = beta * c[offset] + float_c[j * m + i];
}
}
}
} }
void gemm( void gemm(
@ -493,24 +478,9 @@ void gemm(
return; return;
} }
#endif #endif
// for the fallback path, first compute gemm with beta = 0,
// and then add c in full precision.
int64_t c_size = n * m;
std::vector<float> float_c(c_size, 0.f);
gemm_no_downcast_stub( gemm_no_downcast_stub(
at::kCPU, at::kHalf, at::kCPU, at::kHalf,
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
for (const auto j : c10::irange(n)) {
for (const auto i : c10::irange(m)) {
auto offset = j * ldc + i;
// beta == 0 won't propagate NaN from C
if (beta == 0.f) {
c[offset] = float_c[j * m + i];
} else {
c[offset] = beta * c[offset] + float_c[j * m + i];
}
}
}
} }
void gemm( void gemm(