mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add some tests for 3d channels last (#118283)
Part of a multi-PR work to fix #59168. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118283 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bacbad5bc9
commit
aef820926c
@ -125,6 +125,15 @@ class Test(torch.jit.ScriptModule):
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv3d(x, w)
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last_3d)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
@torch.jit.script_method
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
Binary file not shown.
@ -348,15 +348,32 @@ public abstract class PytorchTestBase {
|
||||
@Test
|
||||
public void testChannelsLastConv2d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||
long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
long[] dataNCHW = new long[] {
|
||||
111, 112,
|
||||
121, 122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
311, 312,
|
||||
321, 322};
|
||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNHWC = new long[] {
|
||||
111, 211, 311, 112, 212, 312,
|
||||
|
||||
121, 221, 321, 122, 222, 322};
|
||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||
long[] weightShape = new long[] {3, 3, 1, 1};
|
||||
long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
long[] dataWeightOIHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
||||
long[] dataWeightOHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1};
|
||||
|
||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
@ -367,7 +384,15 @@ public abstract class PytorchTestBase {
|
||||
outputNCHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
|
||||
new long[] {
|
||||
2*111, 2*112,
|
||||
2*121, 2*122,
|
||||
|
||||
211, 212,
|
||||
221, 222,
|
||||
|
||||
-311, -312,
|
||||
-321, -322});
|
||||
|
||||
final IValue outputNHWC =
|
||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||
@ -375,7 +400,89 @@ public abstract class PytorchTestBase {
|
||||
outputNHWC,
|
||||
MemoryFormat.CHANNELS_LAST,
|
||||
new long[] {1, 3, 2, 2},
|
||||
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
||||
new long[] {
|
||||
2*111, 211, -311, 2*112, 212, -312,
|
||||
2*121, 221, -321, 2*122, 222, -322});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testChannelsLastConv3d() throws IOException {
|
||||
long[] inputShape = new long[] {1, 3, 2, 2, 2};
|
||||
long[] dataNCDHW = new long[] {
|
||||
1111, 1112,
|
||||
1121, 1122,
|
||||
1211, 1212,
|
||||
1221, 1222,
|
||||
|
||||
2111, 2112,
|
||||
2121, 2122,
|
||||
2211, 2212,
|
||||
2221, 2222,
|
||||
|
||||
3111, 3112,
|
||||
3121, 3122,
|
||||
3211, 3212,
|
||||
3221, 3222};
|
||||
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataNDHWC = new long[] {
|
||||
1111, 2111, 3111,
|
||||
1112, 2112, 3112,
|
||||
|
||||
1121, 2121, 3121,
|
||||
1122, 2122, 3122,
|
||||
|
||||
1211, 2211, 3211,
|
||||
1212, 2212, 3212,
|
||||
|
||||
1221, 2221, 3221,
|
||||
1222, 2222, 3222};
|
||||
|
||||
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
long[] weightShape = new long[] {3, 3, 1, 1, 1};
|
||||
long[] dataWeightOIDHW = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||
long[] dataWeightODHWI = new long[] {
|
||||
2, 0, 0,
|
||||
0, 1, 0,
|
||||
0, 0, -1,
|
||||
};
|
||||
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||
|
||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||
|
||||
final IValue outputNCDHW =
|
||||
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
|
||||
assertIValueTensor(
|
||||
outputNCDHW,
|
||||
MemoryFormat.CONTIGUOUS,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2*1112, 2*1121, 2*1122,
|
||||
2*1211, 2*1212, 2*1221, 2*1222,
|
||||
|
||||
2111, 2112, 2121, 2122,
|
||||
2211, 2212, 2221, 2222,
|
||||
|
||||
-3111, -3112, -3121, -3122,
|
||||
-3211, -3212, -3221, -3222});
|
||||
|
||||
final IValue outputNDHWC =
|
||||
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
|
||||
assertIValueTensor(
|
||||
outputNDHWC,
|
||||
MemoryFormat.CHANNELS_LAST_3D,
|
||||
new long[] {1, 3, 2, 2, 2},
|
||||
new long[] {
|
||||
2*1111, 2111, -3111, 2*1112, 2112, -3112,
|
||||
2*1121, 2121, -3121, 2*1122, 2122, -3122,
|
||||
|
||||
2*1211, 2211, -3211, 2*1212, 2212, -3212,
|
||||
2*1221, 2221, -3221, 2*1222, 2222, -3222});
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -84,6 +84,15 @@ def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.conv3d(x, w)
|
||||
if (toChannelsLast):
|
||||
# memory_format=torch.channels_last_3d
|
||||
r = r.contiguous(memory_format=2)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
return r
|
||||
|
||||
def contiguous(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous()
|
||||
|
||||
|
Reference in New Issue
Block a user