Add stateful XNNPack deconvolution2d operator to torch. (#43233)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43233

XNNPack is already being used for the convolution2d operation. Add the
ability for it to be used with transpose convolution.

Test Plan: buck run caffe2/test:xnnpack_integration

Reviewed By: kimishpatel

Differential Revision: D23184249

fbshipit-source-id: 3fa728ce1eaca154d24e60f800d5e946d768c8b7
This commit is contained in:
mspryn@fb.com
2020-08-28 10:29:54 -07:00
committed by Facebook GitHub Bot
parent 58a7e73a95
commit b630c1870d
12 changed files with 647 additions and 38 deletions

View File

@ -268,7 +268,8 @@ auto ConvParams::use_xnnpack(
padding,
stride,
dilation,
groups);
groups,
transposed);
}
#endif
return false;

View File

@ -33,15 +33,19 @@ struct ContextLinear final {
static constexpr float kMax = std::numeric_limits<float>::infinity();
};
// This contains information for both the transpose and non-transpose cases.
struct ContextConv2D final {
Operator op;
std::array<int64_t, 4> weight_size_;
std::array<int64_t, 2> padding_;
std::array<int64_t, 2> output_padding_;
std::array<int64_t, 2> stride_;
std::array<int64_t, 2> dilation_;
const float* cached_input_ptr{nullptr};
const float* cached_output_ptr{nullptr};
size_t input_height{0}, input_width{0}, batch_size{0}, input_channels{0};
bool transposed_;
int64_t groups_;
ContextConv2D() = delete;
@ -49,13 +53,19 @@ struct ContextConv2D final {
Operator&& o,
std::array<int64_t, 4> weight_size,
std::array<int64_t, 2> padding,
std::array<int64_t, 2> output_padding,
std::array<int64_t, 2> stride,
std::array<int64_t, 2> dilation)
std::array<int64_t, 2> dilation,
bool transposed,
int64_t groups)
: op(std::move(o)),
weight_size_(weight_size),
padding_(padding),
output_padding_(output_padding),
stride_(stride),
dilation_(dilation) {}
dilation_(dilation),
transposed_(transposed),
groups_(groups) {}
static constexpr float kMin = -std::numeric_limits<float>::infinity();
static constexpr float kMax = std::numeric_limits<float>::infinity();
};

View File

@ -31,6 +31,7 @@ bool available(
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed,
const float output_min,
const float output_max) {
// XNNPACK
@ -43,9 +44,10 @@ bool available(
(kFloat == weight.scalar_type()) &&
// Bias
((bias && bias->defined()) ? ((1 == bias->ndimension()) &&
(c10::DeviceType::CPU == bias->device().type()) &&
(kFloat == bias->scalar_type()) &&
(weight.size(Layout::Filter::output)) == bias->size(0))
(c10::DeviceType::CPU == bias->device().type()) &&
(kFloat == bias->scalar_type()) &&
((transposed ? (weight.size(Layout::Filter::input) == (bias->size(0) / groups))
: (weight.size(Layout::Filter::output) == (bias->size(0))))))
: true) &&
// Padding
(padding[Layout::Parameter::height] >= 0) &&
@ -88,35 +90,97 @@ Tensor create_and_run(
const Tensor& weight,
const Tensor& bias,
const IntArrayRef padding,
const IntArrayRef output_padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed,
const float output_min,
const float output_max) {
auto op_context = create(
weight,
bias,
padding,
output_padding,
stride,
dilation,
groups,
transposed,
output_min,
output_max);
return run(op_context, input);
}
// XNNPack's deconvolution operator expects weights to be indexed in the following order:
// * Groups
// * Group Output Channels
// * Kernel Height
// * Kernel Width
// * Group Input Channels
//
// (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678)
//
// This function takes in a contiguous NHWC pytorch tensor (e.g. MemoryFormat == ChannelsLast) and rearranges the weights in preparation for use with xnnpack.
// By default, for pytorch, transpose conv2d weights are {input_channels, output_Channels_per_group, kernel_height, kernel_width}.
// In addition, it condenses the tensor from 5 to 4 dimensions as expected by the rest of the pytorch framework by combining the groups and input_channels dimension.
const Tensor reorder_weights_for_transpose_conv(const Tensor& weight_nhwc,
int num_groups) {
TORCH_CHECK(weight_nhwc.size(0) % num_groups == 0, "The number of groups cannot be satisfied by the provided weight tensor.");
int input_channels_per_group = weight_nhwc.size(0) / num_groups;
int output_channels_per_group = weight_nhwc.size(1);
int kernel_width = weight_nhwc.size(3);
int kernel_height = weight_nhwc.size(2);
int o_offset = 1;
int h_offset = (output_channels_per_group);
int w_offset = (output_channels_per_group)*(kernel_height);
int i_offset = (output_channels_per_group)*(kernel_height)*(kernel_width);
int g_offset = (output_channels_per_group)*(kernel_height)*(kernel_width)*(input_channels_per_group);
Tensor reordered = mobile::empty_with_tail_padding(
weight_nhwc.sizes(),
weight_nhwc.options().dtype(),
MemoryFormat::ChannelsLast,
weight_nhwc.names());
float* out_ptr = reordered.data_ptr<float>();
float* in_ptr = weight_nhwc.data_ptr<float>();
int out_index = 0;
for (int g = 0; g < num_groups; g++) {
for (int o = 0; o < output_channels_per_group; o++) {
for (int w = 0; w < kernel_width; w++) {
for (int h = 0; h < kernel_height; h++) {
for (int i = 0; i < input_channels_per_group; i++) {
int in_index = (g*g_offset) + (i*i_offset) + (h*h_offset) + (w*w_offset) + (o*o_offset);
out_ptr[out_index] = in_ptr[in_index];
out_index++;
}
}
}
}
}
return reordered;
}
} // namespace
ContextConv2D create(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const IntArrayRef padding,
const IntArrayRef output_padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed,
const float output_min,
const float output_max) {
const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast);
@ -129,15 +193,52 @@ ContextConv2D create(
stride_expanded,
dilation_expanded,
groups,
transposed,
output_min,
output_max),
"xnnpack::convolution not available! "
"Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) "
"Reason: The provided (weight, bias, padding, stride, dilation, groups, transposed, output_min, output_max) "
"parameters are either invalid individually or their combination is not supported by XNNPACK.");
xnn_operator_t convolution_op{};
const xnn_status create_status = xnn_create_convolution2d_nhwc_f32(
xnn_operator_t convolution_op{};
xnn_status create_status;
std::array<int64_t, 4> weight_sizes;
if (transposed) {
const Tensor weight_reordered = reorder_weights_for_transpose_conv(weight_nhwc, groups);
for (int i = 0; i < 4; i++) {
weight_sizes[i] = weight_reordered.size(i);
}
create_status = xnn_create_deconvolution2d_nhwc_f32(
padding_expanded[Layout::Parameter::height], // output_padding_top
padding_expanded[Layout::Parameter::width], // output_padding_right
padding_expanded[Layout::Parameter::height], // output_padding_bottom
padding_expanded[Layout::Parameter::width], // output_padding_left
weight_reordered.size(Layout::Filter::height), // kernel_height
weight_reordered.size(Layout::Filter::width), // kernel_width
stride_expanded[Layout::Parameter::height], // subsampling_height
stride_expanded[Layout::Parameter::width], // subsampling_width
dilation_expanded[Layout::Parameter::height], // dilation_height
dilation_expanded[Layout::Parameter::width], // dilation_width
groups, // groups
weight_reordered.size(Layout::Filter::output) / groups, // group_input_channels
weight_reordered.size(Layout::Filter::input), // group_output_channels
weight_reordered.size(Layout::Filter::output), // input_pixel_stride
weight_reordered.size(Layout::Filter::input) * groups, // output_pixel_stride
weight_reordered.data_ptr<float>(), // kernel
(bias && bias->defined())
? bias->contiguous().data_ptr<float>()
: nullptr, // bias
output_min, // output_min
output_max, // output_max
0u, // flags
&convolution_op); // operator
} else {
for (int i = 0; i < 4; i++) {
weight_sizes[i] = weight_nhwc.size(i);
}
create_status = xnn_create_convolution2d_nhwc_f32(
padding_expanded[Layout::Parameter::height], // input_padding_top
padding_expanded[Layout::Parameter::width], // input_padding_right
padding_expanded[Layout::Parameter::height], // input_padding_bottom
@ -161,18 +262,21 @@ ContextConv2D create(
output_max, // output_max
0u, // flags
&convolution_op); // operator
}
TORCH_CHECK(
xnn_status_success == create_status,
"xnn_create_convolution2d_nhwc_f32 failed!");
(transposed ? "xnn_create_deconvolution2d_nhwc_f32 failed!"
: "xnn_create_convolution2d_nhwc_f32 failed!"));
return ContextConv2D{
Operator(convolution_op),
{weight_nhwc.sizes()[0], weight_nhwc.sizes()[1],
weight_nhwc.sizes()[2], weight_nhwc.sizes()[3]},
weight_sizes,
{padding_expanded[0], padding_expanded[1]},
{output_padding_expanded[0], output_padding_expanded[1]},
{stride_expanded[0], stride_expanded[1]},
{dilation_expanded[0], dilation_expanded[1]}
{dilation_expanded[0], dilation_expanded[1]},
transposed, groups
};
}
@ -189,7 +293,21 @@ Tensor run(
"XNNPACK Convolution not usable! "
"Reason: The provided input tensor is either invalid or unsupported by XNNPACK.");
Tensor output = mobile::empty_with_tail_padding(
Tensor output;
if (context.transposed_) {
output = mobile::empty_with_tail_padding(
conv_input_size(padded_input_nhwc.sizes(),
context.weight_size_,
context.padding_,
context.output_padding_,
context.stride_,
context.dilation_,
context.groups_),
padded_input_nhwc.options().dtype(),
MemoryFormat::ChannelsLast,
padded_input_nhwc.names());
} else {
output = mobile::empty_with_tail_padding(
conv_output_size(
padded_input_nhwc.sizes(),
context.weight_size_,
@ -199,7 +317,9 @@ Tensor run(
padded_input_nhwc.options().dtype(),
MemoryFormat::ChannelsLast,
padded_input_nhwc.names());
}
xnn_status setup_status;
if ((context.cached_input_ptr != padded_input_nhwc.data_ptr<float>()) ||
(context.cached_output_ptr != output.data_ptr<float>()) ||
(padded_input_nhwc.size(Layout::Activation4D::batch) !=
@ -211,26 +331,42 @@ Tensor run(
(padded_input_nhwc.size(Layout::Activation4D::width) !=
context.input_width)
) {
const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32(
context.op.get(), // operator
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
padded_input_nhwc.data_ptr<float>(), // input
output.data_ptr<float>(), // output
caffe2::pthreadpool_()); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status,
"xnn_setup_convolution2d_nhwc_f32 failed!");
if (context.transposed_) {
setup_status = xnn_setup_deconvolution2d_nhwc_f32(
context.op.get(), // operator
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
context.output_padding_[0], // adjustment_height
context.output_padding_[1], // adjustment_width
padded_input_nhwc.data_ptr<float>(), // input
output.data_ptr<float>(), // output
caffe2::pthreadpool_()); // threadpool
// Cache values to avoid setup for the next round.
context.cached_input_ptr = padded_input_nhwc.data_ptr<float>();
context.cached_output_ptr = output.data_ptr<float>();
context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch);
context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels);
context.input_height = padded_input_nhwc.size(Layout::Activation4D::height);
context.input_width = padded_input_nhwc.size(Layout::Activation4D::width);
} else {
setup_status = xnn_setup_convolution2d_nhwc_f32(
context.op.get(), // operator
padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size
padded_input_nhwc.size(Layout::Activation4D::height), // input_height
padded_input_nhwc.size(Layout::Activation4D::width), // input_width
padded_input_nhwc.data_ptr<float>(), // input
output.data_ptr<float>(), // output
caffe2::pthreadpool_());
}
TORCH_CHECK(
xnn_status_success == setup_status,
(context.transposed_ ? "xnn_setup_deconvolution2d_nhwc_f32 failed!"
: "xnn_setup_convolution2d_nhwc_f32 failed!"));
// Cache values to avoid setup for the next round
context.cached_input_ptr = padded_input_nhwc.data_ptr<float>();
context.cached_output_ptr = output.data_ptr<float>();
context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch);
context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels);
context.input_height = padded_input_nhwc.size(Layout::Activation4D::height);
context.input_width = padded_input_nhwc.size(Layout::Activation4D::width);
}
const xnn_status run_status = xnn_run_operator(
@ -265,12 +401,41 @@ c10::intrusive_ptr<xnnpack::Conv2dOpContext>
output_max);
}
c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
createConv2dTransposeClampPrePackOpContext(
Tensor weight,
c10::optional<Tensor> bias,
std::vector<int64_t> stride,
std::vector<int64_t> padding,
std::vector<int64_t> output_padding,
std::vector<int64_t> dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
return xnnpack::XNNPackTransposeConv2dOpContext::create_context(
std::move(weight),
std::move(bias),
std::move(padding),
std::move(output_padding),
std::move(stride),
std::move(dilation),
groups,
output_min,
output_max);
}
Tensor conv2d_clamp_run(
const Tensor& input,
const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context) {
return op_context->run(input);
}
Tensor conv2d_transpose_clamp_run(
const Tensor& input,
const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context) {
return op_context->run(input);
}
} // namespace convolution2d
} // namespace internal
@ -281,7 +446,8 @@ bool use_convolution2d(
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups) {
const int64_t groups,
const bool transposed) {
return internal::convolution2d::available(
weight,
bias,
@ -289,6 +455,7 @@ bool use_convolution2d(
stride,
dilation,
groups,
transposed,
ContextConv2D::kMin,
ContextConv2D::kMax) &&
internal::convolution2d::usable(input);
@ -307,9 +474,11 @@ Tensor convolution2d(
weight,
bias,
padding,
{0, 0}, // output_padding
stride,
dilation,
groups,
false, // transposed
ContextConv2D::kMin,
ContextConv2D::kMax);
}

View File

@ -23,17 +23,35 @@ c10::intrusive_ptr<xnnpack::Conv2dOpContext>
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max);
c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>
createConv2dTransposeClampPrePackOpContext(
Tensor weight,
c10::optional<Tensor> bias,
std::vector<int64_t> stride,
std::vector<int64_t> padding,
std::vector<int64_t> output_padding,
std::vector<int64_t> dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max);
Tensor conv2d_clamp_run(
const Tensor& input,
const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context);
Tensor conv2d_transpose_clamp_run(
const Tensor& input,
const c10::intrusive_ptr<xnnpack::TransposeConv2dOpContext>& op_context);
ContextConv2D create(
const Tensor& weight,
const c10::optional<Tensor>& bias,
const IntArrayRef padding,
const IntArrayRef output_padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups,
const bool transposed,
const float output_min,
const float output_max);

