mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-15 14:55:09 +08:00
Compare commits
1 Commits
v0.29.3
...
deepspeed-
| Author | SHA1 | Date | |
|---|---|---|---|
| 25345f3ea9 |
@ -100,6 +100,7 @@ from .utils import (
|
||||
wait_for_everyone,
|
||||
)
|
||||
from .utils.constants import FSDP_PYTORCH_VERSION
|
||||
from .utils.dataclasses import InferencePlugin, DeepSpeedInferencePlugin
|
||||
from .utils.modeling import get_state_dict_offloaded_model
|
||||
from .utils.other import is_compiled_module
|
||||
|
||||
@ -187,6 +188,9 @@ class Accelerator:
|
||||
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
|
||||
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
|
||||
using *accelerate config*
|
||||
inference_plugin ([`~utils.InferencePlugin`], *optional*):
|
||||
Tweak how inference should be performed if only using Accelerate for inference and wanting to use `Accelerator`
|
||||
related enhancements.
|
||||
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
|
||||
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
|
||||
directly using *accelerate config*
|
||||
@ -253,6 +257,7 @@ class Accelerator:
|
||||
dataloader_config: DataLoaderConfiguration | None = None,
|
||||
deepspeed_plugin: DeepSpeedPlugin | None = None,
|
||||
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
|
||||
inference_plugin: InferencePlugin | None = None,
|
||||
megatron_lm_plugin: MegatronLMPlugin | None = None,
|
||||
rng_types: list[str | RNGType] | None = None,
|
||||
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
|
||||
|
||||
@ -15,6 +15,8 @@ import math
|
||||
from types import MethodType
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .state import PartialState
|
||||
from .utils import (
|
||||
calculate_maximum_sizes,
|
||||
@ -26,6 +28,7 @@ from .utils import (
|
||||
pad_input_tensors,
|
||||
send_to_device,
|
||||
)
|
||||
from .utils.dataclasses import KwargsHandler
|
||||
|
||||
|
||||
if is_pippy_available():
|
||||
|
||||
@ -26,15 +26,18 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
|
||||
|
||||
import torch
|
||||
|
||||
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
|
||||
from .environment import str_to_bool
|
||||
from .imports import is_cuda_available, is_npu_available, is_xpu_available
|
||||
from .imports import is_cuda_available, is_npu_available, is_xpu_available, is_deepspeed_available
|
||||
from .versions import compare_versions
|
||||
|
||||
if is_deepspeed_available():
|
||||
from deepspeed.inference.config import DeepSpeedTPConfig
|
||||
|
||||
|
||||
class KwargsHandler:
|
||||
"""
|
||||
@ -968,6 +971,136 @@ class DeepSpeedPlugin:
|
||||
"It will only ask for the necessary config variables when using `deepspeed_config_file`."
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class DeepSpeedInferencePlugin:
|
||||
"""
|
||||
Plugin to work with the deepspeed inference engine.
|
||||
"""
|
||||
|
||||
min_tokens: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Minimum number of tokens expected to need to generate."
|
||||
}
|
||||
)
|
||||
|
||||
max_tokens: int = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "Maximum number of tokens expected to work with, including input and output tokens."
|
||||
}
|
||||
)
|
||||
|
||||
replace_with_kernel_inject: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. "
|
||||
"Otherwise, the injection_dict provides the names of two linear layers as a tuple: "
|
||||
"(attention_output projection, transformer output projection)"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
enable_cuda_graph: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to capture CUDA graph of inference ops, so that iterative calls can be faster using graph replay."
|
||||
}
|
||||
)
|
||||
|
||||
injection_policy: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Dictionary mapping a client `nn.Module` to its injection policy, such as "
|
||||
"`{BertLayer: deepspeed.inference.HFBertLayerPolicy}`."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
triangular_masking: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use triangular masking for attention."
|
||||
}
|
||||
)
|
||||
|
||||
use_triton: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use triton kernels for inference."
|
||||
}
|
||||
)
|
||||
|
||||
triton_autotune: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enabling allows for triton autotuning to be performed. Tuning it increases performance, while increases the initial time taken to run."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
# dtype should be picked up from Accelerator
|
||||
|
||||
zero_config: Optional[Union[Dict[str, Any], "DeepSpeedZeroConfig"]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Configuration for ZeRO when using the inference engine. If launching with `accelerate launch`, will "
|
||||
"utilize this configuration."
|
||||
)
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
tensor_parallel_config: Optional[Union[Dict[str, Any], "DeepSpeedTPConfig"]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Configuration for tensor parallelism used to split the model across several GPUs. "
|
||||
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedTPConfig`."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
moe_config: Optional[Union[Dict[str, Any], "DeepSpeedMoeConfig"]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Configuration for Mixture of Experts (MoE) used to split the model across several GPUs. "
|
||||
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedMoEConfig`."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferencePlugin:
|
||||
"""
|
||||
Plugin that helps determine how inference should be performed.
|
||||
"""
|
||||
inference_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The dtype to use for inference. If not specified, will default to the dtype configured "
|
||||
"in the Accelerator"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
deepspeed_config: Optional[DeepSpeedInferencePlugin] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Configuration for DeepSpeed inference engine."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullyShardedDataParallelPlugin:
|
||||
|
||||
Reference in New Issue
Block a user