Compare commits

...

1 Commits

Author SHA1 Message Date
25345f3ea9 Bookmark 2024-04-02 09:08:00 -04:00
3 changed files with 143 additions and 2 deletions

View File

@ -100,6 +100,7 @@ from .utils import (
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.dataclasses import InferencePlugin, DeepSpeedInferencePlugin
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module
@ -187,6 +188,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
inference_plugin ([`~utils.InferencePlugin`], *optional*):
Tweak how inference should be performed if only using Accelerate for inference and wanting to use `Accelerator`
related enhancements.
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
@ -253,6 +257,7 @@ class Accelerator:
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
inference_plugin: InferencePlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,

View File

@ -15,6 +15,8 @@ import math
from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from .state import PartialState
from .utils import (
calculate_maximum_sizes,
@ -26,6 +28,7 @@ from .utils import (
pad_input_tensors,
send_to_device,
)
from .utils.dataclasses import KwargsHandler
if is_pippy_available():

View File

@ -26,15 +26,18 @@ import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
import torch
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .imports import is_cuda_available, is_npu_available, is_xpu_available, is_deepspeed_available
from .versions import compare_versions
if is_deepspeed_available():
from deepspeed.inference.config import DeepSpeedTPConfig
class KwargsHandler:
"""
@ -968,6 +971,136 @@ class DeepSpeedPlugin:
"It will only ask for the necessary config variables when using `deepspeed_config_file`."
)
@dataclass
class DeepSpeedInferencePlugin:
"""
Plugin to work with the deepspeed inference engine.
"""
min_tokens: int = field(
default=1,
metadata={
"help": "Minimum number of tokens expected to need to generate."
}
)
max_tokens: int = field(
default=1024,
metadata={
"help": "Maximum number of tokens expected to work with, including input and output tokens."
}
)
replace_with_kernel_inject: bool = field(
default=False,
metadata={
"help": (
"Set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. "
"Otherwise, the injection_dict provides the names of two linear layers as a tuple: "
"(attention_output projection, transformer output projection)"
)
}
)
enable_cuda_graph: bool = field(
default=False,
metadata={
"help": "Whether to capture CUDA graph of inference ops, so that iterative calls can be faster using graph replay."
}
)
injection_policy: Optional[dict] = field(
default=None,
metadata={
"help": (
"Dictionary mapping a client `nn.Module` to its injection policy, such as "
"`{BertLayer: deepspeed.inference.HFBertLayerPolicy}`."
)
}
)
triangular_masking: bool = field(
default=False,
metadata={
"help": "Whether to use triangular masking for attention."
}
)
use_triton: bool = field(
default=False,
metadata={
"help": "Whether to use triton kernels for inference."
}
)
triton_autotune: bool = field(
default=False,
metadata={
"help": "Enabling allows for triton autotuning to be performed. Tuning it increases performance, while increases the initial time taken to run."
}
)
# dtype should be picked up from Accelerator
zero_config: Optional[Union[Dict[str, Any], "DeepSpeedZeroConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for ZeRO when using the inference engine. If launching with `accelerate launch`, will "
"utilize this configuration."
)
}
)
tensor_parallel_config: Optional[Union[Dict[str, Any], "DeepSpeedTPConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for tensor parallelism used to split the model across several GPUs. "
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedTPConfig`."
)
}
)
moe_config: Optional[Union[Dict[str, Any], "DeepSpeedMoeConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for Mixture of Experts (MoE) used to split the model across several GPUs. "
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedMoEConfig`."
)
}
)
@dataclass
class InferencePlugin:
"""
Plugin that helps determine how inference should be performed.
"""
inference_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"The dtype to use for inference. If not specified, will default to the dtype configured "
"in the Accelerator"
)
}
)
deepspeed_config: Optional[DeepSpeedInferencePlugin] = field(
default=None,
metadata={
"help": "Configuration for DeepSpeed inference engine."
}
)
@dataclass
class FullyShardedDataParallelPlugin: