Pydantic v2 migration (#5167)

Pydantic v2 has been out for some time now. We have been relying on
using the v1 API available in v2 until now. This is a refresh of #3902
to bring proper v2 support to DeepSpeed.

Corresponding DeepSpeed-MII PR
[here](https://github.com/microsoft/DeepSpeed-MII/pull/423).

@loadams

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com>
Co-authored-by: Abhishek Kulkarni <abkulkarni@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
This commit is contained in:
Michael Wyatt
2024-08-22 15:38:13 -07:00
committed by GitHub
parent 8c2be7e942
commit 0a4457cc48
17 changed files with 198 additions and 188 deletions

View File

@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install pydantic==1.10.11
# Update packages included in the container that do not support pydantic 2+ to versions that do
python -m pip install thinc spacy confection --upgrade
python -m pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment

View File

@ -3,20 +3,12 @@
# DeepSpeed Team
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .constants import *
from ..pydantic_v1 import BaseModel
class CommsConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
class CommsLoggerConfig(CommsConfig):
class CommsLoggerConfig(DeepSpeedConfigModel):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT

View File

@ -5,38 +5,25 @@
import torch
import deepspeed
from deepspeed.pydantic_v1 import Field, validator
from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from typing import Dict, Union
from typing import Dict, Union, Optional
from enum import Enum
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
int8 = torch.int8, "torch.int8", "int8"
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
int8 = (torch.int8, "torch.int8", "int8")
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
@classmethod
def from_str(cls, value: str):
for dtype in cls:
if value in dtype.value:
return dtype
raise ValueError(f"'{value}' is not a valid DtypeEnum")
class MoETypeEnum(str, Enum):
@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):
class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
enabled: bool = True
num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1
class WeightQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}
class ActivationQuantConfig(BaseQuantConfig):
enabled = True
enabled: bool = True
class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
enabled: bool = True
class QuantizationConfig(DeepSpeedConfigModel):
@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):
# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None
checkpoint_dir: Optional[str] = None
save_mp_checkpoint_path: Optional[str] = None
base_dir: Optional[str] = None
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)`
"""
dtype: DtypeEnum = torch.float16
dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""
#todo: refactor the following 3 into the new checkpoint_config
checkpoint: Union[str, Dict] = None
checkpoint: Optional[Union[str, Dict]] = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
specifying whether the inference-module is created with empty or real Tensor
"""
save_mp_checkpoint_path: str = None
save_mp_checkpoint_path: Optional[str] = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
json_schema_extra={
"deprecated": True,
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
})
injection_policy: Dict = Field(None, alias="injection_dict")
injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
"""
injection_policy_tuple: tuple = None
injection_policy_tuple: Optional[tuple] = None
""" TODO: Add docs """
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens")
"""
@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
transposed_mode: bool = Field(False, alias="transposed_mode")
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
ep_group: object = Field(None,
alias="expert_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_group"
})
ep_mp_group: object = Field(None,
alias="expert_mp_group",
json_schema_extra={
"deprecated": True,
"new_param": "moe.ep_mp_group"
})
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
json_schema_extra={
"deprecated": True,
"new_param": "moe.type"
})
@validator("moe")
@field_validator("dtype", mode="before")
def validate_dtype(cls, field_value, values):
if isinstance(field_value, str):
return DtypeEnum.from_str(field_value).value[0]
if isinstance(field_value, torch.dtype):
return field_value
raise TypeError(f"Invalid type for dtype: {type(field_value)}")
@field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
@validator("use_triton")
@field_validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}

View File

@ -3,8 +3,9 @@
# DeepSpeed Team
from pydantic import Field
from typing import Optional
from deepspeed.pydantic_v1 import Field
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig

View File

@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel):
"""
A class to represent a tensor specification.
"""
dtype: Optional[str]
shape: Optional[Tuple[int, ...]]
strides: Optional[Tuple[int, ...]]
dtype: Optional[str] = None
shape: Optional[Tuple[int, ...]] = None
strides: Optional[Tuple[int, ...]] = None
offset: int
@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel):
"""
A class to represent a parameter specification.
"""
core_param: TensorMetadata = None
core_param: Optional[TensorMetadata] = None
aux_params: Dict[str, TensorMetadata] = {}

View File

@ -6,7 +6,7 @@
from enum import Enum
from typing import Tuple
from deepspeed.pydantic_v1 import PositiveInt, validator
from pydantic import PositiveInt, model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum
@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
"""
@validator("max_ragged_sequence_count")
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
@model_validator(mode="after")
def max_ragged_sequence_count_validator(self):
# If the attributes below failed their validation they won't appear in the values dict.
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
return v
assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
return self

View File

@ -5,7 +5,7 @@
from typing import Optional
from deepspeed.pydantic_v1 import root_validator
from pydantic import model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
group: str = None
group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """
team: str = None
team: Optional[str] = None
""" Name for the WandB team. """
project: str = "deepspeed"
@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """
@root_validator
def check_enabled(cls, values):
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
"csv_monitor").enabled or values.get("comet").enabled
return values
@model_validator(mode="after")
def check_enabled(self):
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
self.__dict__["enabled"] = enabled
return self

View File

@ -1,16 +0,0 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""Pydantic v1 compatibility module.
Pydantic v2 introduced breaking changes that hinder its adoption:
https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
as a pydantic-version-agnostic alias for pydantic's v1 API.
"""
try:
from pydantic.v1 import * # noqa: F401
except ImportError:
from pydantic import * # noqa: F401

View File

@ -5,11 +5,12 @@
"""
Collection of DeepSpeed configuration utilities
"""
import json
import collections
import collections.abc
import json
import torch
from functools import reduce
from deepspeed.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict, field_serializer
from deepspeed.utils import logger
@ -54,67 +55,73 @@ class DeepSpeedConfigModel(BaseModel):
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
super().__init__(**data)
self._deprecated_fields_check(self)
self._deprecated_fields_check()
def _process_deprecated_field(self, pydantic_config, field):
def _process_deprecated_field(self, dep_field):
# Get information about the deprecated field
fields_set = pydantic_config.__fields_set__
dep_param = field.name
kwargs = field.field_info.extra
pydantic_config = self
fields_set = pydantic_config.model_fields_set
kwargs = pydantic_config.model_fields[dep_field].json_schema_extra
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
param_value = new_param_fn(getattr(pydantic_config, dep_param))
new_param = kwargs.get("new_param", "")
param_value = new_param_fn(getattr(pydantic_config, dep_field))
new_field = kwargs.get("new_param", "")
dep_msg = kwargs.get("deprecated_msg", "")
if dep_param in fields_set:
logger.warning(f"Config parameter {dep_param} is deprecated" +
(f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
if dep_field in fields_set:
logger.warning(f"Config parameter {dep_field} is deprecated" +
(f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
# Check if there is a new param and if it should be set with a value
if new_param and kwargs.get("set_new_param", True):
if new_field and kwargs.get("set_new_param", True):
# Remove the deprecate field if there is a replacing field
try:
delattr(pydantic_config, dep_param)
delattr(pydantic_config, dep_field)
except Exception as e:
logger.error(f"Tried removing deprecated '{dep_param}' from config")
logger.error(f"Tried removing deprecated '{dep_field}' from config")
raise e
# Set new param value
new_param_nested = new_param.split(".")
new_param_nested = new_field.split(".")
if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get
# the fields set for that subconfig
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
fields_set = pydantic_config.__fields_set__
fields_set = pydantic_config.model_fields_set
new_param_name = new_param_nested[-1]
assert (
new_param_name not in fields_set
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
# A custom function for converting the old param value to new param value can be provided
try:
setattr(pydantic_config, new_param_name, param_value)
except Exception as e:
logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
raise e
def _deprecated_fields_check(self, pydantic_config):
fields = pydantic_config.__fields__
for field in fields.values():
if field.field_info.extra.get("deprecated", False):
self._process_deprecated_field(pydantic_config, field)
def _deprecated_fields_check(self):
fields = self.model_fields
for field_name, field_info in fields.items():
if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
self._process_deprecated_field(field_name)
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
allow_population_by_field_name = True
extra = "forbid"
arbitrary_types_allowed = True
model_config = ConfigDict(
validate_default=True,
validate_assignment=True,
use_enum_values=True,
populate_by_name=True,
extra="forbid",
arbitrary_types_allowed=True,
protected_namespaces=(),
)
@field_serializer("dtype", check_fields=False)
def serialize_torch_dtype(dtype: torch.dtype) -> str:
return str(dtype)
def get_config_default(config, field_name):
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
assert not config.__fields__.get(
field_name).required, f"'{field_name}' is a required field and does not have a default value"
return config.__fields__.get(field_name).default
assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}"
assert not config.model_fields.get(
field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
return config.model_fields.get(field_name).get_default()
class pp_int(int):

View File

@ -6,7 +6,7 @@
import sys
from typing import Optional
from enum import Enum
from deepspeed.pydantic_v1 import Field, validator, root_validator
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
@ -30,7 +30,7 @@ ZeRO optimization should be enabled as:
"reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] (deprecated),
"cpu_offload_params" : [true|false] (deprecated),
"cpu_offload_param" : [true|false] (deprecated),
"cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000,
"offload_param": {...},
@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
the allgather for large model sizes
"""
overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
"""
Attempts to overlap the reduction of the gradients with backward computation
"""
@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
"""
cpu_offload_param: bool = Field(
cpu_offload_param: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_param",
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
json_schema_extra={
"deprecated": True,
"new_param": "offload_param",
"new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
if val else None)
},
)
""" Deprecated, please use ``offload_param`` """
cpu_offload_use_pin_memory: bool = Field(
cpu_offload_use_pin_memory: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_param or offload_optimizer",
set_new_param=False,
json_schema_extra={
"deprecated": True,
"new_param": "offload_param or offload_optimizer",
"set_new_param": False
},
)
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
cpu_offload: bool = Field(
cpu_offload: Optional[bool] = Field(
None,
deprecated=True,
new_param="offload_optimizer",
new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
json_schema_extra={
"deprecated":
True,
"new_param":
"offload_optimizer",
"new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
if val else None)
},
)
""" Deprecated, please use ``offload_optimizer`` """
@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
deprecated=True,
new_param="gather_16bit_weights_on_model_save")
json_schema_extra={
"deprecated": True,
"new_param": "gather_16bit_weights_on_model_save"
})
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
ignore_unused_parameters: bool = True
@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
if field_value is None:
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
field_value = values["stage"] == ZeroStageEnum.weights
return field_value
@model_validator(mode="after")
def overlap_comm_valid(self):
if self.overlap_comm is None:
self.overlap_comm = self.stage == ZeroStageEnum.weights
return self
@root_validator
def offload_ratio_check(cls, values):
offload_config = getattr(values, "offload_optimizer", {})
@model_validator(mode="after")
def offload_ratio_check(self):
offload_config = self.offload_optimizer
if offload_config and offload_config.ratio < 1.0:
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return values
assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return self

View File

@ -5,7 +5,9 @@
from enum import Enum
from pathlib import Path
from deepspeed.pydantic_v1 import Field, validator
from pydantic import Field, model_validator
from typing import Optional
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel):
`nvme`.
"""
nvme_path: Path = None
nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for parameter offloading. """
buffer_count: int = Field(5, ge=0)
@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
`nvme`. Optimizer computation is offload to CPU regardless of device option.
"""
nvme_path: Path = None
nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for optimizer state offloading. """
buffer_count: int = Field(4, ge=0)
@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
fast_init: bool = False
""" Enable fast optimizer initialization when offloading to NVMe. """
@validator("pipeline_read", "pipeline_write", always=True)
def set_pipeline(cls, field_value, values):
values["pipeline"] = field_value or values.get("pipeline", False)
return field_value
ratio: float = Field(1.0, ge=0.0, le=1.0)
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
@model_validator(mode="after")
def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write
self.__dict__["pipeline"] = pipeline
return self

View File

@ -725,8 +725,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]:
if partition_id in self.param_to_partition_ids[group_id][param_id]:
return index
return None
def initialize_gradient_partitioning_data_structures(self):

View File

@ -1,10 +1,10 @@
autodoc_pydantic
autodoc_pydantic>=2.0.0
docutils<0.18
hjson
packaging
psutil
py-cpuinfo
pydantic<2.0.0
pydantic>=2.0.0
recommonmark
sphinx_rtd_theme
torch

View File

@ -4,6 +4,6 @@ numpy
packaging>=20.0
psutil
py-cpuinfo
pydantic
pydantic>=2.0.0
torch
tqdm

View File

@ -5,7 +5,7 @@
import pytest
from deepspeed.pydantic_v1 import ValidationError
from pydantic import ValidationError
from deepspeed.inference.v2.ragged import DSStateManagerConfig

View File

@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
if not success:
assert not status
print("Failed but All is well")
return
assert ds_config.train_batch_size == batch
assert ds_config.train_micro_batch_size_per_gpu == micro_batch
assert ds_config.gradient_accumulation_steps == gas
print("All is well")
#Tests different batch config provided in deepspeed json file

View File

@ -4,18 +4,25 @@
# DeepSpeed Team
import pytest
import os
import json
from typing import List
from deepspeed.pydantic_v1 import Field, ValidationError
import os
from typing import List, Optional
from pydantic import Field, ValidationError
from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
class SimpleConf(DeepSpeedConfigModel):
param_1: int = 0
param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x]))
param_2: List[str] = None
param_2_old: Optional[str] = Field(None,
json_schema_extra={
"deprecated": True,
"new_param": "param_2",
"new_param_fn": (lambda x: [x])
})
param_2: Optional[List[str]] = None
param_3: int = Field(0, alias="param_3_alias")