mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Pydantic v2 migration (#5167)
Pydantic v2 has been out for some time now. We have been relying on using the v1 API available in v2 until now. This is a refresh of #3902 to bring proper v2 support to DeepSpeed. Corresponding DeepSpeed-MII PR [here](https://github.com/microsoft/DeepSpeed-MII/pull/423). @loadams --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams <loadams@microsoft.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com> Co-authored-by: Abhishek Kulkarni <abkulkarni@microsoft.com> Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
This commit is contained in:
3
.github/workflows/nv-a6000.yml
vendored
3
.github/workflows/nv-a6000.yml
vendored
@ -47,7 +47,8 @@ jobs:
|
|||||||
- name: Install deepspeed
|
- name: Install deepspeed
|
||||||
run: |
|
run: |
|
||||||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
||||||
python -m pip install pydantic==1.10.11
|
# Update packages included in the container that do not support pydantic 2+ to versions that do
|
||||||
|
python -m pip install thinc spacy confection --upgrade
|
||||||
python -m pip install .[dev,1bit,autotuning,inf]
|
python -m pip install .[dev,1bit,autotuning,inf]
|
||||||
ds_report
|
ds_report
|
||||||
- name: Python environment
|
- name: Python environment
|
||||||
|
@ -3,20 +3,12 @@
|
|||||||
|
|
||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
|
|
||||||
from .constants import *
|
from .constants import *
|
||||||
from ..pydantic_v1 import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class CommsConfig(BaseModel):
|
class CommsLoggerConfig(DeepSpeedConfigModel):
|
||||||
|
|
||||||
class Config:
|
|
||||||
validate_all = True
|
|
||||||
validate_assignment = True
|
|
||||||
use_enum_values = True
|
|
||||||
extra = 'forbid'
|
|
||||||
|
|
||||||
|
|
||||||
class CommsLoggerConfig(CommsConfig):
|
|
||||||
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
|
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
|
||||||
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
|
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
|
||||||
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
|
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
|
||||||
|
@ -5,38 +5,25 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from deepspeed.pydantic_v1 import Field, validator
|
from pydantic import Field, field_validator
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class DtypeEnum(Enum):
|
class DtypeEnum(Enum):
|
||||||
# The torch dtype must always be the first value (so we return torch.dtype)
|
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
|
||||||
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
|
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
|
||||||
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
|
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
|
||||||
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
|
int8 = (torch.int8, "torch.int8", "int8")
|
||||||
int8 = torch.int8, "torch.int8", "int8"
|
|
||||||
|
|
||||||
# Copied from https://stackoverflow.com/a/43210118
|
@classmethod
|
||||||
# Allows us to use multiple values for each Enum index and returns first
|
def from_str(cls, value: str):
|
||||||
# listed value when Enum is called
|
for dtype in cls:
|
||||||
def __new__(cls, *values):
|
if value in dtype.value:
|
||||||
obj = object.__new__(cls)
|
return dtype
|
||||||
# first value is canonical value
|
raise ValueError(f"'{value}' is not a valid DtypeEnum")
|
||||||
obj._value_ = values[0]
|
|
||||||
for other_value in values[1:]:
|
|
||||||
cls._value2member_map_[other_value] = obj
|
|
||||||
obj._all_values = values
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<%s.%s: %s>" % (
|
|
||||||
self.__class__.__name__,
|
|
||||||
self._name_,
|
|
||||||
", ".join([repr(v) for v in self._all_values]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MoETypeEnum(str, Enum):
|
class MoETypeEnum(str, Enum):
|
||||||
@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class BaseQuantConfig(DeepSpeedConfigModel):
|
class BaseQuantConfig(DeepSpeedConfigModel):
|
||||||
enabled = True
|
enabled: bool = True
|
||||||
num_bits = 8
|
num_bits: int = 8
|
||||||
q_type: QuantTypeEnum = QuantTypeEnum.sym
|
q_type: QuantTypeEnum = QuantTypeEnum.sym
|
||||||
q_groups: int = 1
|
q_groups: int = 1
|
||||||
|
|
||||||
|
|
||||||
class WeightQuantConfig(BaseQuantConfig):
|
class WeightQuantConfig(BaseQuantConfig):
|
||||||
enabled = True
|
enabled: bool = True
|
||||||
quantized_initialization: Dict = {}
|
quantized_initialization: Dict = {}
|
||||||
post_init_quant: Dict = {}
|
post_init_quant: Dict = {}
|
||||||
|
|
||||||
|
|
||||||
class ActivationQuantConfig(BaseQuantConfig):
|
class ActivationQuantConfig(BaseQuantConfig):
|
||||||
enabled = True
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
class QKVQuantConfig(DeepSpeedConfigModel):
|
class QKVQuantConfig(DeepSpeedConfigModel):
|
||||||
enabled = True
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
class QuantizationConfig(DeepSpeedConfigModel):
|
class QuantizationConfig(DeepSpeedConfigModel):
|
||||||
@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):
|
|||||||
|
|
||||||
# todo: brainstorm on how to do ckpt loading for DS inference
|
# todo: brainstorm on how to do ckpt loading for DS inference
|
||||||
class InferenceCheckpointConfig(DeepSpeedConfigModel):
|
class InferenceCheckpointConfig(DeepSpeedConfigModel):
|
||||||
checkpoint_dir: str = None
|
checkpoint_dir: Optional[str] = None
|
||||||
save_mp_checkpoint_path: str = None
|
save_mp_checkpoint_path: Optional[str] = None
|
||||||
base_dir: str = None
|
base_dir: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||||
@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
|||||||
`(attention_output projection, transformer output projection)`
|
`(attention_output projection, transformer output projection)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dtype: DtypeEnum = torch.float16
|
dtype: torch.dtype = torch.float16
|
||||||
"""
|
"""
|
||||||
Desired model data type, will convert model to this type.
|
Desired model data type, will convert model to this type.
|
||||||
Supported target types: `torch.half`, `torch.int8`, `torch.float`
|
Supported target types: `torch.half`, `torch.int8`, `torch.float`
|
||||||
@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
#todo: refactor the following 3 into the new checkpoint_config
|
#todo: refactor the following 3 into the new checkpoint_config
|
||||||
checkpoint: Union[str, Dict] = None
|
checkpoint: Optional[Union[str, Dict]] = None
|
||||||
"""
|
"""
|
||||||
Path to deepspeed compatible checkpoint or path to JSON with load policy.
|
Path to deepspeed compatible checkpoint or path to JSON with load policy.
|
||||||
"""
|
"""
|
||||||
@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
|||||||
specifying whether the inference-module is created with empty or real Tensor
|
specifying whether the inference-module is created with empty or real Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
save_mp_checkpoint_path: str = None
|
save_mp_checkpoint_path: Optional[str] = None
|
||||||
"""
|
"""
|
||||||
The path for which we want to save the loaded model with a checkpoint. This
|
The path for which we want to save the loaded model with a checkpoint. This
|
||||||
feature is used for adjusting the parallelism degree to help alleviate the
|
feature is used for adjusting the parallelism degree to help alleviate the
|
||||||
@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
|||||||
|
|
||||||
replace_method: str = Field(
|
replace_method: str = Field(
|
||||||
"auto",
|
"auto",
|
||||||
deprecated=True,
|
json_schema_extra={
|
||||||
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
|
"deprecated": True,
|
||||||
|
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
|
||||||
|
})
|
||||||
|
|
||||||
injection_policy: Dict = Field(None, alias="injection_dict")
|
injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
|
||||||
"""
|
"""
|
||||||
Dictionary mapping a client nn.Module to its corresponding injection
|
Dictionary mapping a client nn.Module to its corresponding injection
|
||||||
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
|
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
injection_policy_tuple: tuple = None
|
injection_policy_tuple: Optional[tuple] = None
|
||||||
""" TODO: Add docs """
|
""" TODO: Add docs """
|
||||||
|
|
||||||
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
|
config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor
|
||||||
|
|
||||||
max_out_tokens: int = Field(1024, alias="max_tokens")
|
max_out_tokens: int = Field(1024, alias="max_tokens")
|
||||||
"""
|
"""
|
||||||
@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
|||||||
|
|
||||||
transposed_mode: bool = Field(False, alias="transposed_mode")
|
transposed_mode: bool = Field(False, alias="transposed_mode")
|
||||||
|
|
||||||
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
|
mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
|
||||||
"""
|
"""
|
||||||
Desired model parallel size, default is 1 meaning no model parallelism.
|
Desired model parallel size, default is 1 meaning no model parallelism.
|
||||||
Deprecated, please use the ``tensor_parallel` config to control model
|
Deprecated, please use the ``tensor_parallel` config to control model
|
||||||
parallelism.
|
parallelism.
|
||||||
"""
|
"""
|
||||||
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
|
mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
|
||||||
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
|
ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
|
||||||
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
|
ep_group: object = Field(None,
|
||||||
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
|
alias="expert_group",
|
||||||
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
|
json_schema_extra={
|
||||||
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
|
"deprecated": True,
|
||||||
|
"new_param": "moe.ep_group"
|
||||||
|
})
|
||||||
|
ep_mp_group: object = Field(None,
|
||||||
|
alias="expert_mp_group",
|
||||||
|
json_schema_extra={
|
||||||
|
"deprecated": True,
|
||||||
|
"new_param": "moe.ep_mp_group"
|
||||||
|
})
|
||||||
|
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
|
||||||
|
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
|
||||||
|
json_schema_extra={
|
||||||
|
"deprecated": True,
|
||||||
|
"new_param": "moe.type"
|
||||||
|
})
|
||||||
|
|
||||||
@validator("moe")
|
@field_validator("dtype", mode="before")
|
||||||
|
def validate_dtype(cls, field_value, values):
|
||||||
|
if isinstance(field_value, str):
|
||||||
|
return DtypeEnum.from_str(field_value).value[0]
|
||||||
|
if isinstance(field_value, torch.dtype):
|
||||||
|
return field_value
|
||||||
|
raise TypeError(f"Invalid type for dtype: {type(field_value)}")
|
||||||
|
|
||||||
|
@field_validator("moe")
|
||||||
def moe_backward_compat(cls, field_value, values):
|
def moe_backward_compat(cls, field_value, values):
|
||||||
if isinstance(field_value, bool):
|
if isinstance(field_value, bool):
|
||||||
return DeepSpeedMoEConfig(moe=field_value)
|
return DeepSpeedMoEConfig(moe=field_value)
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
@validator("use_triton")
|
@field_validator("use_triton")
|
||||||
def has_triton(cls, field_value, values):
|
def has_triton(cls, field_value, values):
|
||||||
if field_value and not deepspeed.HAS_TRITON:
|
if field_value and not deepspeed.HAS_TRITON:
|
||||||
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
|
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
class Config:
|
|
||||||
# Get the str representation of the datatype for serialization
|
|
||||||
json_encoders = {torch.dtype: lambda x: str(x)}
|
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from deepspeed.pydantic_v1 import Field
|
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
from .ragged import DSStateManagerConfig
|
from .ragged import DSStateManagerConfig
|
||||||
|
|
||||||
|
@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel):
|
|||||||
"""
|
"""
|
||||||
A class to represent a tensor specification.
|
A class to represent a tensor specification.
|
||||||
"""
|
"""
|
||||||
dtype: Optional[str]
|
dtype: Optional[str] = None
|
||||||
shape: Optional[Tuple[int, ...]]
|
shape: Optional[Tuple[int, ...]] = None
|
||||||
strides: Optional[Tuple[int, ...]]
|
strides: Optional[Tuple[int, ...]] = None
|
||||||
offset: int
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel):
|
|||||||
"""
|
"""
|
||||||
A class to represent a parameter specification.
|
A class to represent a parameter specification.
|
||||||
"""
|
"""
|
||||||
core_param: TensorMetadata = None
|
core_param: Optional[TensorMetadata] = None
|
||||||
aux_params: Dict[str, TensorMetadata] = {}
|
aux_params: Dict[str, TensorMetadata] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from deepspeed.pydantic_v1 import PositiveInt, validator
|
from pydantic import PositiveInt, model_validator
|
||||||
|
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
from ..inference_utils import DtypeEnum
|
from ..inference_utils import DtypeEnum
|
||||||
@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
|
|||||||
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
|
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@validator("max_ragged_sequence_count")
|
@model_validator(mode="after")
|
||||||
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
|
def max_ragged_sequence_count_validator(self):
|
||||||
# If the attributes below failed their validation they won't appear in the values dict.
|
# If the attributes below failed their validation they won't appear in the values dict.
|
||||||
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
|
assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
|
||||||
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
|
assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
|
||||||
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
|
return self
|
||||||
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
|
|
||||||
return v
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from deepspeed.pydantic_v1 import root_validator
|
from pydantic import model_validator
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
|
|
||||||
|
|
||||||
@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
|
|||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
|
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
|
||||||
|
|
||||||
group: str = None
|
group: Optional[str] = None
|
||||||
""" Name for the WandB group. This can be used to group together runs. """
|
""" Name for the WandB group. This can be used to group together runs. """
|
||||||
|
|
||||||
team: str = None
|
team: Optional[str] = None
|
||||||
""" Name for the WandB team. """
|
""" Name for the WandB team. """
|
||||||
|
|
||||||
project: str = "deepspeed"
|
project: str = "deepspeed"
|
||||||
@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
|
|||||||
csv_monitor: CSVConfig = {}
|
csv_monitor: CSVConfig = {}
|
||||||
""" Local CSV output of monitoring data. """
|
""" Local CSV output of monitoring data. """
|
||||||
|
|
||||||
@root_validator
|
@model_validator(mode="after")
|
||||||
def check_enabled(cls, values):
|
def check_enabled(self):
|
||||||
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
|
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
|
||||||
"csv_monitor").enabled or values.get("comet").enabled
|
self.__dict__["enabled"] = enabled
|
||||||
return values
|
return self
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
# DeepSpeed Team
|
|
||||||
"""Pydantic v1 compatibility module.
|
|
||||||
|
|
||||||
Pydantic v2 introduced breaking changes that hinder its adoption:
|
|
||||||
https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
|
|
||||||
migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
|
|
||||||
as a pydantic-version-agnostic alias for pydantic's v1 API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from pydantic.v1 import * # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
from pydantic import * # noqa: F401
|
|
@ -5,11 +5,12 @@
|
|||||||
"""
|
"""
|
||||||
Collection of DeepSpeed configuration utilities
|
Collection of DeepSpeed configuration utilities
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import collections
|
import collections
|
||||||
import collections.abc
|
import json
|
||||||
|
import torch
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from deepspeed.pydantic_v1 import BaseModel
|
from pydantic import BaseModel, ConfigDict, field_serializer
|
||||||
|
|
||||||
from deepspeed.utils import logger
|
from deepspeed.utils import logger
|
||||||
|
|
||||||
|
|
||||||
@ -54,67 +55,73 @@ class DeepSpeedConfigModel(BaseModel):
|
|||||||
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
|
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
|
||||||
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
|
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._deprecated_fields_check(self)
|
self._deprecated_fields_check()
|
||||||
|
|
||||||
def _process_deprecated_field(self, pydantic_config, field):
|
def _process_deprecated_field(self, dep_field):
|
||||||
# Get information about the deprecated field
|
# Get information about the deprecated field
|
||||||
fields_set = pydantic_config.__fields_set__
|
pydantic_config = self
|
||||||
dep_param = field.name
|
fields_set = pydantic_config.model_fields_set
|
||||||
kwargs = field.field_info.extra
|
kwargs = pydantic_config.model_fields[dep_field].json_schema_extra
|
||||||
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
|
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
|
||||||
param_value = new_param_fn(getattr(pydantic_config, dep_param))
|
param_value = new_param_fn(getattr(pydantic_config, dep_field))
|
||||||
new_param = kwargs.get("new_param", "")
|
new_field = kwargs.get("new_param", "")
|
||||||
dep_msg = kwargs.get("deprecated_msg", "")
|
dep_msg = kwargs.get("deprecated_msg", "")
|
||||||
if dep_param in fields_set:
|
if dep_field in fields_set:
|
||||||
logger.warning(f"Config parameter {dep_param} is deprecated" +
|
logger.warning(f"Config parameter {dep_field} is deprecated" +
|
||||||
(f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
|
(f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
|
||||||
# Check if there is a new param and if it should be set with a value
|
# Check if there is a new param and if it should be set with a value
|
||||||
if new_param and kwargs.get("set_new_param", True):
|
if new_field and kwargs.get("set_new_param", True):
|
||||||
# Remove the deprecate field if there is a replacing field
|
# Remove the deprecate field if there is a replacing field
|
||||||
try:
|
try:
|
||||||
delattr(pydantic_config, dep_param)
|
delattr(pydantic_config, dep_field)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tried removing deprecated '{dep_param}' from config")
|
logger.error(f"Tried removing deprecated '{dep_field}' from config")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Set new param value
|
# Set new param value
|
||||||
new_param_nested = new_param.split(".")
|
new_param_nested = new_field.split(".")
|
||||||
if len(new_param_nested) > 1:
|
if len(new_param_nested) > 1:
|
||||||
# If the new param exists in a subconfig, we need to get
|
# If the new param exists in a subconfig, we need to get
|
||||||
# the fields set for that subconfig
|
# the fields set for that subconfig
|
||||||
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
|
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
|
||||||
fields_set = pydantic_config.__fields_set__
|
fields_set = pydantic_config.model_fields_set
|
||||||
new_param_name = new_param_nested[-1]
|
new_param_name = new_param_nested[-1]
|
||||||
assert (
|
assert (
|
||||||
new_param_name not in fields_set
|
new_param_name not in fields_set
|
||||||
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
|
), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
|
||||||
# A custom function for converting the old param value to new param value can be provided
|
# A custom function for converting the old param value to new param value can be provided
|
||||||
try:
|
try:
|
||||||
setattr(pydantic_config, new_param_name, param_value)
|
setattr(pydantic_config, new_param_name, param_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
|
logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _deprecated_fields_check(self, pydantic_config):
|
def _deprecated_fields_check(self):
|
||||||
fields = pydantic_config.__fields__
|
fields = self.model_fields
|
||||||
for field in fields.values():
|
for field_name, field_info in fields.items():
|
||||||
if field.field_info.extra.get("deprecated", False):
|
if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
|
||||||
self._process_deprecated_field(pydantic_config, field)
|
self._process_deprecated_field(field_name)
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
validate_all = True
|
validate_default=True,
|
||||||
validate_assignment = True
|
validate_assignment=True,
|
||||||
use_enum_values = True
|
use_enum_values=True,
|
||||||
allow_population_by_field_name = True
|
populate_by_name=True,
|
||||||
extra = "forbid"
|
extra="forbid",
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed=True,
|
||||||
|
protected_namespaces=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_serializer("dtype", check_fields=False)
|
||||||
|
def serialize_torch_dtype(dtype: torch.dtype) -> str:
|
||||||
|
return str(dtype)
|
||||||
|
|
||||||
|
|
||||||
def get_config_default(config, field_name):
|
def get_config_default(config, field_name):
|
||||||
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
|
assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}"
|
||||||
assert not config.__fields__.get(
|
assert not config.model_fields.get(
|
||||||
field_name).required, f"'{field_name}' is a required field and does not have a default value"
|
field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
|
||||||
return config.__fields__.get(field_name).default
|
return config.model_fields.get(field_name).get_default()
|
||||||
|
|
||||||
|
|
||||||
class pp_int(int):
|
class pp_int(int):
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from deepspeed.pydantic_v1 import Field, validator, root_validator
|
from pydantic import Field, model_validator
|
||||||
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
|
||||||
from deepspeed.utils import logger
|
from deepspeed.utils import logger
|
||||||
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
|
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
|
||||||
@ -30,7 +30,7 @@ ZeRO optimization should be enabled as:
|
|||||||
"reduce_bucket_size": 500000000,
|
"reduce_bucket_size": 500000000,
|
||||||
"load_from_fp32_weights": [true|false],
|
"load_from_fp32_weights": [true|false],
|
||||||
"cpu_offload": [true|false] (deprecated),
|
"cpu_offload": [true|false] (deprecated),
|
||||||
"cpu_offload_params" : [true|false] (deprecated),
|
"cpu_offload_param" : [true|false] (deprecated),
|
||||||
"cpu_offload_use_pin_memory": [true|false] (deprecated),
|
"cpu_offload_use_pin_memory": [true|false] (deprecated),
|
||||||
"sub_group_size" : 1000000000000,
|
"sub_group_size" : 1000000000000,
|
||||||
"offload_param": {...},
|
"offload_param": {...},
|
||||||
@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
|||||||
the allgather for large model sizes
|
the allgather for large model sizes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
|
overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
|
||||||
"""
|
"""
|
||||||
Attempts to overlap the reduction of the gradients with backward computation
|
Attempts to overlap the reduction of the gradients with backward computation
|
||||||
"""
|
"""
|
||||||
@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
|||||||
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
|
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cpu_offload_param: bool = Field(
|
cpu_offload_param: Optional[bool] = Field(
|
||||||
None,
|
None,
|
||||||
deprecated=True,
|
json_schema_extra={
|
||||||
new_param="offload_param",
|
"deprecated": True,
|
||||||
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
|
"new_param": "offload_param",
|
||||||
|
"new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
|
||||||
|
if val else None)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
""" Deprecated, please use ``offload_param`` """
|
""" Deprecated, please use ``offload_param`` """
|
||||||
|
|
||||||
cpu_offload_use_pin_memory: bool = Field(
|
cpu_offload_use_pin_memory: Optional[bool] = Field(
|
||||||
None,
|
None,
|
||||||
deprecated=True,
|
json_schema_extra={
|
||||||
new_param="offload_param or offload_optimizer",
|
"deprecated": True,
|
||||||
set_new_param=False,
|
"new_param": "offload_param or offload_optimizer",
|
||||||
|
"set_new_param": False
|
||||||
|
},
|
||||||
)
|
)
|
||||||
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
|
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
|
||||||
|
|
||||||
cpu_offload: bool = Field(
|
cpu_offload: Optional[bool] = Field(
|
||||||
None,
|
None,
|
||||||
deprecated=True,
|
json_schema_extra={
|
||||||
new_param="offload_optimizer",
|
"deprecated":
|
||||||
new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
|
True,
|
||||||
|
"new_param":
|
||||||
|
"offload_optimizer",
|
||||||
|
"new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
|
||||||
|
if val else None)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
""" Deprecated, please use ``offload_optimizer`` """
|
""" Deprecated, please use ``offload_optimizer`` """
|
||||||
|
|
||||||
@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
|
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
|
||||||
deprecated=True,
|
json_schema_extra={
|
||||||
new_param="gather_16bit_weights_on_model_save")
|
"deprecated": True,
|
||||||
|
"new_param": "gather_16bit_weights_on_model_save"
|
||||||
|
})
|
||||||
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
|
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
|
||||||
|
|
||||||
ignore_unused_parameters: bool = True
|
ignore_unused_parameters: bool = True
|
||||||
@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Validators
|
# Validators
|
||||||
@validator("overlap_comm")
|
@model_validator(mode="after")
|
||||||
def overlap_comm_valid(cls, field_value, values):
|
def overlap_comm_valid(self):
|
||||||
if field_value is None:
|
if self.overlap_comm is None:
|
||||||
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
|
self.overlap_comm = self.stage == ZeroStageEnum.weights
|
||||||
field_value = values["stage"] == ZeroStageEnum.weights
|
return self
|
||||||
return field_value
|
|
||||||
|
|
||||||
@root_validator
|
@model_validator(mode="after")
|
||||||
def offload_ratio_check(cls, values):
|
def offload_ratio_check(self):
|
||||||
offload_config = getattr(values, "offload_optimizer", {})
|
offload_config = self.offload_optimizer
|
||||||
if offload_config and offload_config.ratio < 1.0:
|
if offload_config and offload_config.ratio < 1.0:
|
||||||
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
|
assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
|
||||||
return values
|
return self
|
||||||
|
@ -5,7 +5,9 @@
|
|||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from deepspeed.pydantic_v1 import Field, validator
|
from pydantic import Field, model_validator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel):
|
|||||||
`nvme`.
|
`nvme`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nvme_path: Path = None
|
nvme_path: Optional[Path] = None
|
||||||
""" Filesystem path for NVMe device for parameter offloading. """
|
""" Filesystem path for NVMe device for parameter offloading. """
|
||||||
|
|
||||||
buffer_count: int = Field(5, ge=0)
|
buffer_count: int = Field(5, ge=0)
|
||||||
@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
|||||||
`nvme`. Optimizer computation is offload to CPU regardless of device option.
|
`nvme`. Optimizer computation is offload to CPU regardless of device option.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nvme_path: Path = None
|
nvme_path: Optional[Path] = None
|
||||||
""" Filesystem path for NVMe device for optimizer state offloading. """
|
""" Filesystem path for NVMe device for optimizer state offloading. """
|
||||||
|
|
||||||
buffer_count: int = Field(4, ge=0)
|
buffer_count: int = Field(4, ge=0)
|
||||||
@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
|||||||
fast_init: bool = False
|
fast_init: bool = False
|
||||||
""" Enable fast optimizer initialization when offloading to NVMe. """
|
""" Enable fast optimizer initialization when offloading to NVMe. """
|
||||||
|
|
||||||
@validator("pipeline_read", "pipeline_write", always=True)
|
|
||||||
def set_pipeline(cls, field_value, values):
|
|
||||||
values["pipeline"] = field_value or values.get("pipeline", False)
|
|
||||||
return field_value
|
|
||||||
|
|
||||||
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
||||||
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def set_pipeline(self):
|
||||||
|
pipeline = self.pipeline_read or self.pipeline_write
|
||||||
|
self.__dict__["pipeline"] = pipeline
|
||||||
|
return self
|
||||||
|
@ -725,8 +725,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
|||||||
def get_first_param_index(self, group_id, param_group, partition_id):
|
def get_first_param_index(self, group_id, param_group, partition_id):
|
||||||
for index, param in enumerate(param_group):
|
for index, param in enumerate(param_group):
|
||||||
param_id = self.get_param_id(param)
|
param_id = self.get_param_id(param)
|
||||||
if partition_id in self.param_to_partition_ids[group_id][param_id]:
|
if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]:
|
||||||
return index
|
if partition_id in self.param_to_partition_ids[group_id][param_id]:
|
||||||
|
return index
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def initialize_gradient_partitioning_data_structures(self):
|
def initialize_gradient_partitioning_data_structures(self):
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
autodoc_pydantic
|
autodoc_pydantic>=2.0.0
|
||||||
docutils<0.18
|
docutils<0.18
|
||||||
hjson
|
hjson
|
||||||
packaging
|
packaging
|
||||||
psutil
|
psutil
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
pydantic<2.0.0
|
pydantic>=2.0.0
|
||||||
recommonmark
|
recommonmark
|
||||||
sphinx_rtd_theme
|
sphinx_rtd_theme
|
||||||
torch
|
torch
|
||||||
|
@ -4,6 +4,6 @@ numpy
|
|||||||
packaging>=20.0
|
packaging>=20.0
|
||||||
psutil
|
psutil
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
pydantic
|
pydantic>=2.0.0
|
||||||
torch
|
torch
|
||||||
tqdm
|
tqdm
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from deepspeed.pydantic_v1 import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from deepspeed.inference.v2.ragged import DSStateManagerConfig
|
from deepspeed.inference.v2.ragged import DSStateManagerConfig
|
||||||
|
|
||||||
|
@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
|
|||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
assert not status
|
assert not status
|
||||||
print("Failed but All is well")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
assert ds_config.train_batch_size == batch
|
assert ds_config.train_batch_size == batch
|
||||||
assert ds_config.train_micro_batch_size_per_gpu == micro_batch
|
assert ds_config.train_micro_batch_size_per_gpu == micro_batch
|
||||||
assert ds_config.gradient_accumulation_steps == gas
|
assert ds_config.gradient_accumulation_steps == gas
|
||||||
print("All is well")
|
|
||||||
|
|
||||||
|
|
||||||
#Tests different batch config provided in deepspeed json file
|
#Tests different batch config provided in deepspeed json file
|
||||||
|
@ -4,18 +4,25 @@
|
|||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import List
|
import os
|
||||||
from deepspeed.pydantic_v1 import Field, ValidationError
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field, ValidationError
|
||||||
|
|
||||||
from deepspeed.runtime import config as ds_config
|
from deepspeed.runtime import config as ds_config
|
||||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||||
|
|
||||||
|
|
||||||
class SimpleConf(DeepSpeedConfigModel):
|
class SimpleConf(DeepSpeedConfigModel):
|
||||||
param_1: int = 0
|
param_1: int = 0
|
||||||
param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x]))
|
param_2_old: Optional[str] = Field(None,
|
||||||
param_2: List[str] = None
|
json_schema_extra={
|
||||||
|
"deprecated": True,
|
||||||
|
"new_param": "param_2",
|
||||||
|
"new_param_fn": (lambda x: [x])
|
||||||
|
})
|
||||||
|
param_2: Optional[List[str]] = None
|
||||||
param_3: int = Field(0, alias="param_3_alias")
|
param_3: int = Field(0, alias="param_3_alias")
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user