[Intel GPU] convolution fusion at XPU backend (#154202)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154202
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/etaf
ghstack dependencies: #140365
This commit is contained in:
ZhiweiYan-96
2025-05-27 02:07:04 +00:00
committed by PyTorch MergeBot
parent c6fc11af76
commit f12ce4e36b
2 changed files with 289 additions and 75 deletions

View File

@ -3,6 +3,7 @@
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/interned_strings.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/mkldnn/xpu/FusionUtils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/ops/full.h>
@ -309,81 +310,6 @@ static at::Tensor view3d(const at::Tensor& tensor) {
return tensor.squeeze(2);
}
Attr get_onednn_conv_sum_attr(
const Tensor& input_r,
const Tensor& weight_r,
IntArrayRef stride_,
IntArrayRef padding_,
IntArrayRef dilation_,
Tensor& accumu,
double scale,
Tensor& output,
bool& is_fused,
Attr attr = Attr(),
bool force_inplace = false) {
is_fused = true;
if (scale == 0.f)
return attr;
auto ndim = input_r.ndimension();
auto output_size = conv_dst_size(
ndim,
input_r.sizes(),
weight_r.sizes(),
padding_,
padding_,
stride_,
dilation_);
MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
auto input_fmt = input_r.suggest_memory_format();
auto input_is_cl =
(input_fmt == at::MemoryFormat::ChannelsLast ||
input_fmt == at::MemoryFormat::ChannelsLast3d);
auto weight_fmt = weight_r.suggest_memory_format();
auto weight_is_cl =
(weight_fmt == at::MemoryFormat::ChannelsLast ||
weight_fmt == at::MemoryFormat::ChannelsLast3d);
bool propagate_channels_last = input_is_cl || weight_is_cl;
if (propagate_channels_last)
mem_fmt = get_cl_tag_by_ndim(ndim);
Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
if (!onednn::binary_valid(out, accumu)) {
is_fused = false;
return attr;
}
// For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
// Thus we need the following transformation
// conv(src, wei) + scale * accumu
// scale * (1/scale * conv(src, wei) + sum (or binary))
if (scale != 1.f)
attr.append_post_eltwise(
/* scale */ 1.f,
/* alpha */ 1.f / scale,
/* beta */ 0.f,
attr.kind_with_linear);
if (force_inplace) {
// If sizes are the same, post sum is used.
output = accumu;
attr.append_post_sum(/* sum_scale */ 1.f);
} else {
// If sizes are different, post binary is used.
attr.append_post_binary(attr.kind_with_binary_add, accumu);
}
if (scale != 1.f)
attr.append_post_eltwise(
/* scale */ 1.f,
/* alpha */ scale,
/* beta */ 0.f,
attr.kind_with_linear);
return attr;
}
} // namespace impl
using namespace impl;
@ -476,6 +402,8 @@ Tensor _convolution_out(
params.output_padding,
params.groups);
output = at::empty(dst_tz, input.options(), mfmt);
} else {
output = output_r;
}
onednn::deconvolution(
@ -518,6 +446,8 @@ Tensor _convolution_out(
params.stride,
params.dilation);
output = at::empty(dst_tz, input.options(), mfmt);
} else {
output = output_r;
}
onednn::convolution(
output,
@ -751,6 +681,119 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor convolution_pointwise(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view attr,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Attr att;
att = construct_unary_attr(att, attr, scalars, algorithm);
const Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
return _convolution(
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
att);
}
Tensor convolution_pointwise_binary(
const Tensor& input_t,
const Tensor& other_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view binary_attr,
std::optional<at::Scalar> alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Tensor output;
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
// Step1: Construct binary attr
Attr attr;
attr = construct_binary_attr(attr, binary_attr, other_t);
// Step2: Append unary attr
if (unary_attr.has_value())
attr = construct_unary_attr(
attr, unary_attr.value(), unary_scalars, unary_algorithm);
Tensor res = _convolution_out(
output,
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
attr);
// Step3: Run conv
return res;
}
Tensor& convolution_pointwise_binary_(
Tensor& other_t,
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
std::string_view binary_attr,
std::optional<at::Scalar> alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
c10::DeviceGuard device_guard(input_t.device());
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
// Step1: Construct binary attr
Attr attr;
attr = construct_binary_attr(attr, binary_attr, other_t);
// Step2: Append unary attr
if (unary_attr.has_value())
attr = construct_unary_attr(
attr, unary_attr.value(), unary_scalars, unary_algorithm);
_convolution_out(
other_t,
input_t,
weight_t,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
/*output_padding*/ {0},
groups,
attr);
// Step3: Run conv
return other_t;
}
TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
m.impl(
@ -758,4 +801,16 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
TORCH_FN(convolution_backward_overrideable));
}
TORCH_LIBRARY_IMPL(mkldnn, XPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
TORCH_FN(convolution_pointwise));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
TORCH_FN(convolution_pointwise_binary));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
TORCH_FN(convolution_pointwise_binary_));
}
} // namespace at::native::xpu

