mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU] convolution fusion at XPU backend (#154202)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154202 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/etaf ghstack dependencies: #140365
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6fc11af76
commit
f12ce4e36b
@ -3,6 +3,7 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/FusionUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/ops/full.h>
|
||||
@ -309,81 +310,6 @@ static at::Tensor view3d(const at::Tensor& tensor) {
|
||||
return tensor.squeeze(2);
|
||||
}
|
||||
|
||||
Attr get_onednn_conv_sum_attr(
|
||||
const Tensor& input_r,
|
||||
const Tensor& weight_r,
|
||||
IntArrayRef stride_,
|
||||
IntArrayRef padding_,
|
||||
IntArrayRef dilation_,
|
||||
Tensor& accumu,
|
||||
double scale,
|
||||
Tensor& output,
|
||||
bool& is_fused,
|
||||
Attr attr = Attr(),
|
||||
bool force_inplace = false) {
|
||||
is_fused = true;
|
||||
if (scale == 0.f)
|
||||
return attr;
|
||||
|
||||
auto ndim = input_r.ndimension();
|
||||
auto output_size = conv_dst_size(
|
||||
ndim,
|
||||
input_r.sizes(),
|
||||
weight_r.sizes(),
|
||||
padding_,
|
||||
padding_,
|
||||
stride_,
|
||||
dilation_);
|
||||
MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
|
||||
auto input_fmt = input_r.suggest_memory_format();
|
||||
auto input_is_cl =
|
||||
(input_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
input_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
auto weight_fmt = weight_r.suggest_memory_format();
|
||||
auto weight_is_cl =
|
||||
(weight_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
weight_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
|
||||
bool propagate_channels_last = input_is_cl || weight_is_cl;
|
||||
if (propagate_channels_last)
|
||||
mem_fmt = get_cl_tag_by_ndim(ndim);
|
||||
|
||||
Tensor out = at::empty(output_size, input_r.options().memory_format(mem_fmt));
|
||||
if (!onednn::binary_valid(out, accumu)) {
|
||||
is_fused = false;
|
||||
return attr;
|
||||
}
|
||||
|
||||
// For post-sum and post-binary-add, onednn needs sum/binary scale=1.f
|
||||
// Thus we need the following transformation
|
||||
// conv(src, wei) + scale * accumu
|
||||
// scale * (1/scale * conv(src, wei) + sum (or binary))
|
||||
if (scale != 1.f)
|
||||
attr.append_post_eltwise(
|
||||
/* scale */ 1.f,
|
||||
/* alpha */ 1.f / scale,
|
||||
/* beta */ 0.f,
|
||||
attr.kind_with_linear);
|
||||
|
||||
if (force_inplace) {
|
||||
// If sizes are the same, post sum is used.
|
||||
output = accumu;
|
||||
attr.append_post_sum(/* sum_scale */ 1.f);
|
||||
} else {
|
||||
// If sizes are different, post binary is used.
|
||||
attr.append_post_binary(attr.kind_with_binary_add, accumu);
|
||||
}
|
||||
|
||||
if (scale != 1.f)
|
||||
attr.append_post_eltwise(
|
||||
/* scale */ 1.f,
|
||||
/* alpha */ scale,
|
||||
/* beta */ 0.f,
|
||||
attr.kind_with_linear);
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
using namespace impl;
|
||||
@ -476,6 +402,8 @@ Tensor _convolution_out(
|
||||
params.output_padding,
|
||||
params.groups);
|
||||
output = at::empty(dst_tz, input.options(), mfmt);
|
||||
} else {
|
||||
output = output_r;
|
||||
}
|
||||
|
||||
onednn::deconvolution(
|
||||
@ -518,6 +446,8 @@ Tensor _convolution_out(
|
||||
params.stride,
|
||||
params.dilation);
|
||||
output = at::empty(dst_tz, input.options(), mfmt);
|
||||
} else {
|
||||
output = output_r;
|
||||
}
|
||||
onednn::convolution(
|
||||
output,
|
||||
@ -751,6 +681,119 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
Tensor convolution_pointwise(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view attr,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Attr att;
|
||||
att = construct_unary_attr(att, attr, scalars, algorithm);
|
||||
const Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
|
||||
return _convolution(
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
att);
|
||||
}
|
||||
|
||||
Tensor convolution_pointwise_binary(
|
||||
const Tensor& input_t,
|
||||
const Tensor& other_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Tensor output;
|
||||
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
// Step1: Construct binary attr
|
||||
Attr attr;
|
||||
attr = construct_binary_attr(attr, binary_attr, other_t);
|
||||
// Step2: Append unary attr
|
||||
if (unary_attr.has_value())
|
||||
attr = construct_unary_attr(
|
||||
attr, unary_attr.value(), unary_scalars, unary_algorithm);
|
||||
|
||||
Tensor res = _convolution_out(
|
||||
output,
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
attr);
|
||||
|
||||
// Step3: Run conv
|
||||
return res;
|
||||
}
|
||||
|
||||
Tensor& convolution_pointwise_binary_(
|
||||
Tensor& other_t,
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
c10::DeviceGuard device_guard(input_t.device());
|
||||
Tensor bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
// Step1: Construct binary attr
|
||||
Attr attr;
|
||||
attr = construct_binary_attr(attr, binary_attr, other_t);
|
||||
|
||||
// Step2: Append unary attr
|
||||
if (unary_attr.has_value())
|
||||
attr = construct_unary_attr(
|
||||
attr, unary_attr.value(), unary_scalars, unary_algorithm);
|
||||
|
||||
_convolution_out(
|
||||
other_t,
|
||||
input_t,
|
||||
weight_t,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
/*output_padding*/ {0},
|
||||
groups,
|
||||
attr);
|
||||
|
||||
// Step3: Run conv
|
||||
return other_t;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
|
||||
m.impl(
|
||||
@ -758,4 +801,16 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
TORCH_FN(convolution_backward_overrideable));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, XPU, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"),
|
||||
TORCH_FN(convolution_pointwise));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"),
|
||||
TORCH_FN(convolution_pointwise_binary));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
|
||||
TORCH_FN(convolution_pointwise_binary_));
|
||||
}
|
||||
|
||||
} // namespace at::native::xpu
|
||||
|
@ -122,6 +122,165 @@ class TestoneDNNFusion(TestCase):
|
||||
)
|
||||
self.assertEqual(ref, fused)
|
||||
|
||||
def test_conv_unary_fusion_ops(self):
|
||||
class M(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unary_fn,
|
||||
dim,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = CONV_MODULES[dim](
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**kwargs,
|
||||
)
|
||||
self.unary = unary_fn
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.unary(x)
|
||||
return x
|
||||
|
||||
input_shapes = {2: (112, 112), 3: (55, 55, 55)}
|
||||
for pointwise_info in self._unary_list().values():
|
||||
for dim in [2, 3]:
|
||||
channels_last = (
|
||||
torch.channels_last if dim == 2 else torch.channels_last_3d
|
||||
)
|
||||
options = itertools.product(
|
||||
[True, False],
|
||||
[1, 2],
|
||||
[1, 4],
|
||||
[torch.contiguous_format, channels_last],
|
||||
)
|
||||
for bias, dilation, groups, memory_format in options:
|
||||
oC = 32 * groups
|
||||
iC = 3 * groups
|
||||
x_shape = (1, iC) + input_shapes[dim]
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(
|
||||
memory_format=memory_format
|
||||
)
|
||||
mod = M(
|
||||
pointwise_info.pointwise_module,
|
||||
dim,
|
||||
iC,
|
||||
oC,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
kernel_size=3,
|
||||
)
|
||||
mod = mod.to(memory_format=memory_format).eval()
|
||||
with torch.no_grad():
|
||||
x = x.to("xpu")
|
||||
mod = mod.to("xpu")
|
||||
ref = mod(x)
|
||||
attr = pointwise_info.attr
|
||||
scalars = pointwise_info.scalars
|
||||
algorithm = pointwise_info.algorithm
|
||||
fused = torch.ops.mkldnn._convolution_pointwise(
|
||||
x,
|
||||
mod.conv.weight,
|
||||
mod.conv.bias,
|
||||
mod.conv.padding,
|
||||
mod.conv.stride,
|
||||
mod.conv.dilation,
|
||||
mod.conv.groups,
|
||||
attr,
|
||||
scalars,
|
||||
algorithm,
|
||||
)
|
||||
self.assertEqual(ref, fused)
|
||||
|
||||
def test_conv_binary_fusion_ops(self):
|
||||
class M(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
binary_fn,
|
||||
dim,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = CONV_MODULES[dim](
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**kwargs,
|
||||
)
|
||||
self.binary = binary_fn
|
||||
|
||||
def forward(self, x, other):
|
||||
x = self.conv(x)
|
||||
x = self.binary(x, other)
|
||||
return x
|
||||
|
||||
for pointwise_name, pointwise_fn in self._binary_list().items():
|
||||
x = torch.randn(
|
||||
(
|
||||
1,
|
||||
3,
|
||||
112,
|
||||
112,
|
||||
)
|
||||
).to("xpu")
|
||||
mod = M(pointwise_fn, 2, 3, 3, 1, 1, True, kernel_size=3).to("xpu")
|
||||
other = torch.randn_like(mod.conv(x))
|
||||
with torch.no_grad():
|
||||
ref = mod(x, other)
|
||||
unary_attr = None
|
||||
attr = pointwise_name
|
||||
fused = torch.ops.mkldnn._convolution_pointwise(
|
||||
x,
|
||||
other,
|
||||
mod.conv.weight,
|
||||
mod.conv.bias,
|
||||
mod.conv.padding,
|
||||
mod.conv.stride,
|
||||
mod.conv.dilation,
|
||||
mod.conv.groups,
|
||||
attr,
|
||||
None,
|
||||
unary_attr,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
if attr == "add":
|
||||
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
|
||||
other,
|
||||
x,
|
||||
mod.conv.weight,
|
||||
mod.conv.bias,
|
||||
mod.conv.padding,
|
||||
mod.conv.stride,
|
||||
mod.conv.dilation,
|
||||
mod.conv.groups,
|
||||
attr,
|
||||
None,
|
||||
unary_attr,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
self.assertEqual(ref, other)
|
||||
self.assertEqual(ref, fused_inplace)
|
||||
self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
|
||||
|
||||
|
||||
instantiate_device_type_tests(
|
||||
TestoneDNNFusion, globals(), only_for="xpu", allow_xpu=True
|
||||
|
Reference in New Issue
Block a user