Compare commits

...

1 Commits

Author SHA1 Message Date
25345f3ea9 Bookmark 2024-04-02 09:08:00 -04:00
3 changed files with 143 additions and 2 deletions

View File

@ -100,6 +100,7 @@ from .utils import (
wait_for_everyone, wait_for_everyone,
) )
from .utils.constants import FSDP_PYTORCH_VERSION from .utils.constants import FSDP_PYTORCH_VERSION
from .utils.dataclasses import InferencePlugin, DeepSpeedInferencePlugin
from .utils.modeling import get_state_dict_offloaded_model from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module from .utils.other import is_compiled_module
@ -187,6 +188,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*): fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config* using *accelerate config*
inference_plugin ([`~utils.InferencePlugin`], *optional*):
Tweak how inference should be performed if only using Accelerate for inference and wanting to use `Accelerator`
related enhancements.
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*): megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config* directly using *accelerate config*
@ -253,6 +257,7 @@ class Accelerator:
dataloader_config: DataLoaderConfiguration | None = None, dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | None = None, deepspeed_plugin: DeepSpeedPlugin | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
inference_plugin: InferencePlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None, rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,

View File

@ -15,6 +15,8 @@ import math
from types import MethodType from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from .state import PartialState from .state import PartialState
from .utils import ( from .utils import (
calculate_maximum_sizes, calculate_maximum_sizes,
@ -26,6 +28,7 @@ from .utils import (
pad_input_tensors, pad_input_tensors,
send_to_device, send_to_device,
) )
from .utils.dataclasses import KwargsHandler
if is_pippy_available(): if is_pippy_available():

View File

@ -26,15 +26,18 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args
import torch import torch
from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool from .environment import str_to_bool
from .imports import is_cuda_available, is_npu_available, is_xpu_available from .imports import is_cuda_available, is_npu_available, is_xpu_available, is_deepspeed_available
from .versions import compare_versions from .versions import compare_versions
if is_deepspeed_available():
from deepspeed.inference.config import DeepSpeedTPConfig
class KwargsHandler: class KwargsHandler:
""" """
@ -968,6 +971,136 @@ class DeepSpeedPlugin:
"It will only ask for the necessary config variables when using `deepspeed_config_file`." "It will only ask for the necessary config variables when using `deepspeed_config_file`."
) )
@dataclass
class DeepSpeedInferencePlugin:
"""
Plugin to work with the deepspeed inference engine.
"""
min_tokens: int = field(
default=1,
metadata={
"help": "Minimum number of tokens expected to need to generate."
}
)
max_tokens: int = field(
default=1024,
metadata={
"help": "Maximum number of tokens expected to work with, including input and output tokens."
}
)
replace_with_kernel_inject: bool = field(
default=False,
metadata={
"help": (
"Set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. "
"Otherwise, the injection_dict provides the names of two linear layers as a tuple: "
"(attention_output projection, transformer output projection)"
)
}
)
enable_cuda_graph: bool = field(
default=False,
metadata={
"help": "Whether to capture CUDA graph of inference ops, so that iterative calls can be faster using graph replay."
}
)
injection_policy: Optional[dict] = field(
default=None,
metadata={
"help": (
"Dictionary mapping a client `nn.Module` to its injection policy, such as "
"`{BertLayer: deepspeed.inference.HFBertLayerPolicy}`."
)
}
)
triangular_masking: bool = field(
default=False,
metadata={
"help": "Whether to use triangular masking for attention."
}
)
use_triton: bool = field(
default=False,
metadata={
"help": "Whether to use triton kernels for inference."
}
)
triton_autotune: bool = field(
default=False,
metadata={
"help": "Enabling allows for triton autotuning to be performed. Tuning it increases performance, while increases the initial time taken to run."
}
)
# dtype should be picked up from Accelerator
zero_config: Optional[Union[Dict[str, Any], "DeepSpeedZeroConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for ZeRO when using the inference engine. If launching with `accelerate launch`, will "
"utilize this configuration."
)
}
)
tensor_parallel_config: Optional[Union[Dict[str, Any], "DeepSpeedTPConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for tensor parallelism used to split the model across several GPUs. "
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedTPConfig`."
)
}
)
moe_config: Optional[Union[Dict[str, Any], "DeepSpeedMoeConfig"]] = field(
default=None,
metadata={
"help": (
"Configuration for Mixture of Experts (MoE) used to split the model across several GPUs. "
"Valid keys can be found in `deepspeed.inference.config.DeepSpeedMoEConfig`."
)
}
)
@dataclass
class InferencePlugin:
"""
Plugin that helps determine how inference should be performed.
"""
inference_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"The dtype to use for inference. If not specified, will default to the dtype configured "
"in the Accelerator"
)
}
)
deepspeed_config: Optional[DeepSpeedInferencePlugin] = field(
default=None,
metadata={
"help": "Configuration for DeepSpeed inference engine."
}
)
@dataclass @dataclass
class FullyShardedDataParallelPlugin: class FullyShardedDataParallelPlugin: