Throw error if stateless.functional_call called with nn.DataParallel (#107403)

Part of #77576

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107403
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Kurt Mohler
2023-08-18 03:01:59 +00:00
committed by PyTorch MergeBot
parent 600f9ef2ad
commit 36141de427
2 changed files with 19 additions and 0 deletions

View File

@ -131,6 +131,21 @@ class TestStatelessFunctionalAPI(TestCase):
dp_module = torch.nn.DataParallel(module, [0, 1])
self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_functional_call_with_data_parallel_error(self, functional_call):
module = MockModule()
module.cuda()
dp_module = torch.nn.DataParallel(module, [0, 1])
with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
functional_call(
dp_module,
{'module.weight': torch.zeros(5, device='cuda')},
(torch.ones(2, 5, device='cuda'),))
@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")