Compare commits

...

5 Commits

Author SHA1 Message Date
e91acff698 Finish launch dataclasses 2023-10-06 14:07:23 -04:00
7642a84920 Bookmark 2023-10-05 15:41:30 -04:00
39ed4554a2 Futher along 2023-10-05 11:58:52 -04:00
d4debcea79 Working API 2023-10-05 10:45:03 -04:00
58a8198c5c Start of new CLI process method 2023-10-03 15:00:28 +00:00
4 changed files with 435 additions and 432 deletions

View File

@ -14,12 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from dataclasses import dataclass
from typing import Literal
from huggingface_hub import model_info
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from accelerate import init_empty_weights
from accelerate.utils import (
Arguments,
calculate_maximum_sizes,
convert_bytes,
is_timm_available,
@ -177,6 +180,31 @@ def create_ascii_table(headers: list, rows: list, title: str):
return table
@dataclass
class EstimateArguments(Arguments):
"""
Arguments for the `accelerate estimate` command.
Args:
model_name (`str`):
The model name on the Hugging Face Hub.
library_name (`str`):
The library the model has an integration with, such as `transformers`, needed only if this information is
not stored on the Hub. Must be one of `timm` or `transformers`.
dtypes (`list[str]`, `optional`, defaults to `["float32", "float16", "int8", "int4"]`):
The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`.
trust_remote_code (`bool`, `optional`, defaults to `False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag should
only be used for repositories you trust and in which you have read the code, as it will execute code
present on the Hub on your local machine.
"""
model_name: str
library_name: Literal["timm", "transformers"]
dtypes: list[str]
trust_remote_code: bool = False
def estimate_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("estimate-memory")
@ -207,11 +235,11 @@ def estimate_command_parser(subparsers=None):
)
if subparsers is not None:
parser.set_defaults(func=estimate_command)
parser.set_defaults(func=estimate_memory)
return parser
def gather_data(args):
def gather_data(args: EstimateArguments):
"Creates an empty model and gathers the data for the sizes"
try:
model = create_empty_model(
@ -246,7 +274,7 @@ def gather_data(args):
return data
def estimate_command(args):
def estimate_memory(args: EstimateArguments):
data = gather_data(args)
for row in data:
for i, item in enumerate(row):
@ -263,7 +291,7 @@ def estimate_command(args):
def main():
parser = estimate_command_parser()
args = parser.parse_args()
estimate_command(args)
estimate_memory(args)
if __name__ == "__main__":

View File

@ -20,16 +20,18 @@ import logging
import os
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar, Literal
import psutil
import torch
from accelerate.commands.config import default_config_file, load_config_from_file
from accelerate.commands.config.config_args import SageMakerConfig
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
from accelerate.state import get_int_from_env
from accelerate.utils import (
Arguments,
ComputeEnvironment,
DistributedType,
PrepareForLaunch,
@ -49,7 +51,7 @@ from accelerate.utils import (
prepare_simple_launcher_cmd_env,
prepare_tpu,
)
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
if is_rich_available():
@ -125,6 +127,305 @@ class _CustomHelpAction(argparse._HelpAction):
super().__call__(parser, namespace, values, option_string)
@dataclass
class ResourceArguments(Arguments):
"""
Arguments for fine-tuning what and how available hardware should be used.
Args:
cpu (`bool`, *optional*, defaults to `False`):
Whether or not to force the training on the CPU.
multi_gpu (`bool`, *optional*, defaults to `False`):
Whether or not this should launch a distributed GPU training.
tpu (`bool`, *optional*, defaults to `False`):
Whether or not this should launch a TPU training.
ipex (`bool`, *optional*, defaults to `False`):
Whether or not this should launch a Intel PyTorch Extension (IPEX) training.
mixed_precision (`str`, *optional*, defaults to `no`):
Whether or not to use mixed precision training. Choose between FP16, BF16 (bfloat16) or FP8 training. BF16
training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.
num_processes (`int`, *optional*, defaults to `None`):
The total number of processes to be launched in parallel.
num_machines (`int`, *optional*, defaults to `None`):
The total number of machines used in this training.
num_cpu_threads_per_process (`int`, *optional*, defaults to `None`):
The number of CPU threads per process. Can be tuned for optimal performance.
use_deepspeed (`bool`, *optional*, defaults to `False`):
Whether to use deepspeed.
use_fsdp (`bool`, *optional*, defaults to `False`):
Whether to use fsdp.
use_megatron_lm (`bool`, *optional*, defaults to `False`):
Whether to use Megatron-LM.
use_xpu (`bool`, *optional*, defaults to `False`):
Whether to use IPEX plugin to speed up training on XPU specifically.
"""
cpu: bool = False
multi_gpu: bool = False
tpu: bool = False
ipex: bool = False
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
num_processes: int = None
num_machines: int = None
num_cpu_threads_per_process: int = None
use_deepspeed: bool = False
use_fsdp: bool = False
use_megatron_lm: bool = False
use_xpu: bool = False
@dataclass
class DynamoArguments(Arguments):
"""
Arguments related to `torchdynamo`
Args:
backend (`str`):
Backend to optimize your training with dynamo, see more at https://github.com/pytorch/torchdynamo.
mode (`str`, *optional*, defaults to "default"):
Mode to optimize your training with dynamo.
use_fullgraph (`bool`, *optional*):
Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs.
use_dynamic (`bool`, *optional*):
Whether to enable dynamic shape tracing.
"""
prefix: ClassVar[str] = "dynamo_"
backend: Literal[
"no",
"eager",
"aot_eager",
"inductor",
"nvfuser",
"aot_nvfuser",
"aot_cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"ipex",
] = "no"
mode: Literal["default", "reduce-overhead", "max-autotune"] = "default"
use_fullgraph: bool = False
use_dynamic: bool = False
@dataclass
class CUDAArguments(Arguments):
"""
Arguments related to CUDA usage.
Args:
gpu_ids (`str`):
What GPUs (by id) should be used for training on this machine as a comma-seperated list.
same_network (`bool`):
Whether all machines used for multinode training exist on the same local network.
machine_rank (`int`):
The rank of the machine on which this script is launched.
main_process_ip (`str`):
The IP address of the machine of rank 0.
main_process_port (`int`):
The port to use to communicate with the machine of rank 0.
tee (`str`, *optional*, defaults to "0"):
Tee std streams into a log file and also to console.
role (`str`, *optional*, defaults to "default"):
User-defined role for the workers.
rdzv_backend (`str`, *optional*, defaults to "static"):
The rendezvous method to use, such as "static" or "c10d".
rdzv_conf (`str`, *optional*, defaults to ""):
Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).
max_restarts (`int`, *optional*, defaults to 0):
Maximum number of worker group restarts before failing.
monitor_interval (`float`, *optional*, defaults to 5.0):
Interval, in seconds, to monitor the state of workers.
"""
gpu_ids: str = None
same_network: bool = False
machine_rank: int = None
main_process_ip: str = None
main_process_port: int = None
tee: str = "0"
role: str = "default"
rdzv_backend: Literal["static", "c10d"] = "static"
rdzv_conf: str = ""
max_restarts: int = 0
monitor_interval: float = 5.0
@dataclass
class TPUArguments(Arguments):
"""
Arguments related to TPU usage.
Args:
tpu_cluster (`bool`):
Whether to use a GCP TPU pod for training.
tpu_use_sudo (`bool`):
Whether to use `sudo` when running the TPU training script in each pod.
vm (list of `str`):
List of single Compute VM instance names. If not provided we assume usage of instance groups. For TPU pods.
env (list of `str`):
List of environment variables to set on the Compute VM instances. For TPU pods.
main_training_function (`str`):
The name of the main function to be executed in your script (only for TPU training).
downcast_bf16 (`bool`):
Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if
double tensors remain as float32.
"""
tpu_cluster: bool = False
tpu_use_sudo: bool = False
vm: list[str] = field(default_factory=list)
env: list[str] = field(default_factory=list)
main_training_function: str = None
downcast_bf16: bool = False
@dataclass
class DeepSpeedArguments(Arguments):
"""
Arguments related to DeepSpeed
Args:
deepspeed_config_file (`str`, *optional*):
DeepSpeed config file to use.
zero_stage (`int`, *optional*, defaults to 2):
DeepSpeed's ZeRO optimization stage.
offload_optimizer_device (`str`, *optional*, defaults to "none"):
Decides where (none|cpu|nvme) to offload optimizer states.
offload_param_device (`str`, *optional*, defaults to "none"):
Decides where (none|cpu|nvme) to offload parameters.
offload_optimizer_nvme_path (`str`, *optional*, defaults to "none"):
Decides Nvme Path to offload optimizer states.
offload_param_nvme_path (`str`, *optional*, defaults to "none"):
Decides Nvme Path to offload parameters.
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
Number of gradient_accumulation_steps used in your training script when using deepspeed.
gradient_clipping (`float`, *optional*, defaults to 1.0):
Gradient clipping value used in your training script when using deepspeed.
zero3_init_flag (`bool`, *optional*):
Whether to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with DeepSpeed
ZeRO Stage-3.
zero3_save_16bit_model (`bool`, *optional*):
Whether to save 16-bit model weights when using ZeRO Stage-3. Only applicable with DeepSpeed ZeRO Stage-3.
deepspeed_hostfile (`str`, *optional*):
DeepSpeed hostfile for configuring multi-node compute resources.
deepspeed_exclusion_filter (`str`, *optional*):
DeepSpeed exclusion filter string when using mutli-node setup.
deepspeed_inclusion_filter (`str`, *optional*):
DeepSpeed inclusion filter string when using mutli-node setup.
deepspeed_multinode_launcher (`str`, *optional*, defaults to "pdsh"):
DeepSpeed multi-node launcher to use.
"""
config_file: str = None
zero_stage: int = 2
offload_optimizer_device: Literal["none", "cpu", "nvme"] = "none"
offload_param_device: Literal["none", "cpu", "nvme"] = "none"
offload_optimizer_nvme_path: str = "none"
offload_param_nvme_path: str = "none"
gradient_accumulation_steps: int = 1
gradient_clipping: float = 1.0
zero3_init_flag: bool = True
zero3_save_16bit_model: bool = False
deepspeed_hostfile: str = None
deepspeed_exclusion_filter: str = None
deepspeed_inclusion_filter: str = None
deepspeed_multinode_launcher: Literal["pdsh", "standard", "openmpi", "mvapich", "mpich"] = "pdsh"
@dataclass
class FSDPArguments(Arguments):
"""
Arguments related to Fully Shared Data Parallelism (FSDP)
Args:
offload_params (`bool`, *optional*):
Decides whether to offload parameters and gradients to CPU.
min_num_params (`int`, *optional*, defaults to 1e8):
FSDP's minimum number of parameters for Default Auto Wrapping.
sharding_strategy (`int`, *optional*, defaults to 1):
FSDP's Sharding Strategy.
auto_wrap_policy (`str`, *optional*):
FSDP's auto wrap policy.
transformer_layer_cls_to_wrap (`str`, *optional*):
Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` ....
backward_prefetch_policy (`str`, *optional*):
FSDP's backward prefetch policy.
state_dict_type (`str`, *optional*):
FSDP's state dict type.
forward_prefetch (`bool`, *optional*):
Whether to explicitly prefetch the next upcoming all-gather while executing in the forward pass.
use_orig_params (`bool`, *optional*):
Whether to allow non-uniform `requires_grad` during init, which means support for interspersed frozen and
trainable parameters.
sync_module_states (`bool`, *optional*, defaults to `True`):
Whether to broadcast module parameters from rank 0.
"""
prefix: ClassVar[str] = "fsdp_"
offload_params: bool = False
min_num_params: int = 1e8
sharding_strategy: int = 1
auto_wrap_policy: str = None
transformer_layer_cls_to_wrap: str = None
backward_prefetch_policy: str = None
state_dict_type: str = None
forward_prefetch: bool = False
use_orig_params: bool = False
sync_module_states: bool = True
@dataclass
class MegatronLMArguments(Arguments):
"""
Arguments related to MegaTron-LM
Args:
tp_degree (`int`, *optional*, defaults to 1):
Tensor Parallelism (TP) degree.
pp_degree (`int`, *optional*, defaults to 1):
Pipeline Parallelism (PP) degree.
num_micro_batches (`int`, *optional*):
Number of micro batches when `pp_degree` > 1.
sequence_parallelism (`bool`, *optional*):
Whether to enable Sequence Parallelism when `tp_degree` > 1.
recompute_activations (`bool`, *optional*):
Whether to enable Selective Activation Recomputation.
use_distributed_optimizer (`bool`, *optional*):
Whether to use distributed optimizer which shards optimizer state and gradients across Data Pralellel (DP)
ranks.
gradient_clipping (`float`, *optional*, defaults to 1.0):
Gradient clipping value based on global L2 Norm (0 to disable).
"""
prefix: ClassVar[str] = "megatron_lm_"
tp_degree: int = 1
pp_degree: int = 1
num_micro_batches: int = None
sequence_parallelism: bool = None
recompute_activations: bool = None
use_distributed_optimizer: bool = None
gradient_clipping: float = 1.0
@dataclass
class AWSArguments(Arguments):
"""
Arguments related to AWS
Args:
access_key_id (`str`, *optional*):
The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job.
secret_access_key (`str`, *optional*):
The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.
"""
prefix: ClassVar[str] = "aws_"
access_key_id: str = None
secret_access_key: str = None
def launch_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("launch", add_help=False, allow_abbrev=False)
@ -135,182 +436,16 @@ def launch_command_parser(subparsers=None):
parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
parser.add_argument(
"--config_file", default=None, help="The config file to use for the default values in the launching script."
"--config_file",
type=str,
default=None,
help="The config file to use for the default values in the launching script.",
)
parser.add_argument(
"--quiet",
"-q",
action="store_true",
help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
)
# Hardware selection arguments
hardware_args = parser.add_argument_group(
"Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
)
hardware_args.add_argument(
"--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
)
hardware_args.add_argument(
"--multi_gpu",
default=False,
action="store_true",
help="Whether or not this should launch a distributed GPU training.",
)
hardware_args.add_argument(
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
)
hardware_args.add_argument(
"--ipex",
default=False,
action="store_true",
help="Whether or not this should launch a Intel PyTorch Extension (IPEX) training.",
)
# Resource selection arguments
resource_args = parser.add_argument_group(
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
)
resource_args.add_argument(
"--mixed_precision",
type=str,
choices=["no", "fp16", "bf16", "fp8"],
help="Whether or not to use mixed precision training. "
"Choose between FP16 and BF16 (bfloat16) training. "
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
)
resource_args.add_argument(
"--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
)
resource_args.add_argument(
"--num_machines", type=int, default=None, help="The total number of machines used in this training."
)
resource_args.add_argument(
"--num_cpu_threads_per_process",
type=int,
default=None,
help="The number of CPU threads per process. Can be tuned for optimal performance.",
)
# Dynamo arguments
resource_args.add_argument(
"--dynamo_backend",
type=str,
choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
help="Choose a backend to optimize your training with dynamo, see more at "
"https://github.com/pytorch/torchdynamo.",
)
resource_args.add_argument(
"--dynamo_mode",
type=str,
default="default",
choices=TORCH_DYNAMO_MODES,
help="Choose a mode to optimize your training with dynamo.",
)
resource_args.add_argument(
"--dynamo_use_fullgraph",
default=False,
action="store_true",
help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
)
resource_args.add_argument(
"--dynamo_use_dynamic",
default=False,
action="store_true",
help="Whether to enable dynamic shape tracing.",
)
# Training Paradigm arguments
paradigm_args = parser.add_argument_group(
"Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
)
paradigm_args.add_argument(
"--use_deepspeed",
default=False,
action="store_true",
help="Whether to use deepspeed.",
)
paradigm_args.add_argument(
"--use_fsdp",
default=False,
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
action="store_true",
help="Whether to use Megatron-LM.",
)
paradigm_args.add_argument(
"--use_xpu",
default=False,
action="store_true",
help="Whether to use IPEX plugin to speed up training on XPU specifically.",
)
# distributed GPU training arguments
distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
distributed_args.add_argument(
"--gpu_ids",
default=None,
help="What GPUs (by id) should be used for training on this machine as a comma-seperated list",
)
distributed_args.add_argument(
"--same_network",
default=False,
action="store_true",
help="Whether all machines used for multinode training exist on the same local network.",
)
distributed_args.add_argument(
"--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
)
distributed_args.add_argument(
"--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
)
distributed_args.add_argument(
"--main_process_port",
type=int,
default=None,
help="The port to use to communicate with the machine of rank 0.",
)
distributed_args.add_argument(
"-t",
"--tee",
default="0",
type=str,
help="Tee std streams into a log file and also to console.",
)
distributed_args.add_argument(
"--role",
type=str,
default="default",
help="User-defined role for the workers.",
)
# Rendezvous related arguments
distributed_args.add_argument(
"--rdzv_backend",
type=str,
default="static",
help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
)
distributed_args.add_argument(
"--rdzv_conf",
type=str,
default="",
help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
)
distributed_args.add_argument(
"--max_restarts",
type=int,
default=0,
help="Maximum number of worker group restarts before failing.",
)
distributed_args.add_argument(
"--monitor_interval",
type=float,
default=5,
help="Interval, in seconds, to monitor the state of workers.",
)
parser.add_argument(
"-m",
"--module",
@ -323,279 +458,44 @@ def launch_command_parser(subparsers=None):
help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
)
# Resource selection arguments
resource_args = parser.add_argument_group(
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
)
ResourceArguments().add_to_parser(resource_args)
# Dynamo arguments
DynamoArguments().add_to_parser(resource_args)
# distributed GPU training arguments
distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
CUDAArguments().add_to_parser(distributed_args)
# TPU arguments
tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
tpu_args.add_argument(
"--tpu_cluster",
action="store_true",
dest="tpu_use_cluster",
help="Whether to use a GCP TPU pod for training.",
)
TPUArguments().add_to_parser(tpu_args)
tpu_args.add_argument(
"--no_tpu_cluster",
action="store_false",
dest="tpu_use_cluster",
help="Should not be passed explicitly, this is for internal use only.",
)
tpu_args.add_argument(
"--tpu_use_sudo",
action="store_true",
help="Whether to use `sudo` when running the TPU training script in each pod.",
)
tpu_args.add_argument(
"--vm",
type=str,
action="append",
help=(
"List of single Compute VM instance names. "
"If not provided we assume usage of instance groups. For TPU pods."
),
)
tpu_args.add_argument(
"--env",
type=str,
action="append",
help="List of environment variables to set on the Compute VM instances. For TPU pods.",
)
tpu_args.add_argument(
"--main_training_function",
type=str,
default=None,
help="The name of the main function to be executed in your script (only for TPU training).",
)
tpu_args.add_argument(
"--downcast_bf16",
action="store_true",
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
)
# DeepSpeed arguments
deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
deepspeed_args.add_argument(
"--deepspeed_config_file",
default=None,
type=str,
help="DeepSpeed config file.",
)
deepspeed_args.add_argument(
"--zero_stage",
default=None,
type=int,
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to `2`.",
)
deepspeed_args.add_argument(
"--offload_optimizer_device",
default=None,
type=str,
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to 'none'.",
)
deepspeed_args.add_argument(
"--offload_param_device",
default=None,
type=str,
help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to 'none'.",
)
deepspeed_args.add_argument(
"--offload_optimizer_nvme_path",
default=None,
type=str,
help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to 'none'.",
)
deepspeed_args.add_argument(
"--offload_param_nvme_path",
default=None,
type=str,
help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to 'none'.",
)
deepspeed_args.add_argument(
"--gradient_accumulation_steps",
default=None,
type=int,
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to `1`.",
)
deepspeed_args.add_argument(
"--gradient_clipping",
default=None,
type=float,
help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
"If unspecified, will default to `1.0`.",
)
deepspeed_args.add_argument(
"--zero3_init_flag",
default=None,
type=str,
help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
)
deepspeed_args.add_argument(
"--zero3_save_16bit_model",
default=None,
type=str,
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
)
deepspeed_args.add_argument(
"--deepspeed_hostfile",
default=None,
type=str,
help="DeepSpeed hostfile for configuring multi-node compute resources.",
)
deepspeed_args.add_argument(
"--deepspeed_exclusion_filter",
default=None,
type=str,
help="DeepSpeed exclusion filter string when using mutli-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_inclusion_filter",
default=None,
type=str,
help="DeepSpeed inclusion filter string when using mutli-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_multinode_launcher",
default=None,
type=str,
help="DeepSpeed multi-node launcher to use. If unspecified, will default to `pdsh`.",
)
DeepSpeedArguments().add_to_parser(deepspeed_args)
# fsdp arguments
fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
fsdp_args.add_argument(
"--fsdp_offload_params",
default="false",
type=str,
help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_min_num_params",
type=int,
default=1e8,
help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_sharding_strategy",
type=int,
default=1,
help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_auto_wrap_policy",
type=str,
default=None,
help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_transformer_layer_cls_to_wrap",
default=None,
type=str,
help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
"(useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_backward_prefetch_policy",
default=None,
type=str,
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_state_dict_type",
default=None,
type=str,
help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_forward_prefetch",
default="false",
type=str,
help="If True, then FSDP explicitly prefetches the next upcoming "
"all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_use_orig_params",
default="false",
type=str,
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
" (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_sync_module_states",
default="true",
type=str,
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
" (useful only when `use_fsdp` flag is passed).",
)
FSDPArguments().add_to_parser(fsdp_args)
# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
megatron_lm_args.add_argument(
"--megatron_lm_tp_degree",
type=int,
default=1,
help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_pp_degree",
type=int,
default=1,
help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_num_micro_batches",
type=int,
default=None,
help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_sequence_parallelism",
default=None,
type=str,
help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
"(useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_recompute_activations",
default=None,
type=str,
help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
"(useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_use_distributed_optimizer",
default=None,
type=str,
help="Decides Whether (true|false) to use distributed optimizer "
"which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
"(useful only when `use_megatron_lm` flag is passed).",
)
megatron_lm_args.add_argument(
"--megatron_lm_gradient_clipping",
default=1.0,
type=float,
help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
"(useful only when `use_megatron_lm` flag is passed).",
)
MegatronLMArguments().add_to_parser(megatron_lm_args)
# AWS arguments
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
aws_args.add_argument(
"--aws_access_key_id",
type=str,
default=None,
help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
)
aws_args.add_argument(
"--aws_secret_access_key",
type=str,
default=None,
help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
)
AWSArguments().add_to_parser(aws_args)
parser.add_argument(
"--debug",
action="store_true",

View File

@ -12,6 +12,7 @@ from .constants import (
WEIGHTS_NAME,
)
from .dataclasses import (
Arguments,
AutocastKwargs,
BnbQuantizationConfig,
ComputeEnvironment,

View File

@ -20,13 +20,15 @@ import argparse
import copy
import enum
import functools
import inspect
import os
import re
import typing
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from datetime import timedelta
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple
import torch
@ -1508,3 +1510,75 @@ class BnbQuantizationConfig:
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")
class Arguments:
"""
Base dataclass for CLI arguments. Contains methods for type validation and conversion to argparse afterwards.
Allows for compatibility between raw python and using argparse.
A `prefix` can be set which will be prepended to each argument name when converting to argparse
"""
prefix: ClassVar[str] = ""
def __post_init__(self):
self.validate_types()
def validate_types(self):
for arg in fields(self):
parameter_type = typing.get_origin(arg.type)
if parameter_type == typing.Literal:
if self.__dict__[arg.name] not in arg.type.__args__:
raise ValueError(
f"Invalid value for `{arg.name}`. Must be one of {list(arg.type.__args__)} not {self.__dict__[arg.name]}"
)
def to_argparse(self):
command = []
for arg in fields(self):
parameter_type = typing.get_origin(arg.type)
if parameter_type != typing.Literal:
command.append(f"{self.__dict__[arg.name]}")
else:
command.append(f"--{self.prefix}{arg.name}={self.__dict__[arg.name]}")
return command
def add_to_parser(self, parser: argparse.ArgumentParser = None):
"""
Creates an argparse.ArgumentParser from the dataclass with `help` based on the docstring of the class.
"""
param_to_docstring = {}
docstring = inspect.getdoc(self)
args = docstring.split("Args:\n")[1]
args = inspect.cleandoc(args)
args = re.split(r"\n(?=[^\s])", args)
for arg in args:
arg = arg.replace("\n ", " ")
param = arg.split(" ")[0]
docstring = arg.split(": ")[1]
param_to_docstring[param] = docstring
for arg in fields(self):
name = arg.name
docstring = param_to_docstring[name]
if not isinstance(arg.type, type):
parameter_type = typing.get_origin(arg.type)
else:
parameter_type = arg.type
arg_dict = {}
if arg.default is not None:
arg_dict["default"] = arg.default
if arg.default is False:
arg_dict["action"] = "store_true"
arg_dict["help"] = docstring
if parameter_type == typing.Literal:
arg_dict["choices"] = arg.type.__args__
arg_dict["type"] = str
elif parameter_type == list:
arg_dict["action"] = "append"
arg_dict["type"] = str
elif arg_dict.get("action", "store_true") != "store_true":
arg_dict["type"] = parameter_type
parser.add_argument(f"--{self.prefix}{name}", **arg_dict)