View File

@ -122,6 +122,165 @@ class TestoneDNNFusion(TestCase):
)
self.assertEqual(ref, fused)
def test_conv_unary_fusion_ops(self):
class M(nn.Module):
def __init__(
self,
unary_fn,
dim,
in_channels,
out_channels,
dilation,
groups,
bias,
**kwargs,
):
super().__init__()
self.conv = CONV_MODULES[dim](
in_channels,
out_channels,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs,
)
self.unary = unary_fn
def forward(self, x):
x = self.conv(x)
x = self.unary(x)
return x
input_shapes = {2: (112, 112), 3: (55, 55, 55)}
for pointwise_info in self._unary_list().values():
for dim in [2, 3]:
channels_last = (
torch.channels_last if dim == 2 else torch.channels_last_3d
)
options = itertools.product(
[True, False],
[1, 2],
[1, 4],
[torch.contiguous_format, channels_last],
)
for bias, dilation, groups, memory_format in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
x = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
mod = M(
pointwise_info.pointwise_module,
dim,
iC,
oC,
dilation,
groups,
bias,
kernel_size=3,
)
mod = mod.to(memory_format=memory_format).eval()
with torch.no_grad():
x = x.to("xpu")
mod = mod.to("xpu")
ref = mod(x)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._convolution_pointwise(
x,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
scalars,
algorithm,
)
self.assertEqual(ref, fused)
def test_conv_binary_fusion_ops(self):
class M(nn.Module):
def __init__(
self,
binary_fn,
dim,
in_channels,
out_channels,
dilation,
groups,
bias,
**kwargs,
):
super().__init__()
self.conv = CONV_MODULES[dim](
in_channels,
out_channels,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs,
)
self.binary = binary_fn
def forward(self, x, other):
x = self.conv(x)
x = self.binary(x, other)
return x
for pointwise_name, pointwise_fn in self._binary_list().items():
x = torch.randn(
(
1,
3,
112,
112,
)
).to("xpu")
mod = M(pointwise_fn, 2, 3, 3, 1, 1, True, kernel_size=3).to("xpu")
other = torch.randn_like(mod.conv(x))
with torch.no_grad():
ref = mod(x, other)
unary_attr = None
attr = pointwise_name
fused = torch.ops.mkldnn._convolution_pointwise(
x,
other,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
None,
unary_attr,
[],
None,
)
if attr == "add":
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
other,
x,
mod.conv.weight,
mod.conv.bias,
mod.conv.padding,
mod.conv.stride,
mod.conv.dilation,
mod.conv.groups,
attr,
None,
unary_attr,
[],
None,
)
self.assertEqual(ref, other)
self.assertEqual(ref, fused_inplace)
self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
instantiate_device_type_tests(
TestoneDNNFusion, globals(), only_for="xpu", allow_xpu=True