mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Added support for multidimensional tensors in PReLU; Channel number now in second dimension
This commit is contained in:
@ -22,31 +22,18 @@ void THNN_(PReLU_updateOutput)(
|
||||
else
|
||||
{
|
||||
input = THTensor_(newContiguous)(input);
|
||||
long bs, ks;
|
||||
long bs = 1, ks = 1;
|
||||
{
|
||||
long input_ndim = THTensor_(nDimension)(input);
|
||||
switch (input_ndim)
|
||||
{
|
||||
case 1:
|
||||
bs = 1;
|
||||
ks = 1;
|
||||
break;
|
||||
case 2:
|
||||
bs = input->size[0];
|
||||
ks = 1;
|
||||
break;
|
||||
case 3:
|
||||
bs = 1;
|
||||
ks = input->size[1] * input->size[2];
|
||||
break;
|
||||
case 4:
|
||||
bs = input->size[0];
|
||||
ks = input->size[2] * input->size[3];
|
||||
break;
|
||||
}
|
||||
if (input->size[input_ndim > 1] != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size[input_ndim > 1]);
|
||||
|
||||
if (input->size[(input_ndim + 1) % 2] != nOutputPlane)
|
||||
THError("wrong number of input planes");
|
||||
if (input_ndim > 1) {
|
||||
bs = input->size[0];
|
||||
for (int d = 2; d < input_ndim; d++) {
|
||||
ks *= input->size[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
real *output_data = THTensor_(data)(output);
|
||||
@ -100,31 +87,18 @@ void THNN_(PReLU_updateGradInput)(
|
||||
const real *weight_data = THTensor_(data)(weight);
|
||||
real *gradInput_data = THTensor_(data)(gradInput);
|
||||
|
||||
long bs, ks;
|
||||
long bs = 1, ks = 1;
|
||||
{
|
||||
long input_ndim = THTensor_(nDimension)(input);
|
||||
switch (input_ndim)
|
||||
{
|
||||
case 1:
|
||||
bs = 1;
|
||||
ks = 1;
|
||||
break;
|
||||
case 2:
|
||||
bs = input->size[0];
|
||||
ks = 1;
|
||||
break;
|
||||
case 3:
|
||||
bs = 1;
|
||||
ks = input->size[1] * input->size[2];
|
||||
break;
|
||||
case 4:
|
||||
bs = input->size[0];
|
||||
ks = input->size[2] * input->size[3];
|
||||
break;
|
||||
}
|
||||
if (input->size[input_ndim > 1] != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size[input_ndim > 1]);
|
||||
|
||||
if (input->size[(input_ndim + 1) % 2] != nOutputPlane)
|
||||
THError("wrong number of input planes");
|
||||
if (input_ndim > 1) {
|
||||
bs = input->size[0];
|
||||
for (int d = 2; d < input_ndim; d++) {
|
||||
ks *= input->size[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
THIndex_t i, j, k;
|
||||
@ -184,31 +158,18 @@ void THNN_(PReLU_accGradParameters)(
|
||||
{
|
||||
input = THTensor_(newContiguous)(input);
|
||||
gradOutput = THTensor_(newContiguous)(gradOutput);
|
||||
long bs, ks;
|
||||
long bs = 1, ks = 1;
|
||||
{
|
||||
long input_ndim = THTensor_(nDimension)(input);
|
||||
switch (input_ndim)
|
||||
{
|
||||
case 1:
|
||||
bs = 1;
|
||||
ks = 1;
|
||||
break;
|
||||
case 2:
|
||||
bs = input->size[0];
|
||||
ks = 1;
|
||||
break;
|
||||
case 3:
|
||||
bs = 1;
|
||||
ks = input->size[1] * input->size[2];
|
||||
break;
|
||||
case 4:
|
||||
bs = input->size[0];
|
||||
ks = input->size[2] * input->size[3];
|
||||
break;
|
||||
}
|
||||
if (input->size[input_ndim > 1] != nOutputPlane)
|
||||
THError("Wrong number of input planes. Expected %d but got %d.", nOutputPlane, input->size[input_ndim > 1]);
|
||||
|
||||
if (input->size[(input_ndim + 1) % 2] != nOutputPlane)
|
||||
THError("wrong number of input planes");
|
||||
if (input_ndim > 1) {
|
||||
bs = input->size[0];
|
||||
for (int d = 2; d < input_ndim; d++) {
|
||||
ks *= input->size[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const real *input_data = THTensor_(data)(input);
|
||||
|
||||
Reference in New Issue
Block a user