Add some tests for 3d channels last (#118283)

Part of a multi-PR work to fix #59168.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118283
Approved by: https://github.com/albanD
This commit is contained in:
hyperfraise
2024-01-30 01:26:43 +00:00
committed by PyTorch MergeBot
parent bacbad5bc9
commit aef820926c
6 changed files with 142 additions and 8 deletions

View File

@ -125,6 +125,15 @@ class Test(torch.jit.ScriptModule):
r = r.contiguous()
return r
@torch.jit.script_method
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv3d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last_3d)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()

View File

@ -348,15 +348,32 @@ public abstract class PytorchTestBase {
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] dataNCHW = new long[] {
111, 112,
121, 122,
211, 212,
221, 222,
311, 312,
321, 322};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {
111, 211, 311, 112, 212, 312,
121, 221, 321, 122, 222, 322};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
long[] dataWeightOIHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
long[] dataWeightOHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
@ -367,7 +384,15 @@ public abstract class PytorchTestBase {
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
new long[] {
2*111, 2*112,
2*121, 2*122,
211, 212,
221, 222,
-311, -312,
-321, -322});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
@ -375,7 +400,89 @@ public abstract class PytorchTestBase {
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
new long[] {
2*111, 211, -311, 2*112, 212, -312,
2*121, 221, -321, 2*122, 222, -322});
}
@Test
public void testChannelsLastConv3d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2, 2};
long[] dataNCDHW = new long[] {
1111, 1112,
1121, 1122,
1211, 1212,
1221, 1222,
2111, 2112,
2121, 2122,
2211, 2212,
2221, 2222,
3111, 3112,
3121, 3122,
3211, 3212,
3221, 3222};
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNDHWC = new long[] {
1111, 2111, 3111,
1112, 2112, 3112,
1121, 2121, 3121,
1122, 2122, 3122,
1211, 2211, 3211,
1212, 2212, 3212,
1221, 2221, 3221,
1222, 2222, 3222};
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
long[] weightShape = new long[] {3, 3, 1, 1, 1};
long[] dataWeightOIDHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightODHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCDHW =
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
assertIValueTensor(
outputNCDHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2*1112, 2*1121, 2*1122,
2*1211, 2*1212, 2*1221, 2*1222,
2111, 2112, 2121, 2122,
2211, 2212, 2221, 2222,
-3111, -3112, -3121, -3122,
-3211, -3212, -3221, -3222});
final IValue outputNDHWC =
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
assertIValueTensor(
outputNDHWC,
MemoryFormat.CHANNELS_LAST_3D,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2111, -3111, 2*1112, 2112, -3112,
2*1121, 2121, -3121, 2*1122, 2122, -3122,
2*1211, 2211, -3211, 2*1212, 2212, -3212,
2*1221, 2221, -3221, 2*1222, 2222, -3222});
}
@Test

View File

@ -84,6 +84,15 @@ def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = r.contiguous()
return r
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.conv3d(x, w)
if (toChannelsLast):
# memory_format=torch.channels_last_3d
r = r.contiguous(memory_format=2)
else:
r = r.contiguous()
return r
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()

View File

@ -112,6 +112,15 @@ class AndroidAPIModule(torch.jit.ScriptModule):
r = r.contiguous()
return r
@torch.jit.script_method
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv3d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last_3d)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()