mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Cleanup stale comments/copy from gemm
(#162001)
Followup after https://github.com/pytorch/pytorch/pull/154012 Since the introduction of `gemm_no_downcast_stub` it's no longer necessary to allocate temporary array and then manually implement the `beta` logic in the codebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/162001 Approved by: https://github.com/drisspg ghstack dependencies: #161999
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f6d88f04c
commit
24492cbab2
@ -457,24 +457,9 @@ void gemm(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
// for the fallback path, first compute gemm with beta = 0,
|
|
||||||
// and then add c in full precision.
|
|
||||||
int64_t c_size = n * m;
|
|
||||||
std::vector<float> float_c(c_size, 0.f);
|
|
||||||
gemm_no_downcast_stub(
|
gemm_no_downcast_stub(
|
||||||
at::kCPU, at::kBFloat16,
|
at::kCPU, at::kBFloat16,
|
||||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
|
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||||
for (const auto j : c10::irange(n)) {
|
|
||||||
for (const auto i : c10::irange(m)) {
|
|
||||||
auto offset = j * ldc + i;
|
|
||||||
// beta == 0 won't propagate NaN from C
|
|
||||||
if (beta == 0.f) {
|
|
||||||
c[offset] = float_c[j * m + i];
|
|
||||||
} else {
|
|
||||||
c[offset] = beta * c[offset] + float_c[j * m + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm(
|
void gemm(
|
||||||
@ -493,24 +478,9 @@ void gemm(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
// for the fallback path, first compute gemm with beta = 0,
|
|
||||||
// and then add c in full precision.
|
|
||||||
int64_t c_size = n * m;
|
|
||||||
std::vector<float> float_c(c_size, 0.f);
|
|
||||||
gemm_no_downcast_stub(
|
gemm_no_downcast_stub(
|
||||||
at::kCPU, at::kHalf,
|
at::kCPU, at::kHalf,
|
||||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m);
|
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
||||||
for (const auto j : c10::irange(n)) {
|
|
||||||
for (const auto i : c10::irange(m)) {
|
|
||||||
auto offset = j * ldc + i;
|
|
||||||
// beta == 0 won't propagate NaN from C
|
|
||||||
if (beta == 0.f) {
|
|
||||||
c[offset] = float_c[j * m + i];
|
|
||||||
} else {
|
|
||||||
c[offset] = beta * c[offset] + float_c[j * m + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm(
|
void gemm(
|
||||||
|
Reference in New Issue
Block a user