View File

@ -17,7 +17,8 @@ bool use_convolution2d(
const IntArrayRef padding,
const IntArrayRef stride,
const IntArrayRef dilation,
const int64_t groups);
const int64_t groups,
const bool transposed);
Tensor convolution2d(
const Tensor& input,

View File

@ -48,13 +48,16 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
weight,
bias,
padding,
{0, 0}, // output_padding
stride,
dilation,
groups,
false, // transposed
output_min ? output_min->to<float>()
: xnnpack::ContextConv2D::kMin,
output_max ? output_max->to<float>()
: xnnpack::ContextConv2D::kMax);
auto conv2d_op_context =
c10::make_intrusive<XNNPackConv2dOpContext>(
std::move(weight),
@ -66,6 +69,48 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
output_min,
output_max,
std::move(op_context));
return conv2d_op_context;
}
c10::intrusive_ptr<TransposeConv2dOpContext>
XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight,
c10::optional<at::Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& output_padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
const c10::optional<Scalar> output_min,
const c10::optional<Scalar> output_max) {
auto op_context =
xnnpack::internal::convolution2d::create(
weight,
bias,
padding,
output_padding,
stride,
dilation,
groups,
true, // transposed
output_min ? output_min->to<float>()
: xnnpack::ContextConv2D::kMin,
output_max ? output_max->to<float>()
: xnnpack::ContextConv2D::kMax);
auto conv2d_op_context =
c10::make_intrusive<XNNPackTransposeConv2dOpContext>(
std::move(weight),
std::move(bias),
std::move(padding),
std::move(output_padding),
std::move(stride),
std::move(dilation),
groups,
output_min,
output_max,
std::move(op_context));
return conv2d_op_context;
}
@ -73,6 +118,10 @@ Tensor XNNPackConv2dOpContext::run(const Tensor& input) {
return xnnpack::internal::convolution2d::run(op_context_, input);
}
Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) {
return xnnpack::internal::convolution2d::run(op_context_, input);
}
} // namespace xnnpack
} // namespace native
} // namespace at

