mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Implement float32 copy kernel (#163869)
* Add `float32_copy_kernel` for vectorizing float16/bfloat16 to float32 conversion Pull Request resolved: https://github.com/pytorch/pytorch/pull/163869 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b8fef3f17
commit
b4be380480
@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
});
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
|
||||
return static_cast<float>(value);
|
||||
});
|
||||
}
|
||||
void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
|
||||
return static_cast<float>(value);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
ScalarType other_dtype = iter.dtype(1);
|
||||
@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
|
||||
} else {
|
||||
float16_copy_kernel_cuda(iter);
|
||||
}
|
||||
} else if (isBitsType(dtype)) {
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
|
||||
if (iter.dtype(1) == kBFloat16) {
|
||||
bfloat16tofloat32_copy_kernel_cuda(iter);
|
||||
} else {
|
||||
float16tofloat32_copy_kernel_cuda(iter);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else if (isBitsType(dtype)) {
|
||||
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
|
||||
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
|
||||
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
|
||||
|
Reference in New Issue
Block a user