[ROCm] Implement float32 copy kernel (#163869)

* Add `float32_copy_kernel` for vectorizing float16/bfloat16 to float32 conversion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163869
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jerry Mannil
2025-09-26 00:39:27 +00:00
committed by PyTorch MergeBot
parent 5b8fef3f17
commit b4be380480

View File

@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
});
}
#ifdef USE_ROCM
void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
return static_cast<float>(value);
});
}
void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
return static_cast<float>(value);
});
}
#endif
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
ScalarType other_dtype = iter.dtype(1);
@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
} else {
float16_copy_kernel_cuda(iter);
}
} else if (isBitsType(dtype)) {
}
#ifdef USE_ROCM
else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
if (iter.dtype(1) == kBFloat16) {
bfloat16tofloat32_copy_kernel_cuda(iter);
} else {
float16tofloat32_copy_kernel_cuda(iter);
}
}
#endif
else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {