Pydantic v2 migration (#5167)

Pydantic v2 has been out for some time now. We have been relying on
using the v1 API available in v2 until now. This is a refresh of #3902
to bring proper v2 support to DeepSpeed.

Corresponding DeepSpeed-MII PR
[here](https://github.com/microsoft/DeepSpeed-MII/pull/423).

@loadams

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com>
Co-authored-by: Abhishek Kulkarni <abkulkarni@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
This commit is contained in:
Michael Wyatt
2024-08-22 15:38:13 -07:00
committed by GitHub
parent 8c2be7e942
commit 0a4457cc48
17 changed files with 198 additions and 188 deletions

View File

@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed - name: Install deepspeed
run: | run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install pydantic==1.10.11 # Update packages included in the container that do not support pydantic 2+ to versions that do
python -m pip install thinc spacy confection --upgrade
python -m pip install .[dev,1bit,autotuning,inf] python -m pip install .[dev,1bit,autotuning,inf]
ds_report ds_report
- name: Python environment - name: Python environment

View File

@ -3,20 +3,12 @@
# DeepSpeed Team # DeepSpeed Team
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .constants import * from .constants import *
from ..pydantic_v1 import BaseModel
class CommsConfig(BaseModel): class CommsLoggerConfig(DeepSpeedConfigModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
class CommsLoggerConfig(CommsConfig):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT

View File

@ -5,38 +5,25 @@
import torch import torch
import deepspeed import deepspeed
from deepspeed.pydantic_v1 import Field, validator from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from typing import Dict, Union from typing import Dict, Union, Optional
from enum import Enum from enum import Enum
class DtypeEnum(Enum): class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype) fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat" int8 = (torch.int8, "torch.int8", "int8")
int8 = torch.int8, "torch.int8", "int8"
# Copied from https://stackoverflow.com/a/43210118 @classmethod
# Allows us to use multiple values for each Enum index and returns first def from_str(cls, value: str):
# listed value when Enum is called for dtype in cls:
def __new__(cls, *values): if value in dtype.value:
obj = object.__new__(cls) return dtype
# first value is canonical value raise ValueError(f"'{value}' is not a valid DtypeEnum")
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
class MoETypeEnum(str, Enum): class MoETypeEnum(str, Enum):
@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):
class BaseQuantConfig(DeepSpeedConfigModel): class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True enabled: bool = True
num_bits = 8 num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1 q_groups: int = 1
class WeightQuantConfig(BaseQuantConfig): class WeightQuantConfig(BaseQuantConfig):
enabled = True enabled: bool = True
quantized_initialization: Dict = {} quantized_initialization: Dict = {}
post_init_quant: Dict = {} post_init_quant: Dict = {}
class ActivationQuantConfig(BaseQuantConfig): class ActivationQuantConfig(BaseQuantConfig):
enabled = True enabled: bool = True
class QKVQuantConfig(DeepSpeedConfigModel): class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True enabled: bool = True
class QuantizationConfig(DeepSpeedConfigModel): class QuantizationConfig(DeepSpeedConfigModel):
@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):
# todo: brainstorm on how to do ckpt loading for DS inference # todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel): class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None checkpoint_dir: Optional[str] = None
save_mp_checkpoint_path: str = None save_mp_checkpoint_path: Optional[str] = None
base_dir: str = None base_dir: Optional[str] = None
class DeepSpeedInferenceConfig(DeepSpeedConfigModel): class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)` `(attention_output projection, transformer output projection)`
""" """
dtype: DtypeEnum = torch.float16 dtype: torch.dtype = torch.float16
""" """
Desired model data type, will convert model to this type. Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float` Supported target types: `torch.half`, `torch.int8`, `torch.float`
@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
""" """
#todo: refactor the following 3 into the new checkpoint_config #todo: refactor the following 3 into the new checkpoint_config
checkpoint: Union[str, Dict] = None checkpoint: Optional[Union[str, Dict]] = None
""" """
Path to deepspeed compatible checkpoint or path to JSON with load policy. Path to deepspeed compatible checkpoint or path to JSON with load policy.
""" """
@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
specifying whether the inference-module is created with empty or real Tensor specifying whether the inference-module is created with empty or real Tensor
""" """
save_mp_checkpoint_path: str = None save_mp_checkpoint_path: Optional[str] = None
""" """
The path for which we want to save the loaded model with a checkpoint. This The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the feature is used for adjusting the parallelism degree to help alleviate the
@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field( replace_method: str = Field(
"auto", "auto",
deprecated=True, json_schema_extra={
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference") "deprecated": True,
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
})
injection_policy: Dict = Field(None, alias="injection_dict") injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
""" """
Dictionary mapping a client nn.Module to its corresponding injection Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}` policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
""" """
injection_policy_tuple: tuple = None injection_policy_tuple: Optional[tuple] = None
""" TODO: Add docs """ """ TODO: Add docs """
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens") max_out_tokens: int = Field(1024, alias="max_tokens")
""" """
@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
transposed_mode: bool = Field(False, alias="transposed_mode") transposed_mode: bool = Field(False, alias="transposed_mode")
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
""" """
Desired model parallel size, default is 1 meaning no model parallelism. Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model Deprecated, please use the ``tensor_parallel` config to control model
parallelism. parallelism.
""" """
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group") ep_group: object = Field(None,
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group") alias="expert_group",
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts") json_schema_extra={
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type") "deprecated": True,
"new_param": "moe.ep_group"
})
ep_mp_group: object = Field(None,
alias="expert_mp_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_mp_group"
})
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
json_schema_extra={
"deprecated": True,
"new_param": "moe.type"
})
@validator("moe") @field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")
@field_validator("moe")
def moe_backward_compat(cls, field_value, values): def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool): if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value) return DeepSpeedMoEConfig(moe=field_value)
return field_value return field_value
@validator("use_triton") @field_validator("use_triton")
def has_triton(cls, field_value, values): def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON: if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels') raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value return field_value
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}

