mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] fix conv relu fusion (#162856)
Fixes #162816. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162856 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
8590c3a66b
commit
0def79fdd9
@ -1770,10 +1770,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> miopen_depthwise_convolution_back
|
||||
// fusions
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
void raw_miopen_convolution_relu_out(
|
||||
void raw_miopen_convolution_add_relu_out(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& z,
|
||||
float alpha,
|
||||
const Tensor& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
@ -1781,68 +1783,20 @@ void raw_miopen_convolution_relu_out(
|
||||
int64_t groups,
|
||||
bool benchmark,
|
||||
bool deterministic) {
|
||||
auto dataType = getMiopenDataType(input);
|
||||
miopenConvolutionMode_t c_mode = miopenConvolution;
|
||||
ConvolutionArgs args{ input, output, weight };
|
||||
args.handle = getMiopenHandle();
|
||||
at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight);
|
||||
setConvolutionParams(
|
||||
&args.params,
|
||||
args.handle,
|
||||
raw_miopen_convolution_forward_out(
|
||||
output,
|
||||
input,
|
||||
weight,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
deterministic,
|
||||
memory_format);
|
||||
args.idesc.set(input, memory_format);
|
||||
args.wdesc.set(weight, memory_format, 0);
|
||||
args.odesc.set(output, memory_format);
|
||||
args.cdesc.set(
|
||||
dataType,
|
||||
c_mode,
|
||||
input.dim() - 2,
|
||||
args.params.padding,
|
||||
args.params.stride,
|
||||
args.params.dilation,
|
||||
args.params.groups,
|
||||
benchmark,
|
||||
deterministic);
|
||||
|
||||
TensorDescriptor bdesc;
|
||||
bdesc.set(bias.expand({1, bias.size(0)}), output.dim());
|
||||
|
||||
// Create the fusion plan
|
||||
miopenFusionPlanDescriptor_t fusePlanDesc;
|
||||
miopenFusionOpDescriptor_t convoOp;
|
||||
miopenFusionOpDescriptor_t biasOp;
|
||||
miopenFusionOpDescriptor_t activOp;
|
||||
MIOPEN_CHECK(miopenCreateFusionPlan(&fusePlanDesc, miopenVerticalFusion, args.idesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpConvForward(fusePlanDesc, &convoOp, args.cdesc.desc(), args.wdesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpBiasForward(fusePlanDesc, &biasOp, bdesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpActivationForward(fusePlanDesc, &activOp, miopenActivationRELU));
|
||||
|
||||
// compile fusion plan
|
||||
MIOPEN_CHECK(miopenCompileFusionPlan(args.handle, fusePlanDesc));
|
||||
|
||||
// Set the Args
|
||||
float alpha = static_cast<float>(1);
|
||||
float beta = static_cast<float>(0);
|
||||
float activ_alpha = static_cast<float>(0);
|
||||
float activ_beta = static_cast<float>(0);
|
||||
float activ_gamma = static_cast<float>(0);
|
||||
miopenOperatorArgs_t fusionArgs;
|
||||
MIOPEN_CHECK(miopenCreateOperatorArgs(&fusionArgs));
|
||||
MIOPEN_CHECK(miopenSetOpArgsConvForward(fusionArgs, convoOp, &alpha, &beta, weight.const_data_ptr()));
|
||||
MIOPEN_CHECK(miopenSetOpArgsBiasForward(fusionArgs, biasOp, &alpha, &beta, bias.const_data_ptr()));
|
||||
MIOPEN_CHECK(miopenSetOpArgsActivForward(fusionArgs, activOp, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
|
||||
|
||||
miopenExecuteFusionPlan(args.handle, fusePlanDesc, args.idesc.desc(), input.const_data_ptr(), args.odesc.desc(), output.data_ptr(), fusionArgs);
|
||||
|
||||
// Cleanup
|
||||
miopenDestroyFusionPlan(fusePlanDesc);
|
||||
at::Tensor alpha_mul_z_add_bias =
|
||||
at::native::reshape_bias(input.dim(), bias).add(z, alpha);
|
||||
output.add_(alpha_mul_z_add_bias);
|
||||
output.relu_();
|
||||
}
|
||||
|
||||
static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) {
|
||||
@ -1855,171 +1809,107 @@ static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat m
|
||||
Tensor miopen_convolution_add_relu(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const Tensor& z,
|
||||
const Tensor& z_t,
|
||||
const std::optional<Scalar>& alpha,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<Tensor>& bias_t,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
|
||||
// MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function:
|
||||
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
|
||||
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
const Tensor input = input_t.contiguous(memory_format);
|
||||
const Tensor weight = weight_t.contiguous(memory_format);
|
||||
Tensor z = z_t;
|
||||
if (z.suggest_memory_format() != memory_format) {
|
||||
z = z.to(memory_format);
|
||||
}
|
||||
z = z.contiguous(memory_format);
|
||||
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input.sizes(), weight.sizes(), padding, stride, dilation),
|
||||
input.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input_t.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0){
|
||||
return output_t;
|
||||
}
|
||||
// Avoid ambiguity of "output" when this is being used as backwards
|
||||
TensorArg output{output_t, "result", 0};
|
||||
miopen_convolution_forward_out(
|
||||
output,
|
||||
"miopen_convolution_add_relu",
|
||||
raw_miopen_convolution_add_relu_out(
|
||||
output_t,
|
||||
input,
|
||||
weight,
|
||||
padding,
|
||||
z,
|
||||
_alpha,
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark,
|
||||
false // deterministic
|
||||
);
|
||||
true); // deterministic
|
||||
|
||||
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
|
||||
|
||||
if (!output_t.is_same(contig_output_t)) {
|
||||
contig_output_t.copy_(output_t);
|
||||
}
|
||||
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{contig_output_t.size(1)},
|
||||
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
|
||||
contig_output_t.options().layout_opt(),
|
||||
contig_output_t.options().device_opt(),
|
||||
contig_output_t.options().pinned_memory_opt());
|
||||
|
||||
at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input_t.dim(), _bias).add(z, _alpha);
|
||||
contig_output_t.add_(alpha_mul_z_add_bias);
|
||||
contig_output_t.relu_();
|
||||
|
||||
return contig_output_t;
|
||||
return output_t;
|
||||
}
|
||||
|
||||
Tensor miopen_convolution_relu(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<Tensor>& bias_t,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
const Tensor input = input_t.contiguous(memory_format);
|
||||
const Tensor weight = weight_t.contiguous(memory_format);
|
||||
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input.sizes(), weight.sizes(), padding, stride, dilation),
|
||||
input.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
// MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d
|
||||
if (input_t.suggest_memory_format() == at::MemoryFormat::Contiguous
|
||||
&& input_t.scalar_type() == at::kFloat
|
||||
&& input_t.ndimension() == 4) {
|
||||
raw_miopen_convolution_add_relu_out(
|
||||
output_t,
|
||||
input,
|
||||
weight,
|
||||
output_t, // use output_t as z to satisfy MIOpen API
|
||||
0, // alpha
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark, // benchmark
|
||||
true); // deterministic
|
||||
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input_t.options().memory_format(input_t.suggest_memory_format()));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
raw_miopen_convolution_relu_out(
|
||||
output_t,
|
||||
input_t,
|
||||
weight_t,
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark, // benchmark
|
||||
false // deterministic
|
||||
);
|
||||
|
||||
return output_t;
|
||||
}
|
||||
else {
|
||||
// fallback
|
||||
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input->options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0){
|
||||
return output_t;
|
||||
}
|
||||
// Avoid ambiguity of "output" when this is being used as backwards
|
||||
TensorArg output{output_t, "result", 0};
|
||||
miopen_convolution_forward_out(
|
||||
output,
|
||||
"miopen_convolution_relu",
|
||||
input,
|
||||
weight,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark,
|
||||
false // deterministic
|
||||
);
|
||||
|
||||
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
|
||||
|
||||
if (!output_t.is_same(contig_output_t)) {
|
||||
contig_output_t.copy_(output_t);
|
||||
}
|
||||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{contig_output_t.size(1)},
|
||||
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
|
||||
contig_output_t.options().layout_opt(),
|
||||
contig_output_t.options().device_opt(),
|
||||
contig_output_t.options().pinned_memory_opt());
|
||||
|
||||
at::Tensor reshaped_bias = at::native::reshape_bias(input_t.dim(), _bias);
|
||||
contig_output_t.add_(reshaped_bias);
|
||||
contig_output_t.relu_();
|
||||
|
||||
return contig_output_t;
|
||||
}
|
||||
return output_t;
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)
|
||||
|
@ -3865,6 +3865,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoCudnn
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
||||
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
||||
def test_cudnn_convolution_relu(self, device, dtype):
|
||||
for batch, groups, image_size, kernel_size, memory_format in product(
|
||||
@ -3898,6 +3899,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoCudnn
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
||||
@precisionOverride({torch.half: 0.002, torch.float: 1e-4})
|
||||
def test_cudnn_convolution_add_relu(self, device, dtype):
|
||||
for batch, groups, image_size, kernel_size, memory_format in product(
|
||||
|
Reference in New Issue
Block a user