[ROCm] Add cast to kFloat in amax calculation (#123872)

necessary cast to kFloat missed in previous amax PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123872
Approved by: https://github.com/drisspg
This commit is contained in:
Andres Lugo-Reyes
2024-04-12 15:38:41 +00:00
committed by PyTorch MergeBot
parent b024c0c2ef
commit 2cb3301f80

View File

@ -912,7 +912,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
// rocm's hipblaslt does not yet support amax, so calculate separately
amax = at::max(at::abs(out));
amax = at::max(at::abs(out.to(kFloat)));
#endif
return {out, amax};