Add NHWC order support in the cost inference function of 3d conv (#19170)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19170

As title
The quantized resnext3d model in production got the following failures without the fix:

```
 Caffe2 operator Int8ConvRelu logging error: [enforce fail at conv_pool_op_base.h:463] order == StorageOrder::NCHW. 1 vs 2. Conv3D only supports NCHW on the production quantized model
```

Reviewed By: jspark1105

Differential Revision: D14894276

fbshipit-source-id: ef97772277f322ed45215e382c3b4a3702e47e59
This commit is contained in:
Summer Deng
2019-04-15 16:43:58 -07:00
committed by Facebook Github Bot
parent ffc9e29844
commit 84b264b17d

View File

@ -460,15 +460,25 @@ class ConvPoolOpBase : public Operator<Context> {
N = X.dims(0);
if (X.dims_size() == 5) {
// 3D convolution
CAFFE_ENFORCE_EQ(order, StorageOrder::NCHW, "Conv3D only supports NCHW");
Y_t = Y.dims(2);
Y_h = Y.dims(3);
Y_w = Y.dims(4);
kernel_t = W.dims(2);
kernel_h = W.dims(3);
kernel_w = W.dims(4);
in_channels = W.dims(1);
out_channels = W.dims(0);
if (order == StorageOrder::NHWC) {
Y_t = Y.dims(1);
Y_h = Y.dims(2);
Y_w = Y.dims(3);
kernel_t = W.dims(1);
kernel_h = W.dims(2);
kernel_w = W.dims(3);
in_channels = W.dims(4);
out_channels = W.dims(0);
} else {
Y_t = Y.dims(2);
Y_h = Y.dims(3);
Y_w = Y.dims(4);
kernel_t = W.dims(2);
kernel_h = W.dims(3);
kernel_w = W.dims(4);
in_channels = W.dims(1);
out_channels = W.dims(0);
}
} else if (X.dims_size() == 4) {
// 2D convolution
CAFFE_ENFORCE_EQ(W.dims_size(), 4, "Conv2D should have 4D filter tensor");