View File

@ -3,8 +3,9 @@
# DeepSpeed Team # DeepSpeed Team
from pydantic import Field
from typing import Optional from typing import Optional
from deepspeed.pydantic_v1 import Field
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig from .ragged import DSStateManagerConfig

View File

@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel):
""" """
A class to represent a tensor specification. A class to represent a tensor specification.
""" """
dtype: Optional[str] dtype: Optional[str] = None
shape: Optional[Tuple[int, ...]] shape: Optional[Tuple[int, ...]] = None
strides: Optional[Tuple[int, ...]] strides: Optional[Tuple[int, ...]] = None
offset: int offset: int
@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel):
""" """
A class to represent a parameter specification. A class to represent a parameter specification.
""" """
core_param: TensorMetadata = None core_param: Optional[TensorMetadata] = None
aux_params: Dict[str, TensorMetadata] = {} aux_params: Dict[str, TensorMetadata] = {}

View File

@ -6,7 +6,7 @@
from enum import Enum from enum import Enum
from typing import Tuple from typing import Tuple
from deepspeed.pydantic_v1 import PositiveInt, validator from pydantic import PositiveInt, model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum from ..inference_utils import DtypeEnum
@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported. Enable tracking for offloading KV-cache to host memory. Currently unsupported.
""" """
@validator("max_ragged_sequence_count") @model_validator(mode="after")
def max_ragged_sequence_count_validator(cls, v: int, values: dict): def max_ragged_sequence_count_validator(self):
# If the attributes below failed their validation they won't appear in the values dict. # If the attributes below failed their validation they won't appear in the values dict.
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]: assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences") assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]: return self
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
return v

View File

