Compare commits

...

2 Commits

5 changed files with 330 additions and 220 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations
import os
from typing import TYPE_CHECKING
from ..utils import is_torch_available
@ -31,3 +32,24 @@ def is_fsdp_managed_module(module: nn.Module) -> bool:
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
module, "_is_fsdp_managed_module", False
)
def enable_cpu_ram_efficient_loading():
"""
Enable CPU RAM efficient loading of model weights by setting `FSDP_CPU_RAM_EFFICIENT_LOADING`.
"""
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
def disable_cpu_ram_efficient_loading():
"""
Disable CPU RAM efficient loading of model weights by unsetting `FSDP_CPU_RAM_EFFICIENT_LOADING`.
"""
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "false"
def set_cpu_ram_efficient_loading(value: bool):
"""
Set CPU RAM efficient loading of model weights by setting `FSDP_CPU_RAM_EFFICIENT_LOADING`.
"""
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(bool(value)).lower()

View File

@ -41,6 +41,7 @@ from torch.utils.data.distributed import DistributedSampler
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import (
is_accelerate_available,
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_xla_available,
@ -49,6 +50,10 @@ from .utils import (
)
if is_accelerate_available():
from accelerate.utils import FullyShardedDataParallelPlugin, TorchDynamoPlugin
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
@ -1312,6 +1317,27 @@ class AcceleratorConfig:
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
},
)
mixed_precision: str = field(
default=None,
metadata={
"help": "The mixed precision policy to use. If not set, the policy will be determined by the `ACCELERATE_MIXED_PRECISION` environment variable. "
"Should not be passed in through a config file."
},
)
dynamo_plugin: Optional["TorchDynamoPlugin"] = field( # noqa: F821
default=None,
metadata={
"help": "The dynamo config to use. If not set, the config will be determined by the `ACCELERATE_DYNAMO_CONFIG` environment variable. "
"Should not be passed in through a config file."
},
)
fsdp_plugin: Optional["FullyShardedDataParallelPlugin"] = field( # noqa: F821
default=None,
metadata={
"help": "The FSDP config to use. If not set, the config will be determined by environmental variables set during `accelerate launch`. "
"Should not be passed in through a config file."
},
)
use_configured_state: bool = field(
default=False,
metadata={
@ -1333,6 +1359,13 @@ class AcceleratorConfig:
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
" version or fix (and potentially remove these keys) from your config file."
)
# Check for fields that should not be set in config file
invalid_fields = ["mixed_precision", "fsdp_plugin", "dynamo_plugin"]
for field in invalid_fields:
if config_dict.get(field) is not None:
raise ValueError(
f"The `{field}` field should not be set in a config file. It is determined by the TrainingArguments."
)
return cls(**config_dict)
def to_dict(self):

View File

@ -401,26 +401,14 @@ class TrainingArguments:
use_ipex (`bool`, *optional*, defaults to `False`):
Use Intel extension for PyTorch when it is available. [IPEX
installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
the [Apex documentation](https://nvidia.github.io/apex/amp).
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
fp16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values.
mixed_precision_dtype (`str`, *optional*, defaults to `"no"`):
The type of mixed precision to use. Can be one of `"no"`, `"fp16"`, `"bf16"`, or `"fp8"`.
mixed_precision_config (`dict`, *optional*):
A dictionary of configuration options for mixed precision training. Valid keys are:
- fp16_opt_level
- backend
- full_eval: Whether to use full `mixed_precision_dtype` evaluation instead of 32-bit.
This will be faster and save memory but can harm metric values.
tf32 (`bool`, *optional*):
Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
@ -1054,48 +1042,19 @@ class TrainingArguments:
)
},
)
bf16: bool = field(
default=False,
mixed_precision_dtype: Optional[str] = field(
default=None,
metadata={"help": "Mixed precision dtype to use. Can be one of `'no'`, `'fp16'`, `'bf16'`, or `'fp8'`."},
)
mixed_precision_config: Optional[dict] = field(
default=None,
metadata={
"help": (
"Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
" architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
"A dictionary of configuration options for mixed precision training. Valid keys are: "
"`fp16_opt_level`, `backend`, `full_eval`."
)
},
)
fp16: bool = field(
default=False,
metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
)
fp16_opt_level: str = field(
default="O1",
metadata={
"help": (
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
"See details at https://nvidia.github.io/apex/amp.html"
)
},
)
half_precision_backend: str = field(
default="auto",
metadata={
"help": "The backend to be used for half precision.",
"choices": ["auto", "apex", "cpu_amp"],
},
)
bf16_full_eval: bool = field(
default=False,
metadata={
"help": (
"Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
" change."
)
},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: Optional[bool] = field(
default=None,
metadata={
@ -1394,6 +1353,31 @@ class TrainingArguments:
"choices": ["auto", "apex", "cpu_amp"],
},
)
bf16: bool = field(
default=False,
metadata={"help": "Deprecated. Use `mixed_precision_dtype='bf16'` instead."},
)
fp16: bool = field(
default=False,
metadata={"help": "Deprecated. Use `mixed_precision_dtype='fp16'` instead."},
)
fp16_opt_level: str = field(
default="O1",
metadata={"help": "Deprecated. Use `mixed_precision_config` and set `fp16_opt_level` instead."},
)
half_precision_backend: str = field(
default="auto",
metadata={"help": "Deprecated. Use `mixed_precision_config` and set `backend` instead."},
)
bf16_full_eval: bool = field(
default=False,
metadata={"help": "Deprecated. Use `mixed_precision_config` and set `full_eval` instead."},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Deprecated. Use `mixed_precision_config` and set `full_eval` instead."},
)
evaluation_strategy: Union[IntervalStrategy, str] = field(
default=None,
metadata={"help": "Deprecated. Use `eval_strategy` instead"},
@ -1589,6 +1573,9 @@ class TrainingArguments:
)
self.use_cpu = self.no_cuda
if self.mixed_precision_config is None:
self.mixed_precision_config = {}
self.eval_strategy = IntervalStrategy(self.eval_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)
self.save_strategy = SaveStrategy(self.save_strategy)
@ -1682,10 +1669,10 @@ class TrainingArguments:
if self.fp16_backend and self.fp16_backend != "auto":
warnings.warn(
"`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `half_precision_backend` instead",
" `mixed_precision_config` and set `backend` instead",
FutureWarning,
)
self.half_precision_backend = self.fp16_backend
self.mixed_precision_config["backend"] = self.fp16_backend
if self.bf16 or self.bf16_full_eval:
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available():
@ -1737,13 +1724,60 @@ class TrainingArguments:
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `torch_compile_backend` instead",
FutureWarning,
)
self.torch_compile_backend = self.torchdynamo
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
self.torch_compile = True
if self.torch_compile and self.torch_compile_backend is None:
self.torch_compile_backend = "inductor"
# accelerate integration for torch compile
dynamo_plugin = None
if self.torch_compile:
# set env vars for accelerate
# TODO: remove this once we've bumped the minimum accelerate version
if not is_accelerate_available("1.2.0"):
os.environ["ACCELERATE_DYNAMO_BACKEND"] = self.torch_compile_backend
if self.torch_compile_mode is not None:
os.environ["ACCELERATE_DYNAMO_MODE"] = self.torch_compile_mode
else:
from accelerate.utils import TorchDynamoPlugin
dynamo_plugin = TorchDynamoPlugin(backend=self.torch_compile_backend, mode=self.torch_compile_mode)
# Process FSDP before we proceed
fsdp_plugin_args = self._process_fsdp_plugin()
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
self.fp16 = mixed_precision_dtype == "fp16"
self.bf16 = mixed_precision_dtype == "bf16"
# We need to setup the accelerator config here *before* the first call to `self.device`
if is_accelerate_available():
core_accelerate_config_args = {
"dynamo_plugin": dynamo_plugin,
"fsdp_plugin": fsdp_plugin_args,
"mixed_precision": mixed_precision_dtype,
}
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
self.accelerator_config = AcceleratorConfig(**core_accelerate_config_args)
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
self.accelerator_config = AcceleratorConfig(
**{**core_accelerate_config_args, **self.accelerator_config}
)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
@ -1752,7 +1786,11 @@ class TrainingArguments:
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
# Load config from JSON file and update with core args
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
self.accelerator_config.__dict__.update(
{k: v for k, v in core_accelerate_config_args.items() if v is not None}
)
if self.dispatch_batches is not None:
warnings.warn(
@ -1787,26 +1825,6 @@ class TrainingArguments:
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
self.average_tokens_across_devices = False
if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `torch_compile_backend` instead",
FutureWarning,
)
self.torch_compile_backend = self.torchdynamo
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
self.torch_compile = True
if self.torch_compile and self.torch_compile_backend is None:
self.torch_compile_backend = "inductor"
# accelerate integration for torch compile
if self.torch_compile:
# set env vars for accelerate
prefix = "ACCELERATE_DYNAMO_"
os.environ[prefix + "BACKEND"] = self.torch_compile_backend
if self.torch_compile_mode is not None:
os.environ[prefix + "MODE"] = self.torch_compile_mode
if self.framework == "pt" and is_torch_available() and self.torch_compile:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
@ -1833,15 +1851,6 @@ class TrainingArguments:
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
@ -1877,131 +1886,6 @@ class TrainingArguments:
if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0:
raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.")
if isinstance(self.fsdp, bool):
self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
raise ValueError(
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
)
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)
if self.fsdp_config is None:
self.fsdp_config = {}
if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
v = self.fsdp_config.pop(k)
self.fsdp_config[k[5:]] = v
if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if (
len(self.fsdp) > 0
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary
# Copy the config to avoid modifying the original config (which may be used for JSON serialization)
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
# apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
if "buffer_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
else:
warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
else:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)
prefix = "FSDP_"
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["min_num_params"] > 0:
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
# In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
@ -2037,15 +1921,13 @@ class TrainingArguments:
# Accelerate DeepSpeed Plugin
from accelerate.utils import DeepSpeedPlugin
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")):
# Accelerate DeepSpeed Plugin
from accelerate.utils import DeepSpeedPlugin
self.deepspeed_plugin = DeepSpeedPlugin()
mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
self.deepspeed_plugin.set_mixed_precision(mixed_precision)
self.deepspeed_plugin.set_mixed_precision(self.mixed_precision_dtype)
self.deepspeed_plugin.set_deepspeed_weakref()
if self.use_cpu:
@ -3091,6 +2973,136 @@ class TrainingArguments:
self.data_seed = sampler_seed
return self
def _process_fsdp_plugin(self):
if isinstance(self.fsdp, bool):
self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
raise ValueError(
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
)
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)
if self.fsdp_config is None:
self.fsdp_config = {}
if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
v = self.fsdp_config.pop(k)
self.fsdp_config[k[5:]] = v
if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if (
len(self.fsdp) > 0
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary
# Copy the config to avoid modifying the original config (which may be used for JSON serialization)
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
# apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
if "buffer_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
else:
warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
else:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
fsdp_plugin_args = None
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
self.use_fsdp = True
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)
from transformers.integrations.fsdp import set_cpu_ram_efficient_loading
# Because `self.fsdp` is a list or dict, we need to convert it to kwargs for FSDPPlugin
fsdp_plugin_args = {}
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
fsdp_plugin_args["sharding_strategy"] = fsdp_option
elif fsdp_option == FSDPOption.OFFLOAD:
fsdp_plugin_args["offload_params"] = True
elif fsdp_option == FSDPOption.AUTO_WRAP:
fsdp_plugin_args["auto_wrap_policy"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["min_num_params"] > 0:
fsdp_plugin_args["min_num_params"] = self.fsdp_config["min_num_params"]
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
fsdp_plugin_args["transformer_layer_cls_to_wrap"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
# In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
fsdp_plugin_args["sync_module_states"] = sync_module_states
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
# To trickle through to model init, we need to set the environmental variable here
set_cpu_ram_efficient_loading(cpu_ram_efficient_loading)
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
return fsdp_plugin_args
class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel"

View File

@ -72,14 +72,15 @@ class TestFSDPTrainer(TestCasePlus):
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--fsdp",
"full_shard",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",

View File

@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tempfile import TemporaryDirectory
from transformers import TrainingArguments
from transformers.testing_utils import TestCasePlus, is_accelerate_available, require_accelerate
if is_accelerate_available():
from accelerate.utils import patch_environment
@require_accelerate
class TrainingArgsTest(TestCasePlus):
"""
Tests the core `TrainingArguments` class for pre and post processing.
"""
def test_mixed_precision(self):
with TemporaryDirectory() as temp_dir:
# First with no env
TrainingArguments(fp16=True, output_dir=temp_dir)
args = TrainingArguments(output_dir=temp_dir, fp16=False)
self.assertEqual(args.fp16, False)
# Then with env
with patch_environment(accelerate_mixed_precision="fp16"):
args = TrainingArguments(output_dir=temp_dir)
self.assertEqual(args.fp16, True)