Compare commits

...

1 Commits

Author SHA1 Message Date
cfcb847623 Fix accumulation 2023-06-21 11:00:31 -04:00

View File

@ -205,7 +205,7 @@ if is_peft_available():
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import DistributedDataParallelKwargs
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
@ -3843,9 +3843,11 @@ class Trainer:
def create_accelerator_and_postprocess(self):
# create accelerator object
accumulation_plugin = GradientAccumulationPlugin(
gradient_accumulation_steps=self.args.gradient_accumulation_steps, sync_with_dataloader=False
)
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=accumulation_plugin
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher