[ROCm] fix conv relu fusion (#162856)

Fixes #162816.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162856
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-09-15 22:49:29 +00:00
committed by PyTorch MergeBot
parent 8590c3a66b
commit 0def79fdd9
2 changed files with 81 additions and 189 deletions

View File

@ -1770,10 +1770,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> miopen_depthwise_convolution_back
// fusions // fusions
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
void raw_miopen_convolution_relu_out( void raw_miopen_convolution_add_relu_out(
const Tensor& output, const Tensor& output,
const Tensor& input, const Tensor& input,
const Tensor& weight, const Tensor& weight,
const Tensor& z,
float alpha,
const Tensor& bias, const Tensor& bias,
IntArrayRef stride, IntArrayRef stride,
IntArrayRef padding, IntArrayRef padding,
@ -1781,68 +1783,20 @@ void raw_miopen_convolution_relu_out(
int64_t groups, int64_t groups,
bool benchmark, bool benchmark,
bool deterministic) { bool deterministic) {
auto dataType = getMiopenDataType(input); raw_miopen_convolution_forward_out(
miopenConvolutionMode_t c_mode = miopenConvolution; output,
ConvolutionArgs args{ input, output, weight };
args.handle = getMiopenHandle();
at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight);
setConvolutionParams(
&args.params,
args.handle,
input, input,
weight, weight,
padding, padding,
stride, stride,
dilation, dilation,
groups, groups,
deterministic,
memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(output, memory_format);
args.cdesc.set(
dataType,
c_mode,
input.dim() - 2,
args.params.padding,
args.params.stride,
args.params.dilation,
args.params.groups,
benchmark, benchmark,
deterministic); deterministic);
at::Tensor alpha_mul_z_add_bias =
TensorDescriptor bdesc; at::native::reshape_bias(input.dim(), bias).add(z, alpha);
bdesc.set(bias.expand({1, bias.size(0)}), output.dim()); output.add_(alpha_mul_z_add_bias);
output.relu_();
// Create the fusion plan
miopenFusionPlanDescriptor_t fusePlanDesc;
miopenFusionOpDescriptor_t convoOp;
miopenFusionOpDescriptor_t biasOp;
miopenFusionOpDescriptor_t activOp;
MIOPEN_CHECK(miopenCreateFusionPlan(&fusePlanDesc, miopenVerticalFusion, args.idesc.desc()));
MIOPEN_CHECK(miopenCreateOpConvForward(fusePlanDesc, &convoOp, args.cdesc.desc(), args.wdesc.desc()));
MIOPEN_CHECK(miopenCreateOpBiasForward(fusePlanDesc, &biasOp, bdesc.desc()));
MIOPEN_CHECK(miopenCreateOpActivationForward(fusePlanDesc, &activOp, miopenActivationRELU));
// compile fusion plan
MIOPEN_CHECK(miopenCompileFusionPlan(args.handle, fusePlanDesc));
// Set the Args
float alpha = static_cast<float>(1);
float beta = static_cast<float>(0);
float activ_alpha = static_cast<float>(0);
float activ_beta = static_cast<float>(0);
float activ_gamma = static_cast<float>(0);
miopenOperatorArgs_t fusionArgs;
MIOPEN_CHECK(miopenCreateOperatorArgs(&fusionArgs));
MIOPEN_CHECK(miopenSetOpArgsConvForward(fusionArgs, convoOp, &alpha, &beta, weight.const_data_ptr()));
MIOPEN_CHECK(miopenSetOpArgsBiasForward(fusionArgs, biasOp, &alpha, &beta, bias.const_data_ptr()));
MIOPEN_CHECK(miopenSetOpArgsActivForward(fusionArgs, activOp, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
miopenExecuteFusionPlan(args.handle, fusePlanDesc, args.idesc.desc(), input.const_data_ptr(), args.odesc.desc(), output.data_ptr(), fusionArgs);
// Cleanup
miopenDestroyFusionPlan(fusePlanDesc);
} }
static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) { static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) {
@ -1855,171 +1809,107 @@ static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat m
Tensor miopen_convolution_add_relu( Tensor miopen_convolution_add_relu(
const Tensor& input_t, const Tensor& input_t,
const Tensor& weight_t, const Tensor& weight_t,
const Tensor& z, const Tensor& z_t,
const std::optional<Scalar>& alpha, const std::optional<Scalar>& alpha,
const std::optional<Tensor>& bias, const std::optional<Tensor>& bias_t,
IntArrayRef stride, IntArrayRef stride,
IntArrayRef padding, IntArrayRef padding,
IntArrayRef dilation, IntArrayRef dilation,
int64_t groups) { int64_t groups) {
// MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function:
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t); auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
const Tensor input = input_t.contiguous(memory_format);
const Tensor weight = weight_t.contiguous(memory_format);
Tensor z = z_t;
if (z.suggest_memory_format() != memory_format) {
z = z.to(memory_format);
}
z = z.contiguous(memory_format);
// FuseFrozenConvAddRelu performs some tensor shape checking
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input.sizes(), weight.sizes(), padding, stride, dilation),
input.options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
auto& ctx = at::globalContext(); auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN(); bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
TensorArg input { input_t, "input", 1 }, raw_miopen_convolution_add_relu_out(
weight { weight_t, "weight", 2 }; output_t,
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input_t.options().memory_format(memory_format));
if (output_t.numel() == 0){
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{output_t, "result", 0};
miopen_convolution_forward_out(
output,
"miopen_convolution_add_relu",
input, input,
weight, weight,
padding, z,
_alpha,
_bias,
stride, stride,
padding,
dilation, dilation,
groups, groups,
benchmark, benchmark,
false // deterministic true); // deterministic
);
auto contig_output_t = self_or_new_memory_format(output_t, memory_format); return output_t;
if (!output_t.is_same(contig_output_t)) {
contig_output_t.copy_(output_t);
}
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{contig_output_t.size(1)},
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
contig_output_t.options().layout_opt(),
contig_output_t.options().device_opt(),
contig_output_t.options().pinned_memory_opt());
at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input_t.dim(), _bias).add(z, _alpha);
contig_output_t.add_(alpha_mul_z_add_bias);
contig_output_t.relu_();
return contig_output_t;
} }
Tensor miopen_convolution_relu( Tensor miopen_convolution_relu(
const Tensor& input_t, const Tensor& input_t,
const Tensor& weight_t, const Tensor& weight_t,
const std::optional<Tensor>& bias, const std::optional<Tensor>& bias_t,
IntArrayRef stride, IntArrayRef stride,
IntArrayRef padding, IntArrayRef padding,
IntArrayRef dilation, IntArrayRef dilation,
int64_t groups) { int64_t groups) {
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
const Tensor input = input_t.contiguous(memory_format);
const Tensor weight = weight_t.contiguous(memory_format);
// FuseFrozenConvAddRelu performs some tensor shape checking
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input.sizes(), weight.sizes(), padding, stride, dilation),
input.options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
auto& ctx = at::globalContext(); auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN(); bool benchmark = ctx.benchmarkCuDNN();
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
// MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d raw_miopen_convolution_add_relu_out(
if (input_t.suggest_memory_format() == at::MemoryFormat::Contiguous output_t,
&& input_t.scalar_type() == at::kFloat input,
&& input_t.ndimension() == 4) { weight,
output_t, // use output_t as z to satisfy MIOpen API
0, // alpha
_bias,
stride,
padding,
dilation,
groups,
benchmark, // benchmark
true); // deterministic
// FuseFrozenConvAddRelu performs some tensor shape checking return output_t;
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input_t.options().memory_format(input_t.suggest_memory_format()));
if (output_t.numel() == 0) {
return output_t;
}
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
raw_miopen_convolution_relu_out(
output_t,
input_t,
weight_t,
_bias,
stride,
padding,
dilation,
groups,
benchmark, // benchmark
false // deterministic
);
return output_t;
}
else {
// fallback
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input->options().memory_format(memory_format));
if (output_t.numel() == 0){
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{output_t, "result", 0};
miopen_convolution_forward_out(
output,
"miopen_convolution_relu",
input,
weight,
padding,
stride,
dilation,
groups,
benchmark,
false // deterministic
);
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
if (!output_t.is_same(contig_output_t)) {
contig_output_t.copy_(output_t);
}
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{contig_output_t.size(1)},
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
contig_output_t.options().layout_opt(),
contig_output_t.options().device_opt(),
contig_output_t.options().pinned_memory_opt());
at::Tensor reshaped_bias = at::native::reshape_bias(input_t.dim(), _bias);
contig_output_t.add_(reshaped_bias);
contig_output_t.relu_();
return contig_output_t;
}
} }
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward) REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)

View File

@ -3865,6 +3865,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
@onlyCUDA @onlyCUDA
@skipCUDAIfNoCudnn @skipCUDAIfNoCudnn
@dtypes(torch.float, torch.float16) @dtypes(torch.float, torch.float16)
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
@precisionOverride({torch.half: 0.002, torch.float: 1e-4}) @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
def test_cudnn_convolution_relu(self, device, dtype): def test_cudnn_convolution_relu(self, device, dtype):
for batch, groups, image_size, kernel_size, memory_format in product( for batch, groups, image_size, kernel_size, memory_format in product(
@ -3898,6 +3899,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
@onlyCUDA @onlyCUDA
@skipCUDAIfNoCudnn @skipCUDAIfNoCudnn
@dtypes(torch.float, torch.float16) @dtypes(torch.float, torch.float16)
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
@precisionOverride({torch.half: 0.002, torch.float: 1e-4}) @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
def test_cudnn_convolution_add_relu(self, device, dtype): def test_cudnn_convolution_add_relu(self, device, dtype):
for batch, groups, image_size, kernel_size, memory_format in product( for batch, groups, image_size, kernel_size, memory_format in product(