View File

@ -24,6 +24,18 @@ using SerializationTypeConv2dPrePack = std::tuple<
int64_t,
c10::optional<Scalar>,
c10::optional<Scalar>>;
using SerializationTypeTransposeConv2dPrePack = std::tuple<
Tensor,
c10::optional<Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>,
int64_t,
c10::optional<Scalar>,
c10::optional<Scalar>>;
class LinearOpContext : public torch::jit::CustomClassHolder {
protected:
@ -94,6 +106,35 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder {
virtual Tensor run(const Tensor& input) = 0;
};
class TransposeConv2dOpContext : public torch::jit::CustomClassHolder {
protected:
Tensor orig_weight_;
c10::optional<Tensor> orig_bias_;
std::vector<int64_t> stride_;
std::vector<int64_t> padding_;
std::vector<int64_t> output_padding_;
std::vector<int64_t> dilation_;
int64_t groups_;
c10::optional<Scalar> output_min_;
c10::optional<Scalar> output_max_;
public:
SerializationTypeTransposeConv2dPrePack unpack() {
return std::make_tuple(
orig_weight_,
orig_bias_,
stride_,
padding_,
output_padding_,
dilation_,
groups_,
output_min_,
output_max_);
}
virtual Tensor run(const Tensor& input) = 0;
};
class XNNPackConv2dOpContext final : public Conv2dOpContext {
private:
ContextConv2D op_context_;
@ -120,7 +161,7 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext {
output_max_ = max;
}
Tensor run(const Tensor& input);
Tensor run(const Tensor& input) override;
static c10::intrusive_ptr<Conv2dOpContext> create_context(
Tensor&& weight,
@ -132,6 +173,49 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext {
const c10::optional<Scalar> output_min,
const c10::optional<Scalar> output_max);
};
class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext {
private:
ContextConv2D op_context_;
public:
XNNPackTransposeConv2dOpContext(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& output_padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
uint64_t groups,
c10::optional<Scalar> min,
c10::optional<Scalar> max,
ContextConv2D&& op_context)
: op_context_(std::move(op_context)) {
orig_weight_ = std::move(weight);
orig_bias_ = std::move(bias);
padding_ = std::move(padding);
output_padding_ = std::move(output_padding);
stride_ = std::move(stride);
dilation_ = std::move(dilation);
groups_ = groups;
output_min_ = min;
output_max_ = max;
}
Tensor run(const Tensor& input) override;
static c10::intrusive_ptr<TransposeConv2dOpContext> create_context(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& output_padding,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& dilation,
int64_t groups,
const c10::optional<Scalar> output_min,
const c10::optional<Scalar> output_max);
};
} // namespace xnnpack
} // namespace native

