mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix c10d
-> dist
in test_ddp_hooks.py
(#61864)
Summary: **Overview:** The existing `test_ddp_hooks.py` test file uses a prefix `c10d`, which is not defined in the file, meaning the test errors if left as is. This renames each `c10d` prefix to `dist`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61864 Test Plan: All four tests pass when run: ``` gpurun python test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py ``` Reviewed By: ejguan Differential Revision: D29783860 Pulled By: andwgu fbshipit-source-id: 16bdd2dfcb76192964246148f14851a74f8907c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
109bd5e78a
commit
5186fa2831
@ -108,8 +108,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
This unit test verifies the ``allreduce`` hook registered case gives same result
|
||||
with no hook registered case.
|
||||
"""
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -125,8 +125,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
This unit test verifies the ``fp16 compress`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -142,8 +142,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
This unit test verifies the ``quantize per tensor`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -159,8 +159,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
This unit test verifies the ``quantize per channel`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
|
Reference in New Issue
Block a user