mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add stateful XNNPack deconvolution2d operator to torch. (#43233)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43233 XNNPack is already being used for the convolution2d operation. Add the ability for it to be used with transpose convolution. Test Plan: buck run caffe2/test:xnnpack_integration Reviewed By: kimishpatel Differential Revision: D23184249 fbshipit-source-id: 3fa728ce1eaca154d24e60f800d5e946d768c8b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
58a7e73a95
commit
b630c1870d
@ -268,7 +268,8 @@ auto ConvParams::use_xnnpack(
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups);
|
||||
groups,
|
||||
transposed);
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
|
@ -33,15 +33,19 @@ struct ContextLinear final {
|
||||
static constexpr float kMax = std::numeric_limits<float>::infinity();
|
||||
};
|
||||
|
||||
// This contains information for both the transpose and non-transpose cases.
|
||||
struct ContextConv2D final {
|
||||
Operator op;
|
||||
std::array<int64_t, 4> weight_size_;
|
||||
std::array<int64_t, 2> padding_;
|
||||
std::array<int64_t, 2> output_padding_;
|
||||
std::array<int64_t, 2> stride_;
|
||||
std::array<int64_t, 2> dilation_;
|
||||
const float* cached_input_ptr{nullptr};
|
||||
const float* cached_output_ptr{nullptr};
|
||||
size_t input_height{0}, input_width{0}, batch_size{0}, input_channels{0};
|
||||
bool transposed_;
|
||||
int64_t groups_;
|
||||
|
||||
ContextConv2D() = delete;
|
||||
|
||||
@ -49,13 +53,19 @@ struct ContextConv2D final {
|
||||
Operator&& o,
|
||||
std::array<int64_t, 4> weight_size,
|
||||
std::array<int64_t, 2> padding,
|
||||
std::array<int64_t, 2> output_padding,
|
||||
std::array<int64_t, 2> stride,
|
||||
std::array<int64_t, 2> dilation)
|
||||
std::array<int64_t, 2> dilation,
|
||||
bool transposed,
|
||||
int64_t groups)
|
||||
: op(std::move(o)),
|
||||
weight_size_(weight_size),
|
||||
padding_(padding),
|
||||
output_padding_(output_padding),
|
||||
stride_(stride),
|
||||
dilation_(dilation) {}
|
||||
dilation_(dilation),
|
||||
transposed_(transposed),
|
||||
groups_(groups) {}
|
||||
static constexpr float kMin = -std::numeric_limits<float>::infinity();
|
||||
static constexpr float kMax = std::numeric_limits<float>::infinity();
|
||||
};
|
||||
|
@ -31,6 +31,7 @@ bool available(
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const bool transposed,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
// XNNPACK
|
||||
@ -43,9 +44,10 @@ bool available(
|
||||
(kFloat == weight.scalar_type()) &&
|
||||
// Bias
|
||||
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
|
||||
(c10::DeviceType::CPU == bias->device().type()) &&
|
||||
(kFloat == bias->scalar_type()) &&
|
||||
(weight.size(Layout::Filter::output)) == bias->size(0))
|
||||
(c10::DeviceType::CPU == bias->device().type()) &&
|
||||
(kFloat == bias->scalar_type()) &&
|
||||
((transposed ? (weight.size(Layout::Filter::input) == (bias->size(0) / groups))
|
||||
: (weight.size(Layout::Filter::output) == (bias->size(0))))))
|
||||
: true) &&
|
||||
// Padding
|
||||
(padding[Layout::Parameter::height] >= 0) &&
|
||||
@ -88,35 +90,97 @@ Tensor create_and_run(
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef output_padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const bool transposed,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
auto op_context = create(
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
transposed,
|
||||
output_min,
|
||||
output_max);
|
||||
return run(op_context, input);
|
||||
}
|
||||
|
||||
// XNNPack's deconvolution operator expects weights to be indexed in the following order:
|
||||
// * Groups
|
||||
// * Group Output Channels
|
||||
// * Kernel Height
|
||||
// * Kernel Width
|
||||
// * Group Input Channels
|
||||
//
|
||||
// (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678)
|
||||
//
|
||||
// This function takes in a contiguous NHWC pytorch tensor (e.g. MemoryFormat == ChannelsLast) and rearranges the weights in preparation for use with xnnpack.
|
||||
// By default, for pytorch, transpose conv2d weights are {input_channels, output_Channels_per_group, kernel_height, kernel_width}.
|
||||
// In addition, it condenses the tensor from 5 to 4 dimensions as expected by the rest of the pytorch framework by combining the groups and input_channels dimension.
|
||||
const Tensor reorder_weights_for_transpose_conv(const Tensor& weight_nhwc,
|
||||
int num_groups) {
|
||||
|
||||
TORCH_CHECK(weight_nhwc.size(0) % num_groups == 0, "The number of groups cannot be satisfied by the provided weight tensor.");
|
||||
|
||||
int input_channels_per_group = weight_nhwc.size(0) / num_groups;
|
||||
int output_channels_per_group = weight_nhwc.size(1);
|
||||
int kernel_width = weight_nhwc.size(3);
|
||||
int kernel_height = weight_nhwc.size(2);
|
||||
|
||||
int o_offset = 1;
|
||||
int h_offset = (output_channels_per_group);
|
||||
int w_offset = (output_channels_per_group)*(kernel_height);
|
||||
int i_offset = (output_channels_per_group)*(kernel_height)*(kernel_width);
|
||||
int g_offset = (output_channels_per_group)*(kernel_height)*(kernel_width)*(input_channels_per_group);
|
||||
|
||||
Tensor reordered = mobile::empty_with_tail_padding(
|
||||
weight_nhwc.sizes(),
|
||||
weight_nhwc.options().dtype(),
|
||||
MemoryFormat::ChannelsLast,
|
||||
weight_nhwc.names());
|
||||
|
||||
float* out_ptr = reordered.data_ptr<float>();
|
||||
float* in_ptr = weight_nhwc.data_ptr<float>();
|
||||
|
||||
int out_index = 0;
|
||||
for (int g = 0; g < num_groups; g++) {
|
||||
for (int o = 0; o < output_channels_per_group; o++) {
|
||||
for (int w = 0; w < kernel_width; w++) {
|
||||
for (int h = 0; h < kernel_height; h++) {
|
||||
for (int i = 0; i < input_channels_per_group; i++) {
|
||||
int in_index = (g*g_offset) + (i*i_offset) + (h*h_offset) + (w*w_offset) + (o*o_offset);
|
||||
out_ptr[out_index] = in_ptr[in_index];
|
||||
out_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return reordered;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ContextConv2D create(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef output_padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const bool transposed,
|
||||
const float output_min,
|
||||
const float output_max) {
|
||||
const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
|
||||
const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2);
|
||||
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
|
||||
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
|
||||
const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast);
|
||||
@ -129,15 +193,52 @@ ContextConv2D create(
|
||||
stride_expanded,
|
||||
dilation_expanded,
|
||||
groups,
|
||||
transposed,
|
||||
output_min,
|
||||
output_max),
|
||||
"xnnpack::convolution not available! "
|
||||
"Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) "
|
||||
"Reason: The provided (weight, bias, padding, stride, dilation, groups, transposed, output_min, output_max) "
|
||||
"parameters are either invalid individually or their combination is not supported by XNNPACK.");
|
||||
|
||||
xnn_operator_t convolution_op{};
|
||||
|
||||
const xnn_status create_status = xnn_create_convolution2d_nhwc_f32(
|
||||
xnn_operator_t convolution_op{};
|
||||
xnn_status create_status;
|
||||
std::array<int64_t, 4> weight_sizes;
|
||||
|
||||
if (transposed) {
|
||||
const Tensor weight_reordered = reorder_weights_for_transpose_conv(weight_nhwc, groups);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
weight_sizes[i] = weight_reordered.size(i);
|
||||
}
|
||||
create_status = xnn_create_deconvolution2d_nhwc_f32(
|
||||
padding_expanded[Layout::Parameter::height], // output_padding_top
|
||||
padding_expanded[Layout::Parameter::width], // output_padding_right
|
||||
padding_expanded[Layout::Parameter::height], // output_padding_bottom
|
||||
padding_expanded[Layout::Parameter::width], // output_padding_left
|
||||
weight_reordered.size(Layout::Filter::height), // kernel_height
|
||||
weight_reordered.size(Layout::Filter::width), // kernel_width
|
||||
stride_expanded[Layout::Parameter::height], // subsampling_height
|
||||
stride_expanded[Layout::Parameter::width], // subsampling_width
|
||||
dilation_expanded[Layout::Parameter::height], // dilation_height
|
||||
dilation_expanded[Layout::Parameter::width], // dilation_width
|
||||
groups, // groups
|
||||
weight_reordered.size(Layout::Filter::output) / groups, // group_input_channels
|
||||
weight_reordered.size(Layout::Filter::input), // group_output_channels
|
||||
weight_reordered.size(Layout::Filter::output), // input_pixel_stride
|
||||
weight_reordered.size(Layout::Filter::input) * groups, // output_pixel_stride
|
||||
weight_reordered.data_ptr<float>(), // kernel
|
||||
(bias && bias->defined())
|
||||
? bias->contiguous().data_ptr<float>()
|
||||
: nullptr, // bias
|
||||
output_min, // output_min
|
||||
output_max, // output_max
|
||||
0u, // flags
|
||||
&convolution_op); // operator
|
||||
} else {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
weight_sizes[i] = weight_nhwc.size(i);
|
||||
}
|
||||
create_status = xnn_create_convolution2d_nhwc_f32(
|
||||
padding_expanded[Layout::Parameter::height], // input_padding_top
|
||||
padding_expanded[Layout::Parameter::width], // input_padding_right
|
||||
padding_expanded[Layout::Parameter::height], // input_padding_bottom
|
||||
@ -161,18 +262,21 @@ ContextConv2D create(
|
||||
output_max, // output_max
|
||||
0u, // flags
|
||||
&convolution_op); // operator
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == create_status,
|
||||
"xnn_create_convolution2d_nhwc_f32 failed!");
|
||||
(transposed ? "xnn_create_deconvolution2d_nhwc_f32 failed!"
|
||||
: "xnn_create_convolution2d_nhwc_f32 failed!"));
|
||||
|
||||
return ContextConv2D{
|
||||
Operator(convolution_op),
|
||||
{weight_nhwc.sizes()[0], weight_nhwc.sizes()[1],
|
||||
weight_nhwc.sizes()[2], weight_nhwc.sizes()[3]},
|
||||
weight_sizes,
|
||||
{padding_expanded[0], padding_expanded[1]},
|
||||
{output_padding_expanded[0], output_padding_expanded[1]},
|
||||
{stride_expanded[0], stride_expanded[1]},
|
||||
{dilation_expanded[0], dilation_expanded[1]}
|
||||
{dilation_expanded[0], dilation_expanded[1]},
|
||||
transposed, groups
|
||||
};
|
||||
}
|
||||
|
||||
@ -189,7 +293,21 @@ Tensor run(
|
||||
"XNNPACK Convolution not usable! "
|
||||
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
|
||||
|
||||
Tensor output = mobile::empty_with_tail_padding(
|
||||
Tensor output;
|
||||
if (context.transposed_) {
|
||||
output = mobile::empty_with_tail_padding(
|
||||
conv_input_size(padded_input_nhwc.sizes(),
|
||||
context.weight_size_,
|
||||
context.padding_,
|
||||
context.output_padding_,
|
||||
context.stride_,
|
||||
context.dilation_,
|
||||
context.groups_),
|
||||
padded_input_nhwc.options().dtype(),
|
||||
MemoryFormat::ChannelsLast,
|
||||
padded_input_nhwc.names());
|
||||
} else {
|
||||
output = mobile::empty_with_tail_padding(
|
||||
conv_output_size(
|
||||
padded_input_nhwc.sizes(),
|
||||
context.weight_size_,
|
||||
@ -199,7 +317,9 @@ Tensor run(
|
||||
padded_input_nhwc.options().dtype(),
|
||||
MemoryFormat::ChannelsLast,
|
||||
padded_input_nhwc.names());
|
||||
}
|
||||
|
||||
xnn_status setup_status;
|
||||
if ((context.cached_input_ptr != padded_input_nhwc.data_ptr<float>()) ||
|
||||
(context.cached_output_ptr != output.data_ptr<float>()) ||
|
||||
(padded_input_nhwc.size(Layout::Activation4D::batch) !=
|
||||
@ -211,26 +331,42 @@ Tensor run(
|
||||
(padded_input_nhwc.size(Layout::Activation4D::width) !=
|
||||
context.input_width)
|
||||
) {
|
||||
const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32(
|
||||
context.op.get(), // operator
|
||||
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
|
||||
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
|
||||
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
|
||||
padded_input_nhwc.data_ptr<float>(), // input
|
||||
output.data_ptr<float>(), // output
|
||||
caffe2::pthreadpool_()); // threadpool
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == setup_status,
|
||||
"xnn_setup_convolution2d_nhwc_f32 failed!");
|
||||
if (context.transposed_) {
|
||||
setup_status = xnn_setup_deconvolution2d_nhwc_f32(
|
||||
context.op.get(), // operator
|
||||
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
|
||||
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
|
||||
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
|
||||
context.output_padding_[0], // adjustment_height
|
||||
context.output_padding_[1], // adjustment_width
|
||||
padded_input_nhwc.data_ptr<float>(), // input
|
||||
output.data_ptr<float>(), // output
|
||||
caffe2::pthreadpool_()); // threadpool
|
||||
|
||||
// Cache values to avoid setup for the next round.
|
||||
context.cached_input_ptr = padded_input_nhwc.data_ptr<float>();
|
||||
context.cached_output_ptr = output.data_ptr<float>();
|
||||
context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch);
|
||||
context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels);
|
||||
context.input_height = padded_input_nhwc.size(Layout::Activation4D::height);
|
||||
context.input_width = padded_input_nhwc.size(Layout::Activation4D::width);
|
||||
} else {
|
||||
setup_status = xnn_setup_convolution2d_nhwc_f32(
|
||||
context.op.get(), // operator
|
||||
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
|
||||
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
|
||||
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
|
||||
padded_input_nhwc.data_ptr<float>(), // input
|
||||
output.data_ptr<float>(), // output
|
||||
caffe2::pthreadpool_());
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
xnn_status_success == setup_status,
|
||||
(context.transposed_ ? "xnn_setup_deconvolution2d_nhwc_f32 failed!"
|
||||
: "xnn_setup_convolution2d_nhwc_f32 failed!"));
|
||||
|
||||
// Cache values to avoid setup for the next round
|
||||
context.cached_input_ptr = padded_input_nhwc.data_ptr<float>();
|
||||
context.cached_output_ptr = output.data_ptr<float>();
|
||||
context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch);
|
||||
context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels);
|
||||
context.input_height = padded_input_nhwc.size(Layout::Activation4D::height);
|
||||
context.input_width = padded_input_nhwc.size(Layout::Activation4D::width);
|
||||
}
|
||||
|
||||
const xnn_status run_status = xnn_run_operator(
|
||||
@ -265,12 +401,41 @@ c10::intrusive_ptr<xnnpack::Conv2dOpContext>
|
||||
output_max);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
|
||||
createConv2dTransposeClampPrePackOpContext(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
std::vector<int64_t> stride,
|
||||
std::vector<int64_t> padding,
|
||||
std::vector<int64_t> output_padding,
|
||||
std::vector<int64_t> dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max) {
|
||||
return xnnpack::XNNPackTransposeConv2dOpContext::create_context(
|
||||
std::move(weight),
|
||||
std::move(bias),
|
||||
std::move(padding),
|
||||
std::move(output_padding),
|
||||
std::move(stride),
|
||||
std::move(dilation),
|
||||
groups,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
|
||||
Tensor conv2d_clamp_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context) {
|
||||
return op_context->run(input);
|
||||
}
|
||||
|
||||
Tensor conv2d_transpose_clamp_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context) {
|
||||
return op_context->run(input);
|
||||
}
|
||||
|
||||
} // namespace convolution2d
|
||||
} // namespace internal
|
||||
|
||||
@ -281,7 +446,8 @@ bool use_convolution2d(
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups) {
|
||||
const int64_t groups,
|
||||
const bool transposed) {
|
||||
return internal::convolution2d::available(
|
||||
weight,
|
||||
bias,
|
||||
@ -289,6 +455,7 @@ bool use_convolution2d(
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
transposed,
|
||||
ContextConv2D::kMin,
|
||||
ContextConv2D::kMax) &&
|
||||
internal::convolution2d::usable(input);
|
||||
@ -307,9 +474,11 @@ Tensor convolution2d(
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
{0, 0}, // output_padding
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
false, // transposed
|
||||
ContextConv2D::kMin,
|
||||
ContextConv2D::kMax);
|
||||
}
|
||||
|
@ -23,17 +23,35 @@ c10::intrusive_ptr<xnnpack::Conv2dOpContext>
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max);
|
||||
|
||||
c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
|
||||
createConv2dTransposeClampPrePackOpContext(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
std::vector<int64_t> stride,
|
||||
std::vector<int64_t> padding,
|
||||
std::vector<int64_t> output_padding,
|
||||
std::vector<int64_t> dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max);
|
||||
|
||||
Tensor conv2d_clamp_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context);
|
||||
|
||||
Tensor conv2d_transpose_clamp_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context);
|
||||
|
||||
ContextConv2D create(
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor>& bias,
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef output_padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups,
|
||||
const bool transposed,
|
||||
const float output_min,
|
||||
const float output_max);
|
||||
|
||||
|
@ -17,7 +17,8 @@ bool use_convolution2d(
|
||||
const IntArrayRef padding,
|
||||
const IntArrayRef stride,
|
||||
const IntArrayRef dilation,
|
||||
const int64_t groups);
|
||||
const int64_t groups,
|
||||
const bool transposed);
|
||||
|
||||
Tensor convolution2d(
|
||||
const Tensor& input,
|
||||
|
@ -48,13 +48,16 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
{0, 0}, // output_padding
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
false, // transposed
|
||||
output_min ? output_min->to<float>()
|
||||
: xnnpack::ContextConv2D::kMin,
|
||||
output_max ? output_max->to<float>()
|
||||
: xnnpack::ContextConv2D::kMax);
|
||||
|
||||
auto conv2d_op_context =
|
||||
c10::make_intrusive<XNNPackConv2dOpContext>(
|
||||
std::move(weight),
|
||||
@ -66,6 +69,48 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
|
||||
output_min,
|
||||
output_max,
|
||||
std::move(op_context));
|
||||
|
||||
return conv2d_op_context;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TransposeConv2dOpContext>
|
||||
XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight,
|
||||
c10::optional<at::Tensor>&& bias,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& output_padding,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& dilation,
|
||||
int64_t groups,
|
||||
const c10::optional<Scalar> output_min,
|
||||
const c10::optional<Scalar> output_max) {
|
||||
auto op_context =
|
||||
xnnpack::internal::convolution2d::create(
|
||||
weight,
|
||||
bias,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
true, // transposed
|
||||
output_min ? output_min->to<float>()
|
||||
: xnnpack::ContextConv2D::kMin,
|
||||
output_max ? output_max->to<float>()
|
||||
: xnnpack::ContextConv2D::kMax);
|
||||
|
||||
auto conv2d_op_context =
|
||||
c10::make_intrusive<XNNPackTransposeConv2dOpContext>(
|
||||
std::move(weight),
|
||||
std::move(bias),
|
||||
std::move(padding),
|
||||
std::move(output_padding),
|
||||
std::move(stride),
|
||||
std::move(dilation),
|
||||
groups,
|
||||
output_min,
|
||||
output_max,
|
||||
std::move(op_context));
|
||||
|
||||
return conv2d_op_context;
|
||||
}
|
||||
|
||||
@ -73,6 +118,10 @@ Tensor XNNPackConv2dOpContext::run(const Tensor& input) {
|
||||
return xnnpack::internal::convolution2d::run(op_context_, input);
|
||||
}
|
||||
|
||||
Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) {
|
||||
return xnnpack::internal::convolution2d::run(op_context_, input);
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -24,6 +24,18 @@ using SerializationTypeConv2dPrePack = std::tuple<
|
||||
int64_t,
|
||||
c10::optional<Scalar>,
|
||||
c10::optional<Scalar>>;
|
||||
using SerializationTypeTransposeConv2dPrePack = std::tuple<
|
||||
Tensor,
|
||||
c10::optional<Tensor>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int64_t,
|
||||
c10::optional<Scalar>,
|
||||
c10::optional<Scalar>>;
|
||||
|
||||
|
||||
|
||||
class LinearOpContext : public torch::jit::CustomClassHolder {
|
||||
protected:
|
||||
@ -94,6 +106,35 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder {
|
||||
virtual Tensor run(const Tensor& input) = 0;
|
||||
};
|
||||
|
||||
class TransposeConv2dOpContext : public torch::jit::CustomClassHolder {
|
||||
protected:
|
||||
Tensor orig_weight_;
|
||||
c10::optional<Tensor> orig_bias_;
|
||||
std::vector<int64_t> stride_;
|
||||
std::vector<int64_t> padding_;
|
||||
std::vector<int64_t> output_padding_;
|
||||
std::vector<int64_t> dilation_;
|
||||
int64_t groups_;
|
||||
c10::optional<Scalar> output_min_;
|
||||
c10::optional<Scalar> output_max_;
|
||||
|
||||
public:
|
||||
SerializationTypeTransposeConv2dPrePack unpack() {
|
||||
return std::make_tuple(
|
||||
orig_weight_,
|
||||
orig_bias_,
|
||||
stride_,
|
||||
padding_,
|
||||
output_padding_,
|
||||
dilation_,
|
||||
groups_,
|
||||
output_min_,
|
||||
output_max_);
|
||||
}
|
||||
|
||||
virtual Tensor run(const Tensor& input) = 0;
|
||||
};
|
||||
|
||||
class XNNPackConv2dOpContext final : public Conv2dOpContext {
|
||||
private:
|
||||
ContextConv2D op_context_;
|
||||
@ -120,7 +161,7 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext {
|
||||
output_max_ = max;
|
||||
}
|
||||
|
||||
Tensor run(const Tensor& input);
|
||||
Tensor run(const Tensor& input) override;
|
||||
|
||||
static c10::intrusive_ptr<Conv2dOpContext> create_context(
|
||||
Tensor&& weight,
|
||||
@ -132,6 +173,49 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext {
|
||||
const c10::optional<Scalar> output_min,
|
||||
const c10::optional<Scalar> output_max);
|
||||
};
|
||||
|
||||
class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext {
|
||||
private:
|
||||
ContextConv2D op_context_;
|
||||
|
||||
public:
|
||||
XNNPackTransposeConv2dOpContext(
|
||||
Tensor&& weight,
|
||||
c10::optional<Tensor>&& bias,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& output_padding,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& dilation,
|
||||
uint64_t groups,
|
||||
c10::optional<Scalar> min,
|
||||
c10::optional<Scalar> max,
|
||||
ContextConv2D&& op_context)
|
||||
: op_context_(std::move(op_context)) {
|
||||
orig_weight_ = std::move(weight);
|
||||
orig_bias_ = std::move(bias);
|
||||
padding_ = std::move(padding);
|
||||
output_padding_ = std::move(output_padding);
|
||||
stride_ = std::move(stride);
|
||||
dilation_ = std::move(dilation);
|
||||
groups_ = groups;
|
||||
output_min_ = min;
|
||||
output_max_ = max;
|
||||
}
|
||||
|
||||
Tensor run(const Tensor& input) override;
|
||||
|
||||
static c10::intrusive_ptr<TransposeConv2dOpContext> create_context(
|
||||
Tensor&& weight,
|
||||
c10::optional<Tensor>&& bias,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& output_padding,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& dilation,
|
||||
int64_t groups,
|
||||
const c10::optional<Scalar> output_min,
|
||||
const c10::optional<Scalar> output_max);
|
||||
};
|
||||
|
||||
} // namespace xnnpack
|
||||
|
||||
} // namespace native
|
||||
|
@ -13,6 +13,7 @@ namespace xnnpack {
|
||||
|
||||
using internal::linear::createLinearClampPrePackOpContext;
|
||||
using internal::convolution2d::createConv2dClampPrePackOpContext;
|
||||
using internal::convolution2d::createConv2dTransposeClampPrePackOpContext;
|
||||
|
||||
TORCH_LIBRARY(xnnpack, m) {
|
||||
m.class_<LinearOpContext>("LinearOpContext")
|
||||
@ -48,20 +49,45 @@ TORCH_LIBRARY(xnnpack, m) {
|
||||
std::move(std::get<6>(state)),
|
||||
std::move(std::get<7>(state)));
|
||||
});
|
||||
|
||||
m.class_<TransposeConv2dOpContext>("TransposeConv2dOpContext")
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<TransposeConv2dOpContext>& op_context)
|
||||
-> SerializationTypeTransposeConv2dPrePack { // __getstate__
|
||||
return op_context->unpack();
|
||||
},
|
||||
[](SerializationTypeTransposeConv2dPrePack state)
|
||||
-> c10::intrusive_ptr<TransposeConv2dOpContext> { // __setstate__
|
||||
return createConv2dTransposeClampPrePackOpContext(
|
||||
std::move(std::get<0>(state)),
|
||||
std::move(std::get<1>(state)),
|
||||
std::move(std::get<2>(state)),
|
||||
std::move(std::get<3>(state)),
|
||||
std::move(std::get<4>(state)),
|
||||
std::move(std::get<5>(state)),
|
||||
std::move(std::get<6>(state)),
|
||||
std::move(std::get<7>(state)),
|
||||
std::move(std::get<8>(state)));
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(prepacked, m) {
|
||||
m.def("linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext");
|
||||
m.def("linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y");
|
||||
m.def("conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext");
|
||||
m.def("conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext");
|
||||
m.def("conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y");
|
||||
m.def("conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(prepacked, CPU, m) {
|
||||
m.impl("linear_clamp_prepack", TORCH_FN(createLinearClampPrePackOpContext));
|
||||
m.impl("linear_clamp_run", TORCH_FN(internal::linear::linear_clamp_run));
|
||||
m.impl("conv2d_clamp_prepack", TORCH_FN(createConv2dClampPrePackOpContext));
|
||||
m.impl("conv2d_transpose_clamp_prepack", TORCH_FN(createConv2dTransposeClampPrePackOpContext));
|
||||
m.impl("conv2d_clamp_run", TORCH_FN(internal::convolution2d::conv2d_clamp_run));
|
||||
m.impl("conv2d_transpose_clamp_run", TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run));
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
|
@ -35,7 +35,8 @@ bool use_convolution2d(
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const IntArrayRef,
|
||||
const int64_t) {
|
||||
const int64_t,
|
||||
bool) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -93,6 +93,72 @@ class TestXNNPACKOps(TestCase):
|
||||
xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
@given(batch_size=st.integers(1, 3),
|
||||
input_channels_per_group=st.integers(1, 32),
|
||||
height=st.integers(5, 64),
|
||||
width=st.integers(5, 64),
|
||||
output_channels_per_group=st.integers(1, 32),
|
||||
groups=st.integers(1, 16),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
output_pad_h=st.integers(0, 2),
|
||||
output_pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d_transpose(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
output_pad_h,
|
||||
output_pad_w,
|
||||
dilation,
|
||||
use_bias,
|
||||
format):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
output_paddings = (output_pad_h, output_pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
assume(height + 2 * paddings[0]
|
||||
>= dilations[0] * (kernels[0] - 1) + 1)
|
||||
assume(width + 2 * paddings[1]
|
||||
>= dilations[1] * (kernels[1] - 1) + 1)
|
||||
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
|
||||
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((output_channels))
|
||||
|
||||
# Note that groups/dilation is in reverse order from conv2d
|
||||
ref_result = F.conv_transpose2d(input_data, weight, bias,
|
||||
strides, paddings, output_paddings, groups, dilation)
|
||||
packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
|
||||
strides, paddings,
|
||||
output_paddings, dilations,
|
||||
groups)
|
||||
xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias)
|
||||
torch.testing.assert_allclose(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
@ -244,6 +310,114 @@ class TestXNNPACKSerDes(TestCase):
|
||||
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
input_channels_per_group=st.integers(1, 32),
|
||||
height=st.integers(5, 64),
|
||||
width=st.integers(5, 64),
|
||||
output_channels_per_group=st.integers(1, 32),
|
||||
groups=st.integers(1, 16),
|
||||
kernel_h=st.integers(1, 7),
|
||||
kernel_w=st.integers(1, 7),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
output_pad_h=st.integers(0, 2),
|
||||
output_pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 2),
|
||||
use_bias=st.booleans(),
|
||||
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
|
||||
def test_conv2d_transpose(self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
output_pad_h,
|
||||
output_pad_w,
|
||||
dilation,
|
||||
use_bias,
|
||||
format):
|
||||
class Conv2DT(torch.nn.Module):
|
||||
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
|
||||
super(Conv2DT, self).__init__()
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.output_paddings = output_paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv_transpose2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
|
||||
|
||||
class Conv2DTPrePacked(torch.nn.Module):
|
||||
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
|
||||
super(Conv2DTPrePacked, self).__init__()
|
||||
self.packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
|
||||
strides, paddings,
|
||||
output_paddings,
|
||||
dilations, groups)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.prepacked.conv2d_transpose_clamp_run(x, self.packed_weight_bias)
|
||||
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
output_paddings = (output_pad_h, output_pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
assume(height + 2 * paddings[0] >=
|
||||
dilations[0] * (kernels[0] - 1) + 1)
|
||||
assume(width + 2 * paddings[1] >=
|
||||
dilations[1] * (kernels[1] - 1) + 1)
|
||||
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
|
||||
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((output_channels))
|
||||
|
||||
scripted_conv2d = torch.jit.script(Conv2DT(weight, bias,
|
||||
strides, paddings,
|
||||
output_paddings, dilations, groups))
|
||||
scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DTPrePacked(
|
||||
weight, bias, strides, paddings, output_paddings, dilations, groups))
|
||||
ref_result = scripted_conv2d(input_data)
|
||||
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
# Serialize the modules and then deserialize
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
if (format is not None):
|
||||
input_data = input_data.contiguous(memory_format=format)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_conv2d, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_conv2d = torch.jit.load(buffer)
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
|
||||
ref_result = deserialized_conv2d(input_data)
|
||||
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
|
||||
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
@given(batch_size=st.integers(0, 3),
|
||||
input_channels_per_group=st.integers(1, 32),
|
||||
height=st.integers(5, 64),
|
||||
@ -454,14 +628,17 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
kernel_h = kernel_w = 3
|
||||
stride_h = stride_w = 1
|
||||
pad_h = pad_w = 1
|
||||
output_pad_h = output_pad_w = 0
|
||||
dilation = 1
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
output_paddings = (output_pad_h, output_pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
|
||||
conv_transpose_weight_shape = (input_channels, output_channels_per_group, kernel_h, kernel_w)
|
||||
conv_bias_shape = (output_channels)
|
||||
|
||||
class Conv2D(torch.nn.Module):
|
||||
@ -478,12 +655,34 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
return F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
|
||||
class Conv2DT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2DT, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_transpose_weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.output_paddings = output_paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv_transpose2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
|
||||
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"prepacked::conv2d_clamp_prepack": 1,
|
||||
"prepacked::conv2d_clamp_run": 1}
|
||||
TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
||||
|
||||
transpose_data_shape = (batch_size, input_channels, height, width)
|
||||
transpose_pattern_count_map = {"Tensor = aten::conv_transpose2d": -1,
|
||||
"prepacked::conv2d_transpose_clamp_prepack": 1,
|
||||
"prepacked::conv2d_transpose_clamp_run": 1}
|
||||
TestXNNPACKRewritePass.validate_transformed_module(Conv2DT(), transpose_pattern_count_map, data_shape)
|
||||
|
||||
input_data = torch.rand((batch_size, input_channels, height, width))
|
||||
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
|
||||
conv_bias = torch.rand((output_channels))
|
||||
|
@ -78,6 +78,13 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_transpose = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv1d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -124,6 +131,22 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
(calc_value_map["output_padding"].toIntList()[0] == 0) &&
|
||||
(calc_value_map["output_padding"].toIntList()[1] == 0);
|
||||
};
|
||||
auto filter_conv2d_transpose =
|
||||
[](const Match& match,
|
||||
const std::unordered_map<std::string, Value*>& vmap) {
|
||||
auto calc_value_map = getConvParams(match, vmap);
|
||||
if (calc_value_map["output_padding"].toIntList().size() != 2 ||
|
||||
calc_value_map["stride"].toIntList().size() != 2 ||
|
||||
calc_value_map["padding"].toIntList().size() != 2 ||
|
||||
calc_value_map["dilation"].toIntList().size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return calc_value_map["transposed"].toBool() &&
|
||||
!calc_value_map["benchmark"].toBool() &&
|
||||
!calc_value_map["deterministic"].toBool() &&
|
||||
calc_value_map["cudnn_enabled"].toBool();
|
||||
};
|
||||
auto filter_conv3d = [](const Match& match,
|
||||
const std::unordered_map<std::string, Value*>& vmap) {
|
||||
auto calc_value_map = getConvParams(match, vmap);
|
||||
@ -148,6 +171,10 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter_conv2d;
|
||||
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
|
||||
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
|
||||
SubgraphRewriter rewriter_conv2d_transpose;
|
||||
rewriter_conv2d_transpose.RegisterRewritePattern(
|
||||
convolution, conv2d_transpose);
|
||||
rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose);
|
||||
SubgraphRewriter rewriter_conv3d;
|
||||
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
|
||||
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
|
||||
|
@ -143,6 +143,26 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv_2d_pattern, prepacked_ops_conv2d_pattern);
|
||||
rewriter.runOnGraph(graph);
|
||||
|
||||
std::string conv_2d_transpose_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%output_padding:int[], %groups:int):
|
||||
%r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
|
||||
return (%r) )";
|
||||
|
||||
std::string prepacked_ops_conv2d_transpose_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
|
||||
%output_min_max : None = prim::Constant()
|
||||
%packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
|
||||
%output_min_max, %output_min_max)
|
||||
%r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
SubgraphRewriter transpose_rewriter;
|
||||
transpose_rewriter.RegisterRewritePattern(
|
||||
conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
|
||||
transpose_rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
@ -321,7 +341,11 @@ void FoldPrePackingOps(script::Module& m) {
|
||||
return (
|
||||
(n->kind() ==
|
||||
Symbol::fromQualString("prepacked::linear_clamp_prepack")) ||
|
||||
n->kind() == Symbol::fromQualString("prepacked::conv2d_clamp_prepack"));
|
||||
n->kind() ==
|
||||
Symbol::fromQualString("prepacked::conv2d_clamp_prepack") ||
|
||||
n->kind() ==
|
||||
Symbol::fromQualString(
|
||||
"prepacked::conv2d_transpose_clamp_prepack"));
|
||||
};
|
||||
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
|
||||
}
|
||||
@ -397,7 +421,7 @@ script::Module optimizeForMobile(
|
||||
const std::set<MobileOptimizerType>& blocklist,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Mobile optimizaiton only available with XNNPACK at the moment. "
|
||||
"Mobile optimization only available with XNNPACK at the moment. "
|
||||
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
|
||||
return module;
|
||||
}
|
||||
|
Reference in New Issue
Block a user