mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
average pooling format change to fit the cudnn interface
This commit is contained in:
@ -97,7 +97,7 @@ bool AveragePoolOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
|
||||
template <>
|
||||
bool AveragePoolGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto& dY = Input(2);
|
||||
auto* dX = Output(0);
|
||||
// TODO(Yangqing): Add shape checks.
|
||||
dX->ReshapeLike(X);
|
||||
@ -143,7 +143,7 @@ bool AveragePoolGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
||||
template <>
|
||||
bool AveragePoolGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto& dY = Input(2);
|
||||
CAFFE_CHECK_EQ(dY.ndim(), 4);
|
||||
auto* dX = Output(0);
|
||||
// TODO(Yangqing): Add shape checks.
|
||||
@ -196,7 +196,7 @@ struct GetAveragePoolGradient : public GetGradientDefBase {
|
||||
vector<OperatorDef>* Create(const OperatorDef& def) override {
|
||||
return SingleGradientDef(
|
||||
"AveragePoolGradient", "",
|
||||
vector<string>{I(def, 0), GO(def, 0)},
|
||||
vector<string>{I(def, 0), O(def, 0), GO(def, 0)},
|
||||
vector<string>{GI(def, 0)});
|
||||
}
|
||||
};
|
||||
|
@ -179,7 +179,7 @@ bool AveragePoolOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
||||
template <>
|
||||
bool AveragePoolGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto& dY = Input(2);
|
||||
CAFFE_CHECK_EQ(dY.ndim(), 4);
|
||||
auto* dX = Output(0);
|
||||
dX->ReshapeLike(X);
|
||||
@ -196,7 +196,7 @@ bool AveragePoolGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
|
||||
template <>
|
||||
bool AveragePoolGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
||||
auto& X = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto& dY = Input(2);
|
||||
CAFFE_CHECK_EQ(dY.ndim(), 4);
|
||||
auto* dX = Output(0);
|
||||
dX->ReshapeLike(X);
|
||||
|
@ -38,9 +38,11 @@ class AveragePoolGradientOp final :
|
||||
bool RunOnDeviceWithOrderNCHW() override;
|
||||
bool RunOnDeviceWithOrderNHWC() override;
|
||||
|
||||
// Input: X, Y_grad
|
||||
// Input: X, Y, Y_grad. Y is in fact not used, but to keep compatibility
|
||||
// with CuDNN we keep it here. Definitely not optimal, but probably does not
|
||||
// hurt that much.
|
||||
// Output: X_grad
|
||||
INPUT_OUTPUT_STATS(2, 2, 1, 1);
|
||||
INPUT_OUTPUT_STATS(3, 3, 1, 1);
|
||||
DISABLE_COPY_AND_ASSIGN(AveragePoolGradientOp);
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user