No-batch-dim support for ConvNd (#70506)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70506

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33355034

Pulled By: jbschlosser

fbshipit-source-id: 5a42645299b1d82cee7d461826acca1c5b35a71c
This commit is contained in:
Joel Schlosser
2022-01-06 16:52:12 -08:00
committed by Facebook GitHub Bot
parent 6896b2d734
commit 7b8f73dd32
7 changed files with 326 additions and 107 deletions

View File

@ -159,7 +159,7 @@ class TestORTTensor(common.TestCase):
bias = torch.empty(6, device='ort')
# Make sure forward is overriden
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0])