View File

@ -13,6 +13,7 @@ namespace xnnpack {
using internal::linear::createLinearClampPrePackOpContext;
using internal::convolution2d::createConv2dClampPrePackOpContext;
using internal::convolution2d::createConv2dTransposeClampPrePackOpContext;
TORCH_LIBRARY(xnnpack, m) {
m.class_<LinearOpContext>("LinearOpContext")
@ -48,20 +49,45 @@ TORCH_LIBRARY(xnnpack, m) {
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
m.class_<TransposeConv2dOpContext>("TransposeConv2dOpContext")
.def_pickle(
[](const c10::intrusive_ptr<TransposeConv2dOpContext>& op_context)
-> SerializationTypeTransposeConv2dPrePack { // __getstate__
return op_context->unpack();
},
[](SerializationTypeTransposeConv2dPrePack state)
-> c10::intrusive_ptr<TransposeConv2dOpContext> { // __setstate__
return createConv2dTransposeClampPrePackOpContext(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)),
std::move(std::get<8>(state)));
});
}
TORCH_LIBRARY(prepacked, m) {
m.def("linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext");
m.def("linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y");
m.def("conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext");
m.def("conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext");
m.def("conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y");
m.def("conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y");
}
TORCH_LIBRARY_IMPL(prepacked, CPU, m) {
m.impl("linear_clamp_prepack", TORCH_FN(createLinearClampPrePackOpContext));
m.impl("linear_clamp_run", TORCH_FN(internal::linear::linear_clamp_run));
m.impl("conv2d_clamp_prepack", TORCH_FN(createConv2dClampPrePackOpContext));
m.impl("conv2d_transpose_clamp_prepack", TORCH_FN(createConv2dTransposeClampPrePackOpContext));
m.impl("conv2d_clamp_run", TORCH_FN(internal::convolution2d::conv2d_clamp_run));
m.impl("conv2d_transpose_clamp_run", TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run));
}
} // namespace xnnpack

