[DDP][PT2D] Fix some tracing bugs of DDP (#124421)

1. We need to clear the cache of get_legacy_mod_inlinelist to ensure the newly added rule will be captured.
2. Don't add the hook if the parameter does not require gradient.

Differential Revision: [D56315534](https://our.internmc.facebook.com/intern/diff/D56315534/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124421
Approved by: https://github.com/yf225
This commit is contained in:
Chien-Chin Huang
2024-04-22 15:10:12 -07:00
committed by PyTorch MergeBot
parent 9a5b4d2403
commit 7b6e354ecd

View File

@ -894,6 +894,7 @@ class DistributedDataParallel(Module, Joinable):
torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
"torch.nn.parallel.distributed"
)
torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear()
self._force_to_disable_cpp_reducer = (
optimize_ddp == "python_reducer_without_compiled_forward"
)
@ -926,6 +927,8 @@ class DistributedDataParallel(Module, Joinable):
param.grad.copy_(gradient)
for index, param in enumerate(self._module_parameters):
if not param.requires_grad:
continue
self._accum_grad_hooks.append(
param.register_post_accumulate_grad_hook(
functools.partial(