mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Pydantic v2 migration (#5167)
Pydantic v2 has been out for some time now. We have been relying on using the v1 API available in v2 until now. This is a refresh of #3902 to bring proper v2 support to DeepSpeed. Corresponding DeepSpeed-MII PR [here](https://github.com/microsoft/DeepSpeed-MII/pull/423). @loadams --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams <loadams@microsoft.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Abhishek Kulkarni <11399+adk9@users.noreply.github.com> Co-authored-by: Abhishek Kulkarni <abkulkarni@microsoft.com> Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
This commit is contained in:
3
.github/workflows/nv-a6000.yml
vendored
3
.github/workflows/nv-a6000.yml
vendored
@ -47,7 +47,8 @@ jobs:
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
||||
python -m pip install pydantic==1.10.11
|
||||
# Update packages included in the container that do not support pydantic 2+ to versions that do
|
||||
python -m pip install thinc spacy confection --upgrade
|
||||
python -m pip install .[dev,1bit,autotuning,inf]
|
||||
ds_report
|
||||
- name: Python environment
|
||||
|
@ -3,20 +3,12 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
from .constants import *
|
||||
from ..pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class CommsConfig(BaseModel):
|
||||
|
||||
class Config:
|
||||
validate_all = True
|
||||
validate_assignment = True
|
||||
use_enum_values = True
|
||||
extra = 'forbid'
|
||||
|
||||
|
||||
class CommsLoggerConfig(CommsConfig):
|
||||
class CommsLoggerConfig(DeepSpeedConfigModel):
|
||||
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
|
||||
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
|
||||
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
|
||||
|
@ -5,38 +5,25 @@
|
||||
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.pydantic_v1 import Field, validator
|
||||
from pydantic import Field, field_validator
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Union, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DtypeEnum(Enum):
|
||||
# The torch dtype must always be the first value (so we return torch.dtype)
|
||||
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
|
||||
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
|
||||
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
|
||||
int8 = torch.int8, "torch.int8", "int8"
|
||||
fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
|
||||
fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
|
||||
bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
|
||||
int8 = (torch.int8, "torch.int8", "int8")
|
||||
|
||||
# Copied from https://stackoverflow.com/a/43210118
|
||||
# Allows us to use multiple values for each Enum index and returns first
|
||||
# listed value when Enum is called
|
||||
def __new__(cls, *values):
|
||||
obj = object.__new__(cls)
|
||||
# first value is canonical value
|
||||
obj._value_ = values[0]
|
||||
for other_value in values[1:]:
|
||||
cls._value2member_map_[other_value] = obj
|
||||
obj._all_values = values
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s.%s: %s>" % (
|
||||
self.__class__.__name__,
|
||||
self._name_,
|
||||
", ".join([repr(v) for v in self._all_values]),
|
||||
)
|
||||
@classmethod
|
||||
def from_str(cls, value: str):
|
||||
for dtype in cls:
|
||||
if value in dtype.value:
|
||||
return dtype
|
||||
raise ValueError(f"'{value}' is not a valid DtypeEnum")
|
||||
|
||||
|
||||
class MoETypeEnum(str, Enum):
|
||||
@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):
|
||||
|
||||
|
||||
class BaseQuantConfig(DeepSpeedConfigModel):
|
||||
enabled = True
|
||||
num_bits = 8
|
||||
enabled: bool = True
|
||||
num_bits: int = 8
|
||||
q_type: QuantTypeEnum = QuantTypeEnum.sym
|
||||
q_groups: int = 1
|
||||
|
||||
|
||||
class WeightQuantConfig(BaseQuantConfig):
|
||||
enabled = True
|
||||
enabled: bool = True
|
||||
quantized_initialization: Dict = {}
|
||||
post_init_quant: Dict = {}
|
||||
|
||||
|
||||
class ActivationQuantConfig(BaseQuantConfig):
|
||||
enabled = True
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class QKVQuantConfig(DeepSpeedConfigModel):
|
||||
enabled = True
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class QuantizationConfig(DeepSpeedConfigModel):
|
||||
@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):
|
||||
|
||||
# todo: brainstorm on how to do ckpt loading for DS inference
|
||||
class InferenceCheckpointConfig(DeepSpeedConfigModel):
|
||||
checkpoint_dir: str = None
|
||||
save_mp_checkpoint_path: str = None
|
||||
base_dir: str = None
|
||||
checkpoint_dir: Optional[str] = None
|
||||
save_mp_checkpoint_path: Optional[str] = None
|
||||
base_dir: Optional[str] = None
|
||||
|
||||
|
||||
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
`(attention_output projection, transformer output projection)`
|
||||
"""
|
||||
|
||||
dtype: DtypeEnum = torch.float16
|
||||
dtype: torch.dtype = torch.float16
|
||||
"""
|
||||
Desired model data type, will convert model to this type.
|
||||
Supported target types: `torch.half`, `torch.int8`, `torch.float`
|
||||
@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
"""
|
||||
|
||||
#todo: refactor the following 3 into the new checkpoint_config
|
||||
checkpoint: Union[str, Dict] = None
|
||||
checkpoint: Optional[Union[str, Dict]] = None
|
||||
"""
|
||||
Path to deepspeed compatible checkpoint or path to JSON with load policy.
|
||||
"""
|
||||
@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
specifying whether the inference-module is created with empty or real Tensor
|
||||
"""
|
||||
|
||||
save_mp_checkpoint_path: str = None
|
||||
save_mp_checkpoint_path: Optional[str] = None
|
||||
"""
|
||||
The path for which we want to save the loaded model with a checkpoint. This
|
||||
feature is used for adjusting the parallelism degree to help alleviate the
|
||||
@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
|
||||
replace_method: str = Field(
|
||||
"auto",
|
||||
deprecated=True,
|
||||
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
|
||||
})
|
||||
|
||||
injection_policy: Dict = Field(None, alias="injection_dict")
|
||||
injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
|
||||
"""
|
||||
Dictionary mapping a client nn.Module to its corresponding injection
|
||||
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
|
||||
"""
|
||||
|
||||
injection_policy_tuple: tuple = None
|
||||
injection_policy_tuple: Optional[tuple] = None
|
||||
""" TODO: Add docs """
|
||||
|
||||
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
|
||||
config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor
|
||||
|
||||
max_out_tokens: int = Field(1024, alias="max_tokens")
|
||||
"""
|
||||
@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
|
||||
|
||||
transposed_mode: bool = Field(False, alias="transposed_mode")
|
||||
|
||||
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
|
||||
mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
|
||||
"""
|
||||
Desired model parallel size, default is 1 meaning no model parallelism.
|
||||
Deprecated, please use the ``tensor_parallel` config to control model
|
||||
parallelism.
|
||||
"""
|
||||
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
|
||||
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
|
||||
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
|
||||
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
|
||||
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
|
||||
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
|
||||
mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
|
||||
ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
|
||||
ep_group: object = Field(None,
|
||||
alias="expert_group",
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "moe.ep_group"
|
||||
})
|
||||
ep_mp_group: object = Field(None,
|
||||
alias="expert_mp_group",
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "moe.ep_mp_group"
|
||||
})
|
||||
moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
|
||||
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "moe.type"
|
||||
})
|
||||
|
||||
@validator("moe")
|
||||
@field_validator("dtype", mode="before")
|
||||
def validate_dtype(cls, field_value, values):
|
||||
if isinstance(field_value, str):
|
||||
return DtypeEnum.from_str(field_value).value[0]
|
||||
if isinstance(field_value, torch.dtype):
|
||||
return field_value
|
||||
raise TypeError(f"Invalid type for dtype: {type(field_value)}")
|
||||
|
||||
@field_validator("moe")
|
||||
def moe_backward_compat(cls, field_value, values):
|
||||
if isinstance(field_value, bool):
|
||||
return DeepSpeedMoEConfig(moe=field_value)
|
||||
return field_value
|
||||
|
||||
@validator("use_triton")
|
||||
@field_validator("use_triton")
|
||||
def has_triton(cls, field_value, values):
|
||||
if field_value and not deepspeed.HAS_TRITON:
|
||||
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
|
||||
return field_value
|
||||
|
||||
class Config:
|
||||
# Get the str representation of the datatype for serialization
|
||||
json_encoders = {torch.dtype: lambda x: str(x)}
|
||||
|
@ -3,8 +3,9 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
from deepspeed.pydantic_v1 import Field
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
from .ragged import DSStateManagerConfig
|
||||
|
||||
|
@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel):
|
||||
"""
|
||||
A class to represent a tensor specification.
|
||||
"""
|
||||
dtype: Optional[str]
|
||||
shape: Optional[Tuple[int, ...]]
|
||||
strides: Optional[Tuple[int, ...]]
|
||||
dtype: Optional[str] = None
|
||||
shape: Optional[Tuple[int, ...]] = None
|
||||
strides: Optional[Tuple[int, ...]] = None
|
||||
offset: int
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel):
|
||||
"""
|
||||
A class to represent a parameter specification.
|
||||
"""
|
||||
core_param: TensorMetadata = None
|
||||
core_param: Optional[TensorMetadata] = None
|
||||
aux_params: Dict[str, TensorMetadata] = {}
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from deepspeed.pydantic_v1 import PositiveInt, validator
|
||||
from pydantic import PositiveInt, model_validator
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
from ..inference_utils import DtypeEnum
|
||||
@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
|
||||
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
|
||||
"""
|
||||
|
||||
@validator("max_ragged_sequence_count")
|
||||
def max_ragged_sequence_count_validator(cls, v: int, values: dict):
|
||||
@model_validator(mode="after")
|
||||
def max_ragged_sequence_count_validator(self):
|
||||
# If the attributes below failed their validation they won't appear in the values dict.
|
||||
if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
|
||||
raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
|
||||
if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
|
||||
raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
|
||||
return v
|
||||
assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
|
||||
assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
|
||||
return self
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from deepspeed.pydantic_v1 import root_validator
|
||||
from pydantic import model_validator
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
|
||||
@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
|
||||
enabled: bool = False
|
||||
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
|
||||
|
||||
group: str = None
|
||||
group: Optional[str] = None
|
||||
""" Name for the WandB group. This can be used to group together runs. """
|
||||
|
||||
team: str = None
|
||||
team: Optional[str] = None
|
||||
""" Name for the WandB team. """
|
||||
|
||||
project: str = "deepspeed"
|
||||
@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
|
||||
csv_monitor: CSVConfig = {}
|
||||
""" Local CSV output of monitoring data. """
|
||||
|
||||
@root_validator
|
||||
def check_enabled(cls, values):
|
||||
values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
|
||||
"csv_monitor").enabled or values.get("comet").enabled
|
||||
return values
|
||||
@model_validator(mode="after")
|
||||
def check_enabled(self):
|
||||
enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
|
||||
self.__dict__["enabled"] = enabled
|
||||
return self
|
||||
|
@ -1,16 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""Pydantic v1 compatibility module.
|
||||
|
||||
Pydantic v2 introduced breaking changes that hinder its adoption:
|
||||
https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
|
||||
migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
|
||||
as a pydantic-version-agnostic alias for pydantic's v1 API.
|
||||
"""
|
||||
|
||||
try:
|
||||
from pydantic.v1 import * # noqa: F401
|
||||
except ImportError:
|
||||
from pydantic import * # noqa: F401
|
@ -5,11 +5,12 @@
|
||||
"""
|
||||
Collection of DeepSpeed configuration utilities
|
||||
"""
|
||||
import json
|
||||
import collections
|
||||
import collections.abc
|
||||
import json
|
||||
import torch
|
||||
from functools import reduce
|
||||
from deepspeed.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer
|
||||
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
@ -54,67 +55,73 @@ class DeepSpeedConfigModel(BaseModel):
|
||||
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
|
||||
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
|
||||
super().__init__(**data)
|
||||
self._deprecated_fields_check(self)
|
||||
self._deprecated_fields_check()
|
||||
|
||||
def _process_deprecated_field(self, pydantic_config, field):
|
||||
def _process_deprecated_field(self, dep_field):
|
||||
# Get information about the deprecated field
|
||||
fields_set = pydantic_config.__fields_set__
|
||||
dep_param = field.name
|
||||
kwargs = field.field_info.extra
|
||||
pydantic_config = self
|
||||
fields_set = pydantic_config.model_fields_set
|
||||
kwargs = pydantic_config.model_fields[dep_field].json_schema_extra
|
||||
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
|
||||
param_value = new_param_fn(getattr(pydantic_config, dep_param))
|
||||
new_param = kwargs.get("new_param", "")
|
||||
param_value = new_param_fn(getattr(pydantic_config, dep_field))
|
||||
new_field = kwargs.get("new_param", "")
|
||||
dep_msg = kwargs.get("deprecated_msg", "")
|
||||
if dep_param in fields_set:
|
||||
logger.warning(f"Config parameter {dep_param} is deprecated" +
|
||||
(f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
|
||||
if dep_field in fields_set:
|
||||
logger.warning(f"Config parameter {dep_field} is deprecated" +
|
||||
(f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
|
||||
# Check if there is a new param and if it should be set with a value
|
||||
if new_param and kwargs.get("set_new_param", True):
|
||||
if new_field and kwargs.get("set_new_param", True):
|
||||
# Remove the deprecate field if there is a replacing field
|
||||
try:
|
||||
delattr(pydantic_config, dep_param)
|
||||
delattr(pydantic_config, dep_field)
|
||||
except Exception as e:
|
||||
logger.error(f"Tried removing deprecated '{dep_param}' from config")
|
||||
logger.error(f"Tried removing deprecated '{dep_field}' from config")
|
||||
raise e
|
||||
|
||||
# Set new param value
|
||||
new_param_nested = new_param.split(".")
|
||||
new_param_nested = new_field.split(".")
|
||||
if len(new_param_nested) > 1:
|
||||
# If the new param exists in a subconfig, we need to get
|
||||
# the fields set for that subconfig
|
||||
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
|
||||
fields_set = pydantic_config.__fields_set__
|
||||
fields_set = pydantic_config.model_fields_set
|
||||
new_param_name = new_param_nested[-1]
|
||||
assert (
|
||||
new_param_name not in fields_set
|
||||
), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
|
||||
), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
|
||||
# A custom function for converting the old param value to new param value can be provided
|
||||
try:
|
||||
setattr(pydantic_config, new_param_name, param_value)
|
||||
except Exception as e:
|
||||
logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
|
||||
logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
|
||||
raise e
|
||||
|
||||
def _deprecated_fields_check(self, pydantic_config):
|
||||
fields = pydantic_config.__fields__
|
||||
for field in fields.values():
|
||||
if field.field_info.extra.get("deprecated", False):
|
||||
self._process_deprecated_field(pydantic_config, field)
|
||||
def _deprecated_fields_check(self):
|
||||
fields = self.model_fields
|
||||
for field_name, field_info in fields.items():
|
||||
if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
|
||||
self._process_deprecated_field(field_name)
|
||||
|
||||
class Config:
|
||||
validate_all = True
|
||||
validate_assignment = True
|
||||
use_enum_values = True
|
||||
allow_population_by_field_name = True
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(
|
||||
validate_default=True,
|
||||
validate_assignment=True,
|
||||
use_enum_values=True,
|
||||
populate_by_name=True,
|
||||
extra="forbid",
|
||||
arbitrary_types_allowed=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
@field_serializer("dtype", check_fields=False)
|
||||
def serialize_torch_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype)
|
||||
|
||||
|
||||
def get_config_default(config, field_name):
|
||||
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
|
||||
assert not config.__fields__.get(
|
||||
field_name).required, f"'{field_name}' is a required field and does not have a default value"
|
||||
return config.__fields__.get(field_name).default
|
||||
assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}"
|
||||
assert not config.model_fields.get(
|
||||
field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
|
||||
return config.model_fields.get(field_name).get_default()
|
||||
|
||||
|
||||
class pp_int(int):
|
||||
|
@ -6,7 +6,7 @@
|
||||
import sys
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from deepspeed.pydantic_v1 import Field, validator, root_validator
|
||||
from pydantic import Field, model_validator
|
||||
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
|
||||
from deepspeed.utils import logger
|
||||
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
|
||||
@ -30,7 +30,7 @@ ZeRO optimization should be enabled as:
|
||||
"reduce_bucket_size": 500000000,
|
||||
"load_from_fp32_weights": [true|false],
|
||||
"cpu_offload": [true|false] (deprecated),
|
||||
"cpu_offload_params" : [true|false] (deprecated),
|
||||
"cpu_offload_param" : [true|false] (deprecated),
|
||||
"cpu_offload_use_pin_memory": [true|false] (deprecated),
|
||||
"sub_group_size" : 1000000000000,
|
||||
"offload_param": {...},
|
||||
@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
the allgather for large model sizes
|
||||
"""
|
||||
|
||||
overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
|
||||
overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
|
||||
"""
|
||||
Attempts to overlap the reduction of the gradients with backward computation
|
||||
"""
|
||||
@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
|
||||
"""
|
||||
|
||||
cpu_offload_param: bool = Field(
|
||||
cpu_offload_param: Optional[bool] = Field(
|
||||
None,
|
||||
deprecated=True,
|
||||
new_param="offload_param",
|
||||
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "offload_param",
|
||||
"new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
|
||||
if val else None)
|
||||
},
|
||||
)
|
||||
""" Deprecated, please use ``offload_param`` """
|
||||
|
||||
cpu_offload_use_pin_memory: bool = Field(
|
||||
cpu_offload_use_pin_memory: Optional[bool] = Field(
|
||||
None,
|
||||
deprecated=True,
|
||||
new_param="offload_param or offload_optimizer",
|
||||
set_new_param=False,
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "offload_param or offload_optimizer",
|
||||
"set_new_param": False
|
||||
},
|
||||
)
|
||||
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
|
||||
|
||||
cpu_offload: bool = Field(
|
||||
cpu_offload: Optional[bool] = Field(
|
||||
None,
|
||||
deprecated=True,
|
||||
new_param="offload_optimizer",
|
||||
new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
|
||||
json_schema_extra={
|
||||
"deprecated":
|
||||
True,
|
||||
"new_param":
|
||||
"offload_optimizer",
|
||||
"new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
|
||||
if val else None)
|
||||
},
|
||||
)
|
||||
""" Deprecated, please use ``offload_optimizer`` """
|
||||
|
||||
@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
"""
|
||||
|
||||
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
|
||||
deprecated=True,
|
||||
new_param="gather_16bit_weights_on_model_save")
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "gather_16bit_weights_on_model_save"
|
||||
})
|
||||
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
|
||||
|
||||
ignore_unused_parameters: bool = True
|
||||
@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
|
||||
"""
|
||||
|
||||
# Validators
|
||||
@validator("overlap_comm")
|
||||
def overlap_comm_valid(cls, field_value, values):
|
||||
if field_value is None:
|
||||
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
|
||||
field_value = values["stage"] == ZeroStageEnum.weights
|
||||
return field_value
|
||||
@model_validator(mode="after")
|
||||
def overlap_comm_valid(self):
|
||||
if self.overlap_comm is None:
|
||||
self.overlap_comm = self.stage == ZeroStageEnum.weights
|
||||
return self
|
||||
|
||||
@root_validator
|
||||
def offload_ratio_check(cls, values):
|
||||
offload_config = getattr(values, "offload_optimizer", {})
|
||||
@model_validator(mode="after")
|
||||
def offload_ratio_check(self):
|
||||
offload_config = self.offload_optimizer
|
||||
if offload_config and offload_config.ratio < 1.0:
|
||||
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
|
||||
return values
|
||||
assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
|
||||
return self
|
||||
|
@ -5,7 +5,9 @@
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from deepspeed.pydantic_v1 import Field, validator
|
||||
from pydantic import Field, model_validator
|
||||
from typing import Optional
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
|
||||
|
||||
|
||||
@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel):
|
||||
`nvme`.
|
||||
"""
|
||||
|
||||
nvme_path: Path = None
|
||||
nvme_path: Optional[Path] = None
|
||||
""" Filesystem path for NVMe device for parameter offloading. """
|
||||
|
||||
buffer_count: int = Field(5, ge=0)
|
||||
@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
||||
`nvme`. Optimizer computation is offload to CPU regardless of device option.
|
||||
"""
|
||||
|
||||
nvme_path: Path = None
|
||||
nvme_path: Optional[Path] = None
|
||||
""" Filesystem path for NVMe device for optimizer state offloading. """
|
||||
|
||||
buffer_count: int = Field(4, ge=0)
|
||||
@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
||||
fast_init: bool = False
|
||||
""" Enable fast optimizer initialization when offloading to NVMe. """
|
||||
|
||||
@validator("pipeline_read", "pipeline_write", always=True)
|
||||
def set_pipeline(cls, field_value, values):
|
||||
values["pipeline"] = field_value or values.get("pipeline", False)
|
||||
return field_value
|
||||
|
||||
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
||||
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_pipeline(self):
|
||||
pipeline = self.pipeline_read or self.pipeline_write
|
||||
self.__dict__["pipeline"] = pipeline
|
||||
return self
|
||||
|
@ -725,8 +725,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
def get_first_param_index(self, group_id, param_group, partition_id):
|
||||
for index, param in enumerate(param_group):
|
||||
param_id = self.get_param_id(param)
|
||||
if partition_id in self.param_to_partition_ids[group_id][param_id]:
|
||||
return index
|
||||
if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]:
|
||||
if partition_id in self.param_to_partition_ids[group_id][param_id]:
|
||||
return index
|
||||
return None
|
||||
|
||||
def initialize_gradient_partitioning_data_structures(self):
|
||||
|
@ -1,10 +1,10 @@
|
||||
autodoc_pydantic
|
||||
autodoc_pydantic>=2.0.0
|
||||
docutils<0.18
|
||||
hjson
|
||||
packaging
|
||||
psutil
|
||||
py-cpuinfo
|
||||
pydantic<2.0.0
|
||||
pydantic>=2.0.0
|
||||
recommonmark
|
||||
sphinx_rtd_theme
|
||||
torch
|
||||
|
@ -4,6 +4,6 @@ numpy
|
||||
packaging>=20.0
|
||||
psutil
|
||||
py-cpuinfo
|
||||
pydantic
|
||||
pydantic>=2.0.0
|
||||
torch
|
||||
tqdm
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from deepspeed.pydantic_v1 import ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from deepspeed.inference.v2.ragged import DSStateManagerConfig
|
||||
|
||||
|
@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
|
||||
|
||||
if not success:
|
||||
assert not status
|
||||
print("Failed but All is well")
|
||||
return
|
||||
|
||||
assert ds_config.train_batch_size == batch
|
||||
assert ds_config.train_micro_batch_size_per_gpu == micro_batch
|
||||
assert ds_config.gradient_accumulation_steps == gas
|
||||
print("All is well")
|
||||
|
||||
|
||||
#Tests different batch config provided in deepspeed json file
|
||||
|
@ -4,18 +4,25 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import json
|
||||
from typing import List
|
||||
from deepspeed.pydantic_v1 import Field, ValidationError
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, ValidationError
|
||||
|
||||
from deepspeed.runtime import config as ds_config
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
|
||||
class SimpleConf(DeepSpeedConfigModel):
|
||||
param_1: int = 0
|
||||
param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x]))
|
||||
param_2: List[str] = None
|
||||
param_2_old: Optional[str] = Field(None,
|
||||
json_schema_extra={
|
||||
"deprecated": True,
|
||||
"new_param": "param_2",
|
||||
"new_param_fn": (lambda x: [x])
|
||||
})
|
||||
param_2: Optional[List[str]] = None
|
||||
param_3: int = Field(0, alias="param_3_alias")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user