mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ReflectionPad
supports BFloat16
(#84949)
Just by looking at some commits, I didn't find why BFloat16 isn't there. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84949 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdd3665413
commit
26b5986297
@ -965,8 +965,8 @@ TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)
|
||||
auto input = input_.contiguous();
|
||||
|
||||
if (batch_mode) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
kHalf, input.scalar_type(), "reflection_pad3d_cpu", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad3d_cpu", [&] {
|
||||
auto input_data = input.data_ptr<scalar_t>();
|
||||
auto output_data = output.data_ptr<scalar_t>();
|
||||
auto nbatch = input.size(0);
|
||||
@ -986,8 +986,8 @@ TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)
|
||||
pad_front);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
kHalf, input.scalar_type(), "reflection_pad3d_cpu", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad3d_cpu", [&] {
|
||||
auto input_data = input.data_ptr<scalar_t>();
|
||||
auto output_data = output.data_ptr<scalar_t>();
|
||||
reflection_pad3d_out_frame(
|
||||
@ -1043,8 +1043,8 @@ TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)(const Tensor& grad_output,
|
||||
grad_input.zero_();
|
||||
|
||||
if (batch_mode) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
kHalf, input.scalar_type(), "reflection_pad3d_backward_cpu", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad3d_backward_cpu", [&] {
|
||||
reflection_pad3d_backward_out_loop<scalar_t>(
|
||||
grad_input.data_ptr<scalar_t>(),
|
||||
grad_output_.data_ptr<scalar_t>(),
|
||||
@ -1061,8 +1061,8 @@ TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)(const Tensor& grad_output,
|
||||
pad_front);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
kHalf, input.scalar_type(), "reflection_pad3d_backward_cpu", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad3d_backward_cpu", [&] {
|
||||
reflection_pad3d_backward_out_frame<scalar_t>(
|
||||
grad_input.data_ptr<scalar_t>(),
|
||||
grad_output_.data_ptr<scalar_t>(),
|
||||
|
@ -335,7 +335,7 @@ void reflection_pad2d_out_template(
|
||||
int64_t size_y = nplane;
|
||||
int64_t size_z = nbatch;
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "reflection_pad2d_out_template", [&] {
|
||||
|
||||
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
|
||||
@ -407,7 +407,7 @@ void reflection_pad2d_backward_out_template(
|
||||
int64_t size_y = nplane;
|
||||
int64_t size_z = nbatch;
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
|
||||
|
||||
for (int64_t block_y = 0; block_y < size_y; block_y += 65535) {
|
||||
@ -463,8 +463,8 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
|
||||
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
|
||||
kHalf, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
|
||||
reflection_pad1d_out_kernel<<<
|
||||
grid_size,
|
||||
block_size,
|
||||
@ -520,7 +520,7 @@ TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,
|
||||
dim3 block_size(output_w > 256 ? 256 : output_w);
|
||||
dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
|
||||
grad_input.scalar_type(), "reflection_pad1d_backward_out_cuda", [&] {
|
||||
reflection_pad1d_backward_out_kernel<<<
|
||||
grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
@ -589,7 +589,7 @@ TORCH_IMPL_FUNC(reflection_pad3d_out_cuda) (
|
||||
auto input = input_.contiguous();
|
||||
bool batch_mode = (input.dim() == 5);
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "reflection_pad3d_out_cuda", [&] {
|
||||
auto input_inner = input;
|
||||
auto output_inner = output;
|
||||
@ -641,7 +641,7 @@ TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cuda) (
|
||||
int64_t pad_top = padding[2];
|
||||
int64_t pad_front = padding[4];
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf,
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "reflection_pad3d_backward_out_cuda", [&] {
|
||||
auto grad_input_ = grad_input;
|
||||
auto grad_output_ = grad_output;
|
||||
|
@ -10734,7 +10734,7 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
dtypes=floating_and_complex_types(),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
|
||||
skips=(
|
||||
# Doesn't have a corresponding aten operator.
|
||||
|
Reference in New Issue
Block a user