diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index c3c9518517ae..a487bd1242e0 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -125,6 +125,15 @@ class Test(torch.jit.ScriptModule): r = r.contiguous() return r + @torch.jit.script_method + def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.nn.functional.conv3d(x, w) + if toChannelsLast: + r = r.contiguous(memory_format=torch.channels_last_3d) + else: + r = r.contiguous() + return r + @torch.jit.script_method def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() diff --git a/android/pytorch_android/src/androidTest/assets/android_api_module.ptl b/android/pytorch_android/src/androidTest/assets/android_api_module.ptl index df62dd862088..9adfb84bf855 100644 Binary files a/android/pytorch_android/src/androidTest/assets/android_api_module.ptl and b/android/pytorch_android/src/androidTest/assets/android_api_module.ptl differ diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java index d2dfa93da17a..7980a34c0434 100644 --- a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java @@ -348,15 +348,32 @@ public abstract class PytorchTestBase { @Test public void testChannelsLastConv2d() throws IOException { long[] inputShape = new long[] {1, 3, 2, 2}; - long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104}; - Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS); - long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104}; - Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST); + long[] dataNCHW = new long[] { + 111, 112, + 121, 122, + 211, 212, + 221, 222, + + 311, 312, + 321, 322}; + Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS); + long[] dataNHWC = new long[] { + 111, 211, 311, 112, 212, 312, + + 121, 221, 321, 122, 222, 322}; + Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST); long[] weightShape = new long[] {3, 3, 1, 1}; - long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1}; + long[] dataWeightOIHW = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1}; Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS); - long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1}; + long[] dataWeightOHWI = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1}; + Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST); final Module module = loadModel(TEST_MODULE_ASSET_NAME); @@ -367,7 +384,15 @@ public abstract class PytorchTestBase { outputNCHW, MemoryFormat.CONTIGUOUS, new long[] {1, 3, 2, 2}, - new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104}); + new long[] { + 2*111, 2*112, + 2*121, 2*122, + + 211, 212, + 221, 222, + + -311, -312, + -321, -322}); final IValue outputNHWC = module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true)); @@ -375,7 +400,89 @@ public abstract class PytorchTestBase { outputNHWC, MemoryFormat.CHANNELS_LAST, new long[] {1, 3, 2, 2}, - new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104}); + new long[] { + 2*111, 211, -311, 2*112, 212, -312, + 2*121, 221, -321, 2*122, 222, -322}); + } + + @Test + public void testChannelsLastConv3d() throws IOException { + long[] inputShape = new long[] {1, 3, 2, 2, 2}; + long[] dataNCDHW = new long[] { + 1111, 1112, + 1121, 1122, + 1211, 1212, + 1221, 1222, + + 2111, 2112, + 2121, 2122, + 2211, 2212, + 2221, 2222, + + 3111, 3112, + 3121, 3122, + 3211, 3212, + 3221, 3222}; + Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS); + long[] dataNDHWC = new long[] { + 1111, 2111, 3111, + 1112, 2112, 3112, + + 1121, 2121, 3121, + 1122, 2122, 3122, + + 1211, 2211, 3211, + 1212, 2212, 3212, + + 1221, 2221, 3221, + 1222, 2222, 3222}; + + Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D); + + long[] weightShape = new long[] {3, 3, 1, 1, 1}; + long[] dataWeightOIDHW = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1, + }; + Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS); + long[] dataWeightODHWI = new long[] { + 2, 0, 0, + 0, 1, 0, + 0, 0, -1, + }; + Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D); + + final Module module = loadModel(TEST_MODULE_ASSET_NAME); + + final IValue outputNCDHW = + module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false)); + assertIValueTensor( + outputNCDHW, + MemoryFormat.CONTIGUOUS, + new long[] {1, 3, 2, 2, 2}, + new long[] { + 2*1111, 2*1112, 2*1121, 2*1122, + 2*1211, 2*1212, 2*1221, 2*1222, + + 2111, 2112, 2121, 2122, + 2211, 2212, 2221, 2222, + + -3111, -3112, -3121, -3122, + -3211, -3212, -3221, -3222}); + + final IValue outputNDHWC = + module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true)); + assertIValueTensor( + outputNDHWC, + MemoryFormat.CHANNELS_LAST_3D, + new long[] {1, 3, 2, 2, 2}, + new long[] { + 2*1111, 2111, -3111, 2*1112, 2112, -3112, + 2*1121, 2121, -3121, 2*1122, 2122, -3122, + + 2*1211, 2211, -3211, 2*1212, 2212, -3212, + 2*1221, 2221, -3221, 2*1222, 2222, -3222}); } @Test diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit index 3bd9037da4ee..8605ab13d555 100644 --- a/android/pytorch_android/test_asset.jit +++ b/android/pytorch_android/test_asset.jit @@ -84,6 +84,15 @@ def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = r.contiguous() return r +def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.conv3d(x, w) + if (toChannelsLast): + # memory_format=torch.channels_last_3d + r = r.contiguous(memory_format=2) + else: + r = r.contiguous() + return r + def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() diff --git a/ios/TestApp/models/android_api_module.ptl b/ios/TestApp/models/android_api_module.ptl index df62dd862088..9adfb84bf855 100644 Binary files a/ios/TestApp/models/android_api_module.ptl and b/ios/TestApp/models/android_api_module.ptl differ diff --git a/test/mobile/model_test/android_api_module.py b/test/mobile/model_test/android_api_module.py index acada05fc2ff..ac81bea0d57e 100644 --- a/test/mobile/model_test/android_api_module.py +++ b/test/mobile/model_test/android_api_module.py @@ -112,6 +112,15 @@ class AndroidAPIModule(torch.jit.ScriptModule): r = r.contiguous() return r + @torch.jit.script_method + def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: + r = torch.nn.functional.conv3d(x, w) + if toChannelsLast: + r = r.contiguous(memory_format=torch.channels_last_3d) + else: + r = r.contiguous() + return r + @torch.jit.script_method def contiguous(self, x: Tensor) -> Tensor: return x.contiguous()