From ec92b1af7a5fbaf362da04401fc2576570368a1c Mon Sep 17 00:00:00 2001 From: Walker <33346657+EquationWalker@users.noreply.github.com> Date: Sun, 7 Sep 2025 05:57:46 +0800 Subject: [PATCH] fix: model.set_requires_gradient_sync(False) should be called to turn off gradient synchronization in FSDP2 (#3762) * fix :`model.set_requires_gradient_sync(False)` should be called to turn off gradient synchronization in FSDP2. * fix: remove trailing whitespace --- src/accelerate/accelerator.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 2c925eff..ee787572 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1169,13 +1169,20 @@ class Accelerator: >>> optimizer.zero_grad() ``` """ - context = contextlib.nullcontext - if self.use_distributed: - if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2: - context = getattr(model, "no_sync", context) + if self.is_fsdp2: + model.set_requires_gradient_sync(False) + try: + yield + finally: + model.set_requires_gradient_sync(True) + else: + context = contextlib.nullcontext + if self.use_distributed: + if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2: + context = getattr(model, "no_sync", context) - with context(): - yield + with context(): + yield @staticmethod @contextmanager