View File

@ -35,7 +35,8 @@ bool use_convolution2d(
const IntArrayRef,
const IntArrayRef,
const IntArrayRef,
const int64_t) {
const int64_t,
bool) {
return false;
}

View File

@ -93,6 +93,72 @@ class TestXNNPACKOps(TestCase):
xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(1, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
output_pad_h=st.integers(0, 2),
output_pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_conv2d_transpose(self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
output_pad_h,
output_pad_w,
dilation,
use_bias,
format):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0]
>= dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1]
>= dilations[1] * (kernels[1] - 1) + 1)
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
bias = None
if use_bias:
bias = torch.rand((output_channels))
# Note that groups/dilation is in reverse order from conv2d
ref_result = F.conv_transpose2d(input_data, weight, bias,
strides, paddings, output_paddings, groups, dilation)
packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
strides, paddings,
output_paddings, dilations,
groups)
xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias)
torch.testing.assert_allclose(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3)
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
@ -244,6 +310,114 @@ class TestXNNPACKSerDes(TestCase):
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
width=st.integers(5, 64),
output_channels_per_group=st.integers(1, 32),
groups=st.integers(1, 16),
kernel_h=st.integers(1, 7),
kernel_w=st.integers(1, 7),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
output_pad_h=st.integers(0, 2),
output_pad_w=st.integers(0, 2),
dilation=st.integers(1, 2),
use_bias=st.booleans(),
format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last]))
def test_conv2d_transpose(self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
output_pad_h,
output_pad_w,
dilation,
use_bias,
format):
class Conv2DT(torch.nn.Module):
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
super(Conv2DT, self).__init__()
self.weight = weight
self.bias = bias
self.strides = strides
self.paddings = paddings
self.output_paddings = output_paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv_transpose2d(x, self.weight, self.bias,
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
class Conv2DTPrePacked(torch.nn.Module):
def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups):
super(Conv2DTPrePacked, self).__init__()
self.packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias,
strides, paddings,
output_paddings,
dilations, groups)
def forward(self, x):
return torch.ops.prepacked.conv2d_transpose_clamp_run(x, self.packed_weight_bias)
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
assume(height + 2 * paddings[0] >=
dilations[0] * (kernels[0] - 1) + 1)
assume(width + 2 * paddings[1] >=
dilations[1] * (kernels[1] - 1) + 1)
assume((output_pad_h < stride_h) and (output_pad_h < dilation))
assume((output_pad_w < stride_w) and (output_pad_w < dilation))
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w))
bias = None
if use_bias:
bias = torch.rand((output_channels))
scripted_conv2d = torch.jit.script(Conv2DT(weight, bias,
strides, paddings,
output_paddings, dilations, groups))
scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DTPrePacked(
weight, bias, strides, paddings, output_paddings, dilations, groups))
ref_result = scripted_conv2d(input_data)
xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
# Serialize the modules and then deserialize
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
input_data = input_data.contiguous(memory_format=format)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d, buffer)
buffer.seek(0)
deserialized_conv2d = torch.jit.load(buffer)
buffer = io.BytesIO()
torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
buffer.seek(0)
deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
ref_result = deserialized_conv2d(input_data)
xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@given(batch_size=st.integers(0, 3),
input_channels_per_group=st.integers(1, 32),
height=st.integers(5, 64),
@ -454,14 +628,17 @@ class TestXNNPACKRewritePass(TestCase):
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
output_pad_h = output_pad_w = 0
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
output_paddings = (output_pad_h, output_pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_transpose_weight_shape = (input_channels, output_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
class Conv2D(torch.nn.Module):
@ -478,12 +655,34 @@ class TestXNNPACKRewritePass(TestCase):
return F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
class Conv2DT(torch.nn.Module):
def __init__(self):
super(Conv2DT, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_transpose_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.output_paddings = output_paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv_transpose2d(x, self.weight, self.bias,
self.strides, self.paddings, self.output_paddings, self.groups, self.dilations)
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1}
TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
transpose_data_shape = (batch_size, input_channels, height, width)
transpose_pattern_count_map = {"Tensor = aten::conv_transpose2d": -1,
"prepacked::conv2d_transpose_clamp_prepack": 1,
"prepacked::conv2d_transpose_clamp_run": 1}
TestXNNPACKRewritePass.validate_transformed_module(Conv2DT(), transpose_pattern_count_map, data_shape)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
conv_bias = torch.rand((output_channels))

View File

@ -78,6 +78,13 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv2d_transpose = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv1d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -124,6 +131,22 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
(calc_value_map["output_padding"].toIntList()[0] == 0) &&
(calc_value_map["output_padding"].toIntList()[1] == 0);
};
auto filter_conv2d_transpose =
[](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 2 ||
calc_value_map["stride"].toIntList().size() != 2 ||
calc_value_map["padding"].toIntList().size() != 2 ||
calc_value_map["dilation"].toIntList().size() != 2) {
return false;
}
return calc_value_map["transposed"].toBool() &&
!calc_value_map["benchmark"].toBool() &&
!calc_value_map["deterministic"].toBool() &&
calc_value_map["cudnn_enabled"].toBool();
};
auto filter_conv3d = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
@ -148,6 +171,10 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter_conv2d;
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
SubgraphRewriter rewriter_conv2d_transpose;
rewriter_conv2d_transpose.RegisterRewritePattern(
convolution, conv2d_transpose);
rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose);
SubgraphRewriter rewriter_conv3d;
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
rewriter_conv3d.runOnGraph(graph, filter_conv3d);

