fix: model.set_requires_gradient_sync(False) should be called to turn off gradient synchronization in FSDP2 (#3762)

* fix :`model.set_requires_gradient_sync(False)` should be called to turn off gradient synchronization in FSDP2.

* fix: remove trailing whitespace
This commit is contained in:
Walker
2025-09-07 05:57:46 +08:00
committed by GitHub
parent 62ede1ed2a
commit ec92b1af7a

View File

@ -1169,13 +1169,20 @@ class Accelerator:
>>> optimizer.zero_grad()
```
"""
context = contextlib.nullcontext
if self.use_distributed:
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
context = getattr(model, "no_sync", context)
if self.is_fsdp2:
model.set_requires_gradient_sync(False)
try:
yield
finally:
model.set_requires_gradient_sync(True)
else:
context = contextlib.nullcontext
if self.use_distributed:
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
context = getattr(model, "no_sync", context)
with context():
yield
with context():
yield
@staticmethod
@contextmanager