@ -5,7 +5,7 @@
from typing import Optional from typing import Optional
from deepspeed.pydantic_v1 import root_validator from pydantic import model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """ """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """
group: str = None group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """ """ Name for the WandB group. This can be used to group together runs. """
team: str = None team: Optional[str] = None
""" Name for the WandB team. """ """ Name for the WandB team. """
project: str = "deepspeed" project: str = "deepspeed"
@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {} csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """ """ Local CSV output of monitoring data. """
@root_validator @model_validator(mode="after")
def check_enabled(cls, values): def check_enabled(self):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
"csv_monitor").enabled or values.get("comet").enabled self.__dict__["enabled"] = enabled
return values return self

View File

@ -1,16 +0,0 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""Pydantic v1 compatibility module.
Pydantic v2 introduced breaking changes that hinder its adoption:
https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
as a pydantic-version-agnostic alias for pydantic's v1 API.
"""
try:
from pydantic.v1 import * # noqa: F401
except ImportError:
from pydantic import * # noqa: F401

View File

@ -5,11 +5,12 @@
""" """
Collection of DeepSpeed configuration utilities Collection of DeepSpeed configuration utilities
""" """
import json
import collections import collections
import collections.abc import json
import torch
from functools import reduce from functools import reduce
from deepspeed.pydantic_v1 import BaseModel from pydantic import BaseModel, ConfigDict, field_serializer
from deepspeed.utils import logger from deepspeed.utils import logger
@ -54,67 +55,73 @@ class DeepSpeedConfigModel(BaseModel):
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")} data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
super().__init__(**data) super().__init__(**data)
self._deprecated_fields_check(self) self._deprecated_fields_check()
def _process_deprecated_field(self, pydantic_config, field): def _process_deprecated_field(self, dep_field):
# Get information about the deprecated field # Get information about the deprecated field
fields_set = pydantic_config.__fields_set__ pydantic_config = self
dep_param = field.name fields_set = pydantic_config.model_fields_set
kwargs = field.field_info.extra kwargs = pydantic_config.model_fields[dep_field].json_schema_extra
new_param_fn = kwargs.get("new_param_fn", lambda x: x) new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param)) param_value = new_param_fn(getattr(pydantic_config, dep_field))
new_param = kwargs.get("new_param", "") new_field = kwargs.get("new_param", "")
dep_msg = kwargs.get("deprecated_msg", "") dep_msg = kwargs.get("deprecated_msg", "")
if dep_param in fields_set: if dep_field in fields_set:
logger.warning(f"Config parameter {dep_param} is deprecated" + logger.warning(f"Config parameter {dep_field} is deprecated" +
(f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else "")) (f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
# Check if there is a new param and if it should be set with a value # Check if there is a new param and if it should be set with a value
if new_param and kwargs.get("set_new_param", True): if new_field and kwargs.get("set_new_param", True):
# Remove the deprecate field if there is a replacing field # Remove the deprecate field if there is a replacing field
try: try:
delattr(pydantic_config, dep_param) delattr(pydantic_config, dep_field)
except Exception as e: except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config") logger.error(f"Tried removing deprecated '{dep_field}' from config")
raise e raise e
# Set new param value # Set new param value
new_param_nested = new_param.split(".") new_param_nested = new_field.split(".")
if len(new_param_nested) > 1: if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get # If the new param exists in a subconfig, we need to get
# the fields set for that subconfig # the fields set for that subconfig
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config) pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
fields_set = pydantic_config.__fields_set__ fields_set = pydantic_config.model_fields_set
new_param_name = new_param_nested[-1] new_param_name = new_param_nested[-1]
assert ( assert (
new_param_name not in fields_set new_param_name not in fields_set
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together" ), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
# A custom function for converting the old param value to new param value can be provided # A custom function for converting the old param value to new param value can be provided
try: try:
setattr(pydantic_config, new_param_name, param_value) setattr(pydantic_config, new_param_name, param_value)
except Exception as e: except Exception as e:
logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'") logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
raise e raise e
def _deprecated_fields_check(self, pydantic_config): def _deprecated_fields_check(self):
fields = pydantic_config.__fields__ fields = self.model_fields
for field in fields.values(): for field_name, field_info in fields.items():
if field.field_info.extra.get("deprecated", False): if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
self._process_deprecated_field(pydantic_config, field) self._process_deprecated_field(field_name)
class Config: model_config = ConfigDict(
validate_all = True validate_default=True,
validate_assignment = True validate_assignment=True,
use_enum_values = True use_enum_values=True,
allow_population_by_field_name = True populate_by_name=True,
extra = "forbid" extra="forbid",
arbitrary_types_allowed = True arbitrary_types_allowed=True,
protected_namespaces=(),
)
@field_serializer("dtype", check_fields=False)
def serialize_torch_dtype(dtype: torch.dtype) -> str:
return str(dtype)
def get_config_default(config, field_name): def get_config_default(config, field_name):
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}" assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}"
assert not config.__fields__.get( assert not config.model_fields.get(
field_name).required, f"'{field_name}' is a required field and does not have a default value" field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
return config.__fields__.get(field_name).default return config.model_fields.get(field_name).get_default()
class pp_int(int): class pp_int(int):

View File

@ -6,7 +6,7 @@
import sys import sys
from typing import Optional from typing import Optional
from enum import Enum from enum import Enum
from deepspeed.pydantic_v1 import Field, validator, root_validator from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
@ -30,7 +30,7 @@ ZeRO optimization should be enabled as:
"reduce_bucket_size": 500000000, "reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false], "load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] (deprecated), "cpu_offload": [true|false] (deprecated),
"cpu_offload_params" : [true|false] (deprecated), "cpu_offload_param" : [true|false] (deprecated),
"cpu_offload_use_pin_memory": [true|false] (deprecated), "cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000, "sub_group_size" : 1000000000000,
"offload_param": {...}, "offload_param": {...},
@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
the allgather for large model sizes the allgather for large model sizes
""" """
overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below) overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
""" """
Attempts to overlap the reduction of the gradients with backward computation Attempts to overlap the reduction of the gradients with backward computation
""" """
@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
parameters). Used by ZeRO3-Offload and ZeRO-Infinity parameters). Used by ZeRO3-Offload and ZeRO-Infinity
""" """
cpu_offload_param: bool = Field( cpu_offload_param: Optional[bool] = Field(
None, None,
deprecated=True, json_schema_extra={
new_param="offload_param", "deprecated": True,
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None), "new_param": "offload_param",
"new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
if val else None)
},
) )
""" Deprecated, please use ``offload_param`` """ """ Deprecated, please use ``offload_param`` """
cpu_offload_use_pin_memory: bool = Field( cpu_offload_use_pin_memory: Optional[bool] = Field(
None, None,
deprecated=True, json_schema_extra={
new_param="offload_param or offload_optimizer", "deprecated": True,
set_new_param=False, "new_param": "offload_param or offload_optimizer",
"set_new_param": False
},
) )
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """ """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
cpu_offload: bool = Field( cpu_offload: Optional[bool] = Field(
None, None,
deprecated=True, json_schema_extra={
new_param="offload_optimizer", "deprecated":
new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None), True,
"new_param":
"offload_optimizer",
"new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
if val else None)
},
) )
""" Deprecated, please use ``offload_optimizer`` """ """ Deprecated, please use ``offload_optimizer`` """
@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
""" """
stage3_gather_fp16_weights_on_model_save: bool = Field(False, stage3_gather_fp16_weights_on_model_save: bool = Field(False,
deprecated=True, json_schema_extra={
new_param="gather_16bit_weights_on_model_save") "deprecated": True,
"new_param": "gather_16bit_weights_on_model_save"
})
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """ """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """
ignore_unused_parameters: bool = True ignore_unused_parameters: bool = True
@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
""" """
# Validators # Validators
@validator("overlap_comm") @model_validator(mode="after")
def overlap_comm_valid(cls, field_value, values): def overlap_comm_valid(self):
if field_value is None: if self.overlap_comm is None:
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" self.overlap_comm = self.stage == ZeroStageEnum.weights
field_value = values["stage"] == ZeroStageEnum.weights return self
return field_value
@root_validator @model_validator(mode="after")
def offload_ratio_check(cls, values): def offload_ratio_check(self):
offload_config = getattr(values, "offload_optimizer", {}) offload_config = self.offload_optimizer
if offload_config and offload_config.ratio < 1.0: if offload_config and offload_config.ratio < 1.0:
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return values return self

View File

@ -5,7 +5,9 @@
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from deepspeed.pydantic_v1 import Field, validator from pydantic import Field, model_validator
from typing import Optional
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel):
`nvme`. `nvme`.
""" """
nvme_path: Path = None nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for parameter offloading. """ """ Filesystem path for NVMe device for parameter offloading. """
buffer_count: int = Field(5, ge=0) buffer_count: int = Field(5, ge=0)
@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
`nvme`. Optimizer computation is offload to CPU regardless of device option. `nvme`. Optimizer computation is offload to CPU regardless of device option.
""" """
nvme_path: Path = None nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for optimizer state offloading. """ """ Filesystem path for NVMe device for optimizer state offloading. """
buffer_count: int = Field(4, ge=0) buffer_count: int = Field(4, ge=0)
@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
fast_init: bool = False fast_init: bool = False
""" Enable fast optimizer initialization when offloading to NVMe. """ """ Enable fast optimizer initialization when offloading to NVMe. """
@validator("pipeline_read", "pipeline_write", always=True)
def set_pipeline(cls, field_value, values):
values["pipeline"] = field_value or values.get("pipeline", False)
return field_value
ratio: float = Field(1.0, ge=0.0, le=1.0) ratio: float = Field(1.0, ge=0.0, le=1.0)
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
@model_validator(mode="after")
def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write
self.__dict__["pipeline"] = pipeline
return self

View File

@ -725,8 +725,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def get_first_param_index(self, group_id, param_group, partition_id): def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group): for index, param in enumerate(param_group):
param_id = self.get_param_id(param) param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]: if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]:
return index if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None return None
def initialize_gradient_partitioning_data_structures(self): def initialize_gradient_partitioning_data_structures(self):

View File

@ -1,10 +1,10 @@
autodoc_pydantic autodoc_pydantic>=2.0.0
docutils<0.18 docutils<0.18
hjson hjson
packaging packaging
psutil psutil
py-cpuinfo py-cpuinfo
pydantic<2.0.0 pydantic>=2.0.0
recommonmark recommonmark
sphinx_rtd_theme sphinx_rtd_theme
torch torch

View File

@ -4,6 +4,6 @@ numpy
packaging>=20.0 packaging>=20.0
psutil psutil
py-cpuinfo py-cpuinfo
pydantic pydantic>=2.0.0
torch torch
tqdm tqdm

View File

@ -5,7 +5,7 @@
import pytest import pytest
from deepspeed.pydantic_v1 import ValidationError from pydantic import ValidationError
from deepspeed.inference.v2.ragged import DSStateManagerConfig from deepspeed.inference.v2.ragged import DSStateManagerConfig

View File

@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
if not success: if not success:
assert not status assert not status
print("Failed but All is well")
return return
assert ds_config.train_batch_size == batch assert ds_config.train_batch_size == batch
assert ds_config.train_micro_batch_size_per_gpu == micro_batch assert ds_config.train_micro_batch_size_per_gpu == micro_batch
assert ds_config.gradient_accumulation_steps == gas assert ds_config.gradient_accumulation_steps == gas
print("All is well")
#Tests different batch config provided in deepspeed json file #Tests different batch config provided in deepspeed json file

View File

@ -4,18 +4,25 @@
# DeepSpeed Team # DeepSpeed Team
import pytest import pytest
import os
import json import json
from typing import List import os
from deepspeed.pydantic_v1 import Field, ValidationError from typing import List, Optional
from pydantic import Field, ValidationError
from deepspeed.runtime import config as ds_config from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
class SimpleConf(DeepSpeedConfigModel): class SimpleConf(DeepSpeedConfigModel):
param_1: int = 0 param_1: int = 0
param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x])) param_2_old: Optional[str] = Field(None,
param_2: List[str] = None json_schema_extra={
"deprecated": True,
"new_param": "param_2",
"new_param_fn": (lambda x: [x])
})
param_2: Optional[List[str]] = None
param_3: int = Field(0, alias="param_3_alias") param_3: int = Field(0, alias="param_3_alias")