mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
fix: model.set_requires_gradient_sync(False) should be called to turn off gradient synchronization in FSDP2 (#3762)
* fix :`model.set_requires_gradient_sync(False)` should be called to turn off gradient synchronization in FSDP2. * fix: remove trailing whitespace
This commit is contained in:
@ -1169,13 +1169,20 @@ class Accelerator:
|
||||
>>> optimizer.zero_grad()
|
||||
```
|
||||
"""
|
||||
context = contextlib.nullcontext
|
||||
if self.use_distributed:
|
||||
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
|
||||
context = getattr(model, "no_sync", context)
|
||||
if self.is_fsdp2:
|
||||
model.set_requires_gradient_sync(False)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.set_requires_gradient_sync(True)
|
||||
else:
|
||||
context = contextlib.nullcontext
|
||||
if self.use_distributed:
|
||||
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
|
||||
context = getattr(model, "no_sync", context)
|
||||
|
||||
with context():
|
||||
yield
|
||||
with context():
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
|
Reference in New Issue
Block a user