mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Add cast to kFloat in amax calculation (#123872)
necessary cast to kFloat missed in previous amax PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/123872 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
b024c0c2ef
commit
2cb3301f80
@ -912,7 +912,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||
// rocm's hipblaslt does not yet support amax, so calculate separately
|
||||
amax = at::max(at::abs(out));
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
#endif
|
||||
|
||||
return {out, amax};
|
||||
|
Reference in New Issue
Block a user