View File

@ -143,6 +143,26 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
rewriter.RegisterRewritePattern(
conv_2d_pattern, prepacked_ops_conv2d_pattern);
rewriter.runOnGraph(graph);
std::string conv_2d_transpose_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
%output_padding:int[], %groups:int):
%r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string prepacked_ops_conv2d_transpose_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
%output_min_max : None = prim::Constant()
%packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
%weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
%output_min_max, %output_min_max)
%r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
return (%r) )";
SubgraphRewriter transpose_rewriter;
transpose_rewriter.RegisterRewritePattern(
conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
transpose_rewriter.runOnGraph(graph);
}
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
@ -321,7 +341,11 @@ void FoldPrePackingOps(script::Module& m) {
return (
(n->kind() ==
Symbol::fromQualString("prepacked::linear_clamp_prepack")) ||
n->kind() == Symbol::fromQualString("prepacked::conv2d_clamp_prepack"));
n->kind() ==
Symbol::fromQualString("prepacked::conv2d_clamp_prepack") ||
n->kind() ==
Symbol::fromQualString(
"prepacked::conv2d_transpose_clamp_prepack"));
};
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
}
@ -397,7 +421,7 @@ script::Module optimizeForMobile(
const std::set<MobileOptimizerType>& blocklist,
const std::vector<std::string>& preserved_methods) {
TORCH_INTERNAL_ASSERT(
"Mobile optimizaiton only available with XNNPACK at the moment. "
"Mobile optimization only available with XNNPACK at the moment. "
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
return module;
}