This commit is contained in:
Marc Sun
2025-10-17 15:53:49 +00:00
parent 9e7a80a4e9
commit a4c7a0f2fd
2 changed files with 17 additions and 14 deletions

View File

@ -209,7 +209,6 @@ if is_peft_available():
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState
from accelerate.utils import (
DataLoaderConfiguration,
@ -4967,10 +4966,12 @@ class Trainer:
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
args = {"mixed_precision": self.args.mixed_precision,
"dataloader_config": dataloader_config,
"fsdp_plugin": self.args.fsdp_plugin,
"deepspeed_plugin": self.args.deepspeed_plugin}
args = {
"mixed_precision": self.args.mixed_precision,
"dataloader_config": dataloader_config,
"fsdp_plugin": self.args.fsdp_plugin,
"deepspeed_plugin": self.args.deepspeed_plugin,
}
# We defer compatibility checks to accelerator
if self.args.parallelism_config is not None:
@ -4995,9 +4996,12 @@ class Trainer:
if is_accelerate_available("1.2.0"):
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
from accelerate.utils import TorchDynamoPlugin
dynamo_plugin = TorchDynamoPlugin(backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
dynamo_plugin = TorchDynamoPlugin(
backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode
)
args["dynamo_plugin"] = dynamo_plugin
# create accelerator object
self.accelerator = Accelerator(**args)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag

View File

@ -1540,7 +1540,7 @@ class TrainingArguments:
self.mixed_precision = "fp16"
elif self.bf16:
self.mixed_precision = "bf16"
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
self.torch_compile = True
if self.torch_compile and self.torch_compile_backend is None:
@ -1555,7 +1555,7 @@ class TrainingArguments:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = self.torch_compile_backend
if self.torch_compile_mode is not None:
os.environ["ACCELERATE_DYNAMO_MODE"] = self.torch_compile_mode
# We need to setup the accelerator config here *before* the first call to `self.device`
if is_accelerate_available():
if not isinstance(self.accelerator_config, AcceleratorConfig):
@ -1656,8 +1656,9 @@ class TrainingArguments:
if fsdp_plugin_args is not None:
# Accelerate FSDP Plugin
from accelerate.utils import FullyShardedDataParallelPlugin
self.fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_plugin_args)
self.deepspeed_plugin = None
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
@ -2643,7 +2644,6 @@ class TrainingArguments:
self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed
return self
def _process_fsdp_args(self):
if self.fsdp is None:
@ -2759,12 +2759,11 @@ class TrainingArguments:
# we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
fsdp_plugin_args["sync_module_states"] = sync_module_states
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
return fsdp_plugin_args
class ParallelMode(Enum):