mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw error if stateless.functional_call
called with nn.DataParallel
(#107403)
Part of #77576 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107403 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
600f9ef2ad
commit
36141de427
@ -131,6 +131,21 @@ class TestStatelessFunctionalAPI(TestCase):
|
||||
dp_module = torch.nn.DataParallel(module, [0, 1])
|
||||
self._run_call_with_mock_module(dp_module, functional_call, device='cuda', prefix='module')
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, 'multi-GPU not supported')
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
])
|
||||
def test_functional_call_with_data_parallel_error(self, functional_call):
|
||||
module = MockModule()
|
||||
module.cuda()
|
||||
dp_module = torch.nn.DataParallel(module, [0, 1])
|
||||
with self.assertRaisesRegex(RuntimeError, r'used with nn.DataParallel module'):
|
||||
functional_call(
|
||||
dp_module,
|
||||
{'module.weight': torch.zeros(5, device='cuda')},
|
||||
(torch.ones(2, 5, device='cuda'),))
|
||||
|
||||
@parametrize("functional_call", [
|
||||
subtest(torch.func.functional_call, "torch_func"),
|
||||
subtest(stateless.functional_call, "stateless")
|
||||
|
Reference in New Issue
Block a user