mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DDP][PT2D] Fix some tracing bugs of DDP (#124421)
1. We need to clear the cache of get_legacy_mod_inlinelist to ensure the newly added rule will be captured. 2. Don't add the hook if the parameter does not require gradient. Differential Revision: [D56315534](https://our.internmc.facebook.com/intern/diff/D56315534/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124421 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a5b4d2403
commit
7b6e354ecd
@ -894,6 +894,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
|
||||
"torch.nn.parallel.distributed"
|
||||
)
|
||||
torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
|
||||
self._force_to_disable_cpp_reducer = (
|
||||
optimize_ddp == "python_reducer_without_compiled_forward"
|
||||
)
|
||||
@ -926,6 +927,8 @@ class DistributedDataParallel(Module, Joinable):
|
||||
param.grad.copy_(gradient)
|
||||
|
||||
for index, param in enumerate(self._module_parameters):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
self._accum_grad_hooks.append(
|
||||
param.register_post_accumulate_grad_hook(
|
||||
functools.partial(
|
||||
|
Reference in New Issue
Block a user