mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
style
This commit is contained in:
@ -209,7 +209,6 @@ if is_peft_available():
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator, skip_first_batches
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import (
|
||||
DataLoaderConfiguration,
|
||||
@ -4967,10 +4966,12 @@ class Trainer:
|
||||
# this would have been updated above, no need for it anymore
|
||||
accelerator_config.pop("gradient_accumulation_kwargs")
|
||||
|
||||
args = {"mixed_precision": self.args.mixed_precision,
|
||||
"dataloader_config": dataloader_config,
|
||||
"fsdp_plugin": self.args.fsdp_plugin,
|
||||
"deepspeed_plugin": self.args.deepspeed_plugin}
|
||||
args = {
|
||||
"mixed_precision": self.args.mixed_precision,
|
||||
"dataloader_config": dataloader_config,
|
||||
"fsdp_plugin": self.args.fsdp_plugin,
|
||||
"deepspeed_plugin": self.args.deepspeed_plugin,
|
||||
}
|
||||
|
||||
# We defer compatibility checks to accelerator
|
||||
if self.args.parallelism_config is not None:
|
||||
@ -4995,9 +4996,12 @@ class Trainer:
|
||||
if is_accelerate_available("1.2.0"):
|
||||
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
|
||||
from accelerate.utils import TorchDynamoPlugin
|
||||
dynamo_plugin = TorchDynamoPlugin(backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
|
||||
|
||||
dynamo_plugin = TorchDynamoPlugin(
|
||||
backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode
|
||||
)
|
||||
args["dynamo_plugin"] = dynamo_plugin
|
||||
|
||||
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator(**args)
|
||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||
|
@ -1540,7 +1540,7 @@ class TrainingArguments:
|
||||
self.mixed_precision = "fp16"
|
||||
elif self.bf16:
|
||||
self.mixed_precision = "bf16"
|
||||
|
||||
|
||||
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
|
||||
self.torch_compile = True
|
||||
if self.torch_compile and self.torch_compile_backend is None:
|
||||
@ -1555,7 +1555,7 @@ class TrainingArguments:
|
||||
os.environ["ACCELERATE_DYNAMO_BACKEND"] = self.torch_compile_backend
|
||||
if self.torch_compile_mode is not None:
|
||||
os.environ["ACCELERATE_DYNAMO_MODE"] = self.torch_compile_mode
|
||||
|
||||
|
||||
# We need to setup the accelerator config here *before* the first call to `self.device`
|
||||
if is_accelerate_available():
|
||||
if not isinstance(self.accelerator_config, AcceleratorConfig):
|
||||
@ -1656,8 +1656,9 @@ class TrainingArguments:
|
||||
if fsdp_plugin_args is not None:
|
||||
# Accelerate FSDP Plugin
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
|
||||
self.fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_plugin_args)
|
||||
|
||||
|
||||
self.deepspeed_plugin = None
|
||||
if self.deepspeed:
|
||||
# - must be run very last in arg parsing, since it will use a lot of these settings.
|
||||
@ -2643,7 +2644,6 @@ class TrainingArguments:
|
||||
self.ignore_data_skip = ignore_data_skip
|
||||
self.data_seed = sampler_seed
|
||||
return self
|
||||
|
||||
|
||||
def _process_fsdp_args(self):
|
||||
if self.fsdp is None:
|
||||
@ -2759,12 +2759,11 @@ class TrainingArguments:
|
||||
# we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
|
||||
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
|
||||
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
|
||||
|
||||
|
||||
fsdp_plugin_args["sync_module_states"] = sync_module_states
|
||||
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
|
||||
|
||||
|
||||
return fsdp_plugin_args
|
||||
|
||||
|
||||
|
||||
class ParallelMode(Enum):
|
||||
|
Reference in New Issue
Block a user