mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 16:44:39 +08:00
Compare commits
1 Commits
v0.32.1
...
deepspeed-
| Author | SHA1 | Date | |
|---|---|---|---|
| 25345f3ea9 |
@ -100,6 +100,7 @@ from .utils import (
|
|||||||
wait_for_everyone,
|
wait_for_everyone,
|
||||||
)
|
)
|
||||||
from .utils.constants import FSDP_PYTORCH_VERSION
|
from .utils.constants import FSDP_PYTORCH_VERSION
|
||||||
|
from .utils.dataclasses import InferencePlugin, DeepSpeedInferencePlugin
|
||||||
from .utils.modeling import get_state_dict_offloaded_model
|
from .utils.modeling import get_state_dict_offloaded_model
|
||||||
from .utils.other import is_compiled_module
|
from .utils.other import is_compiled_module
|
||||||
|
|
||||||
@ -187,6 +188,9 @@ class Accelerator:
|
|||||||
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
|
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
|
||||||
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
|
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
|
||||||
using *accelerate config*
|
using *accelerate config*
|
||||||
|
inference_plugin ([`~utils.InferencePlugin`], *optional*):
|
||||||
|
Tweak how inference should be performed if only using Accelerate for inference and wanting to use `Accelerator`
|
||||||
|
related enhancements.
|
||||||
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
|
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
|
||||||
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
|
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
|
||||||
directly using *accelerate config*
|
directly using *accelerate config*
|
||||||
@ -253,6 +257,7 @@ class Accelerator:
|
|||||||
dataloader_config: DataLoaderConfiguration | None = None,
|
dataloader_config: DataLoaderConfiguration | None = None,
|
||||||
deepspeed_plugin: DeepSpeedPlugin | None = None,
|
deepspeed_plugin: DeepSpeedPlugin | None = None,
|
||||||
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
|
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
|
||||||
|
inference_plugin: InferencePlugin | None = None,
|
||||||
megatron_lm_plugin: MegatronLMPlugin | None = None,
|
megatron_lm_plugin: MegatronLMPlugin | None = None,
|
||||||
rng_types: list[str | RNGType] | None = None,
|
rng_types: list[str | RNGType] | None = None,
|
||||||
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
|
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
|
||||||
|
|||||||
@ -15,6 +15,8 @@ import math
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from .state import PartialState
|
from .state import PartialState
|
||||||
from .utils import (
|
from .utils import (
|
||||||
calculate_maximum_sizes,
|
calculate_maximum_sizes,
|
||||||
@ -26,6 +28,7 @@ from .utils import (
|
|||||||
pad_input_tensors,
|
pad_input_tensors,
|
||||||
send_to_device,
|
send_to_device,
|
||||||
)
|
)
|
||||||
|
from .utils.dataclasses import KwargsHandler
|
||||||
|
|
||||||
|
|
||||||
if is_pippy_available():
|
if is_pippy_available():
|
||||||
|
|||||||
@ -26,15 +26,18 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args
|
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
|
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
|
||||||
from .environment import str_to_bool
|
from .environment import str_to_bool
|
||||||
from .imports import is_cuda_available, is_npu_available, is_xpu_available
|
from .imports import is_cuda_available, is_npu_available, is_xpu_available, is_deepspeed_available
|
||||||
from .versions import compare_versions
|
from .versions import compare_versions
|
||||||
|
|
||||||
|
if is_deepspeed_available():
|
||||||
|
from deepspeed.inference.config import DeepSpeedTPConfig
|
||||||
|
|
||||||
|
|
||||||
class KwargsHandler:
|
class KwargsHandler:
|
||||||
"""
|
"""
|
||||||
@ -968,6 +971,136 @@ class DeepSpeedPlugin:
|
|||||||
"It will only ask for the necessary config variables when using `deepspeed_config_file`."
|
"It will only ask for the necessary config variables when using `deepspeed_config_file`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepSpeedInferencePlugin:
|
||||||
|
"""
|
||||||
|
Plugin to work with the deepspeed inference engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_tokens: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={
|
||||||
|
"help": "Minimum number of tokens expected to need to generate."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens: int = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={
|
||||||
|
"help": "Maximum number of tokens expected to work with, including input and output tokens."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
replace_with_kernel_inject: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. "
|
||||||
|
"Otherwise, the injection_dict provides the names of two linear layers as a tuple: "
|
||||||
|
"(attention_output projection, transformer output projection)"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_cuda_graph: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to capture CUDA graph of inference ops, so that iterative calls can be faster using graph replay."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
injection_policy: Optional[dict] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Dictionary mapping a client `nn.Module` to its injection policy, such as "
|
||||||
|
"`{BertLayer: deepspeed.inference.HFBertLayerPolicy}`."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
triangular_masking: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to use triangular masking for attention."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
use_triton: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to use triton kernels for inference."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_autotune: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Enabling allows for triton autotuning to be performed. Tuning it increases performance, while increases the initial time taken to run."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# dtype should be picked up from Accelerator
|
||||||
|
|
||||||
|
zero_config: Optional[Union[Dict[str, Any], "DeepSpeedZeroConfig"]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Configuration for ZeRO when using the inference engine. If launching with `accelerate launch`, will "
|
||||||
|
"utilize this configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_parallel_config: Optional[Union[Dict[str, Any], "DeepSpeedTPConfig"]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Configuration for tensor parallelism used to split the model across several GPUs. "
|
||||||
|
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedTPConfig`."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_config: Optional[Union[Dict[str, Any], "DeepSpeedMoeConfig"]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Configuration for Mixture of Experts (MoE) used to split the model across several GPUs. "
|
||||||
|
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedMoEConfig`."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferencePlugin:
|
||||||
|
"""
|
||||||
|
Plugin that helps determine how inference should be performed.
|
||||||
|
"""
|
||||||
|
inference_dtype: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"The dtype to use for inference. If not specified, will default to the dtype configured "
|
||||||
|
"in the Accelerator"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
deepspeed_config: Optional[DeepSpeedInferencePlugin] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Configuration for DeepSpeed inference engine."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FullyShardedDataParallelPlugin:
|
class FullyShardedDataParallelPlugin:
|
||||||
|
|||||||
Reference in New Issue
Block a user