mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add NHWC order support in the cost inference function of 3d conv (#19170)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19170 As title The quantized resnext3d model in production got the following failures without the fix: ``` Caffe2 operator Int8ConvRelu logging error: [enforce fail at conv_pool_op_base.h:463] order == StorageOrder::NCHW. 1 vs 2. Conv3D only supports NCHW on the production quantized model ``` Reviewed By: jspark1105 Differential Revision: D14894276 fbshipit-source-id: ef97772277f322ed45215e382c3b4a3702e47e59
This commit is contained in:
committed by
Facebook Github Bot
parent
ffc9e29844
commit
84b264b17d
@ -460,15 +460,25 @@ class ConvPoolOpBase : public Operator<Context> {
|
||||
N = X.dims(0);
|
||||
if (X.dims_size() == 5) {
|
||||
// 3D convolution
|
||||
CAFFE_ENFORCE_EQ(order, StorageOrder::NCHW, "Conv3D only supports NCHW");
|
||||
Y_t = Y.dims(2);
|
||||
Y_h = Y.dims(3);
|
||||
Y_w = Y.dims(4);
|
||||
kernel_t = W.dims(2);
|
||||
kernel_h = W.dims(3);
|
||||
kernel_w = W.dims(4);
|
||||
in_channels = W.dims(1);
|
||||
out_channels = W.dims(0);
|
||||
if (order == StorageOrder::NHWC) {
|
||||
Y_t = Y.dims(1);
|
||||
Y_h = Y.dims(2);
|
||||
Y_w = Y.dims(3);
|
||||
kernel_t = W.dims(1);
|
||||
kernel_h = W.dims(2);
|
||||
kernel_w = W.dims(3);
|
||||
in_channels = W.dims(4);
|
||||
out_channels = W.dims(0);
|
||||
} else {
|
||||
Y_t = Y.dims(2);
|
||||
Y_h = Y.dims(3);
|
||||
Y_w = Y.dims(4);
|
||||
kernel_t = W.dims(2);
|
||||
kernel_h = W.dims(3);
|
||||
kernel_w = W.dims(4);
|
||||
in_channels = W.dims(1);
|
||||
out_channels = W.dims(0);
|
||||
}
|
||||
} else if (X.dims_size() == 4) {
|
||||
// 2D convolution
|
||||
CAFFE_ENFORCE_EQ(W.dims_size(), 4, "Conv2D should have 4D filter tensor");
|
||||
|
Reference in New Issue
Block a user