diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 3ce406948432..484948b28e34 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,7 +47,8 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - python -m pip install pydantic==1.10.11 + # Update packages included in the container that do not support pydantic 2+ to versions that do + python -m pip install thinc spacy confection --upgrade python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py index 1c441bb6bfe9..57501c9dd237 100644 --- a/deepspeed/comm/config.py +++ b/deepspeed/comm/config.py @@ -3,20 +3,12 @@ # DeepSpeed Team -from .constants import * -from ..pydantic_v1 import BaseModel - +from deepspeed.runtime.config_utils import DeepSpeedConfigModel -class CommsConfig(BaseModel): - - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' +from .constants import * -class CommsLoggerConfig(CommsConfig): +class CommsLoggerConfig(DeepSpeedConfigModel): enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 1d5018aaa75b..c7c7684fff79 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -5,38 +5,25 @@ import torch import deepspeed -from deepspeed.pydantic_v1 import Field, validator +from pydantic import Field, field_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from typing import Dict, Union +from typing import Dict, Union, Optional from enum import Enum class DtypeEnum(Enum): - # The torch dtype must always be the first value (so we return torch.dtype) - fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" - fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" - bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat" - int8 = torch.int8, "torch.int8", "int8" - - # Copied from https://stackoverflow.com/a/43210118 - # Allows us to use multiple values for each Enum index and returns first - # listed value when Enum is called - def __new__(cls, *values): - obj = object.__new__(cls) - # first value is canonical value - obj._value_ = values[0] - for other_value in values[1:]: - cls._value2member_map_[other_value] = obj - obj._all_values = values - return obj - - def __repr__(self): - return "<%s.%s: %s>" % ( - self.__class__.__name__, - self._name_, - ", ".join([repr(v) for v in self._all_values]), - ) + fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half") + fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float") + bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat") + int8 = (torch.int8, "torch.int8", "int8") + + @classmethod + def from_str(cls, value: str): + for dtype in cls: + if value in dtype.value: + return dtype + raise ValueError(f"'{value}' is not a valid DtypeEnum") class MoETypeEnum(str, Enum): @@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum): class BaseQuantConfig(DeepSpeedConfigModel): - enabled = True - num_bits = 8 + enabled: bool = True + num_bits: int = 8 q_type: QuantTypeEnum = QuantTypeEnum.sym q_groups: int = 1 class WeightQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True quantized_initialization: Dict = {} post_init_quant: Dict = {} class ActivationQuantConfig(BaseQuantConfig): - enabled = True + enabled: bool = True class QKVQuantConfig(DeepSpeedConfigModel): - enabled = True + enabled: bool = True class QuantizationConfig(DeepSpeedConfigModel): @@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel): # todo: brainstorm on how to do ckpt loading for DS inference class InferenceCheckpointConfig(DeepSpeedConfigModel): - checkpoint_dir: str = None - save_mp_checkpoint_path: str = None - base_dir: str = None + checkpoint_dir: Optional[str] = None + save_mp_checkpoint_path: Optional[str] = None + base_dir: Optional[str] = None class DeepSpeedInferenceConfig(DeepSpeedConfigModel): @@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): `(attention_output projection, transformer output projection)` """ - dtype: DtypeEnum = torch.float16 + dtype: torch.dtype = torch.float16 """ Desired model data type, will convert model to this type. Supported target types: `torch.half`, `torch.int8`, `torch.float` @@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): """ #todo: refactor the following 3 into the new checkpoint_config - checkpoint: Union[str, Dict] = None + checkpoint: Optional[Union[str, Dict]] = None """ Path to deepspeed compatible checkpoint or path to JSON with load policy. """ @@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): specifying whether the inference-module is created with empty or real Tensor """ - save_mp_checkpoint_path: str = None + save_mp_checkpoint_path: Optional[str] = None """ The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the parallelism degree to help alleviate the @@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): replace_method: str = Field( "auto", - deprecated=True, - deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference") + json_schema_extra={ + "deprecated": True, + "deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference" + }) - injection_policy: Dict = Field(None, alias="injection_dict") + injection_policy: Optional[Dict] = Field(None, alias="injection_dict") """ Dictionary mapping a client nn.Module to its corresponding injection policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}` """ - injection_policy_tuple: tuple = None + injection_policy_tuple: Optional[tuple] = None """ TODO: Add docs """ - config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor + config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor max_out_tokens: int = Field(1024, alias="max_tokens") """ @@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): transposed_mode: bool = Field(False, alias="transposed_mode") - mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") + mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"}) """ Desired model parallel size, default is 1 meaning no model parallelism. Deprecated, please use the ``tensor_parallel` config to control model parallelism. """ - mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") - ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") - ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group") - ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group") - moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts") - moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type") - - @validator("moe") + mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"}) + ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"}) + ep_group: object = Field(None, + alias="expert_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_group" + }) + ep_mp_group: object = Field(None, + alias="expert_mp_group", + json_schema_extra={ + "deprecated": True, + "new_param": "moe.ep_mp_group" + }) + moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"}) + moe_type: MoETypeEnum = Field(MoETypeEnum.standard, + json_schema_extra={ + "deprecated": True, + "new_param": "moe.type" + }) + + @field_validator("dtype", mode="before") + def validate_dtype(cls, field_value, values): + if isinstance(field_value, str): + return DtypeEnum.from_str(field_value).value[0] + if isinstance(field_value, torch.dtype): + return field_value + raise TypeError(f"Invalid type for dtype: {type(field_value)}") + + @field_validator("moe") def moe_backward_compat(cls, field_value, values): if isinstance(field_value, bool): return DeepSpeedMoEConfig(moe=field_value) return field_value - @validator("use_triton") + @field_validator("use_triton") def has_triton(cls, field_value, values): if field_value and not deepspeed.HAS_TRITON: raise ValueError('Triton needs to be installed to use deepspeed with triton kernels') return field_value - - class Config: - # Get the str representation of the datatype for serialization - json_encoders = {torch.dtype: lambda x: str(x)} diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 85e4b7a0e0a0..325b57d8f56a 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -3,8 +3,9 @@ # DeepSpeed Team +from pydantic import Field from typing import Optional -from deepspeed.pydantic_v1 import Field + from deepspeed.runtime.config_utils import DeepSpeedConfigModel from .ragged import DSStateManagerConfig diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index ebdb59bca920..c5e02adaffc4 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel): """ A class to represent a tensor specification. """ - dtype: Optional[str] - shape: Optional[Tuple[int, ...]] - strides: Optional[Tuple[int, ...]] + dtype: Optional[str] = None + shape: Optional[Tuple[int, ...]] = None + strides: Optional[Tuple[int, ...]] = None offset: int @@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel): """ A class to represent a parameter specification. """ - core_param: TensorMetadata = None + core_param: Optional[TensorMetadata] = None aux_params: Dict[str, TensorMetadata] = {} diff --git a/deepspeed/inference/v2/ragged/manager_configs.py b/deepspeed/inference/v2/ragged/manager_configs.py index a5e98e5bcef1..17283b8bc0c4 100644 --- a/deepspeed/inference/v2/ragged/manager_configs.py +++ b/deepspeed/inference/v2/ragged/manager_configs.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Tuple -from deepspeed.pydantic_v1 import PositiveInt, validator +from pydantic import PositiveInt, model_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel from ..inference_utils import DtypeEnum @@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel): Enable tracking for offloading KV-cache to host memory. Currently unsupported. """ - @validator("max_ragged_sequence_count") - def max_ragged_sequence_count_validator(cls, v: int, values: dict): + @model_validator(mode="after") + def max_ragged_sequence_count_validator(self): # If the attributes below failed their validation they won't appear in the values dict. - if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]: - raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences") - if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]: - raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size") - return v + assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences" + assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size" + return self diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index c4200877089a..960ce1ba997a 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -5,7 +5,7 @@ from typing import Optional -from deepspeed.pydantic_v1 import root_validator +from pydantic import model_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel): enabled: bool = False """ Whether logging to WandB is enabled. Requires `wandb` package is installed. """ - group: str = None + group: Optional[str] = None """ Name for the WandB group. This can be used to group together runs. """ - team: str = None + team: Optional[str] = None """ Name for the WandB team. """ project: str = "deepspeed" @@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel): csv_monitor: CSVConfig = {} """ Local CSV output of monitoring data. """ - @root_validator - def check_enabled(cls, values): - values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( - "csv_monitor").enabled or values.get("comet").enabled - return values + @model_validator(mode="after") + def check_enabled(self): + enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled + self.__dict__["enabled"] = enabled + return self diff --git a/deepspeed/pydantic_v1.py b/deepspeed/pydantic_v1.py deleted file mode 100644 index 6aba072ad929..000000000000 --- a/deepspeed/pydantic_v1.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team -"""Pydantic v1 compatibility module. - -Pydantic v2 introduced breaking changes that hinder its adoption: -https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to -migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module -as a pydantic-version-agnostic alias for pydantic's v1 API. -""" - -try: - from pydantic.v1 import * # noqa: F401 -except ImportError: - from pydantic import * # noqa: F401 diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 5522a8e79d69..d5c3a1548360 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -5,11 +5,12 @@ """ Collection of DeepSpeed configuration utilities """ -import json import collections -import collections.abc +import json +import torch from functools import reduce -from deepspeed.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict, field_serializer + from deepspeed.utils import logger @@ -54,67 +55,73 @@ def __init__(self, strict=False, **data): if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")} super().__init__(**data) - self._deprecated_fields_check(self) + self._deprecated_fields_check() - def _process_deprecated_field(self, pydantic_config, field): + def _process_deprecated_field(self, dep_field): # Get information about the deprecated field - fields_set = pydantic_config.__fields_set__ - dep_param = field.name - kwargs = field.field_info.extra + pydantic_config = self + fields_set = pydantic_config.model_fields_set + kwargs = pydantic_config.model_fields[dep_field].json_schema_extra new_param_fn = kwargs.get("new_param_fn", lambda x: x) - param_value = new_param_fn(getattr(pydantic_config, dep_param)) - new_param = kwargs.get("new_param", "") + param_value = new_param_fn(getattr(pydantic_config, dep_field)) + new_field = kwargs.get("new_param", "") dep_msg = kwargs.get("deprecated_msg", "") - if dep_param in fields_set: - logger.warning(f"Config parameter {dep_param} is deprecated" + - (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else "")) + if dep_field in fields_set: + logger.warning(f"Config parameter {dep_field} is deprecated" + + (f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else "")) # Check if there is a new param and if it should be set with a value - if new_param and kwargs.get("set_new_param", True): + if new_field and kwargs.get("set_new_param", True): # Remove the deprecate field if there is a replacing field try: - delattr(pydantic_config, dep_param) + delattr(pydantic_config, dep_field) except Exception as e: - logger.error(f"Tried removing deprecated '{dep_param}' from config") + logger.error(f"Tried removing deprecated '{dep_field}' from config") raise e # Set new param value - new_param_nested = new_param.split(".") + new_param_nested = new_field.split(".") if len(new_param_nested) > 1: # If the new param exists in a subconfig, we need to get # the fields set for that subconfig pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config) - fields_set = pydantic_config.__fields_set__ + fields_set = pydantic_config.model_fields_set new_param_name = new_param_nested[-1] assert ( new_param_name not in fields_set - ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together" + ), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together" # A custom function for converting the old param value to new param value can be provided try: setattr(pydantic_config, new_param_name, param_value) except Exception as e: - logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'") + logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'") raise e - def _deprecated_fields_check(self, pydantic_config): - fields = pydantic_config.__fields__ - for field in fields.values(): - if field.field_info.extra.get("deprecated", False): - self._process_deprecated_field(pydantic_config, field) + def _deprecated_fields_check(self): + fields = self.model_fields + for field_name, field_info in fields.items(): + if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False): + self._process_deprecated_field(field_name) + + model_config = ConfigDict( + validate_default=True, + validate_assignment=True, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + arbitrary_types_allowed=True, + protected_namespaces=(), + ) - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - allow_population_by_field_name = True - extra = "forbid" - arbitrary_types_allowed = True + @field_serializer("dtype", check_fields=False) + def serialize_torch_dtype(dtype: torch.dtype) -> str: + return str(dtype) def get_config_default(config, field_name): - assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}" - assert not config.__fields__.get( - field_name).required, f"'{field_name}' is a required field and does not have a default value" - return config.__fields__.get(field_name).default + assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}" + assert not config.model_fields.get( + field_name).is_required(), f"'{field_name}' is a required field and does not have a default value" + return config.model_fields.get(field_name).get_default() class pp_int(int): diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 2089d59dbce4..1cfcd784e2ce 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,7 +6,7 @@ import sys from typing import Optional from enum import Enum -from deepspeed.pydantic_v1 import Field, validator, root_validator +from pydantic import Field, model_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum @@ -30,7 +30,7 @@ "reduce_bucket_size": 500000000, "load_from_fp32_weights": [true|false], "cpu_offload": [true|false] (deprecated), - "cpu_offload_params" : [true|false] (deprecated), + "cpu_offload_param" : [true|false] (deprecated), "cpu_offload_use_pin_memory": [true|false] (deprecated), "sub_group_size" : 1000000000000, "offload_param": {...}, @@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): the allgather for large model sizes """ - overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below) + overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below) """ Attempts to overlap the reduction of the gradients with backward computation """ @@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): parameters). Used by ZeRO3-Offload and ZeRO-Infinity """ - cpu_offload_param: bool = Field( + cpu_offload_param: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param", - new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_param`` """ - cpu_offload_use_pin_memory: bool = Field( + cpu_offload_use_pin_memory: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_param or offload_optimizer", - set_new_param=False, + json_schema_extra={ + "deprecated": True, + "new_param": "offload_param or offload_optimizer", + "set_new_param": False + }, ) """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """ - cpu_offload: bool = Field( + cpu_offload: Optional[bool] = Field( None, - deprecated=True, - new_param="offload_optimizer", - new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None), + json_schema_extra={ + "deprecated": + True, + "new_param": + "offload_optimizer", + "new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) + if val else None) + }, ) """ Deprecated, please use ``offload_optimizer`` """ @@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ stage3_gather_fp16_weights_on_model_save: bool = Field(False, - deprecated=True, - new_param="gather_16bit_weights_on_model_save") + json_schema_extra={ + "deprecated": True, + "new_param": "gather_16bit_weights_on_model_save" + }) """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """ ignore_unused_parameters: bool = True @@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ # Validators - @validator("overlap_comm") - def overlap_comm_valid(cls, field_value, values): - if field_value is None: - assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" - field_value = values["stage"] == ZeroStageEnum.weights - return field_value - - @root_validator - def offload_ratio_check(cls, values): - offload_config = getattr(values, "offload_optimizer", {}) + @model_validator(mode="after") + def overlap_comm_valid(self): + if self.overlap_comm is None: + self.overlap_comm = self.stage == ZeroStageEnum.weights + return self + + @model_validator(mode="after") + def offload_ratio_check(self): + offload_config = self.offload_optimizer if offload_config and offload_config.ratio < 1.0: - assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." - return values + assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." + return self diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index b7adc13a0ea2..74a5673bc1bc 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -5,7 +5,9 @@ from enum import Enum from pathlib import Path -from deepspeed.pydantic_v1 import Field, validator +from pydantic import Field, model_validator +from typing import Optional + from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int @@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel): `nvme`. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for parameter offloading. """ buffer_count: int = Field(5, ge=0) @@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): `nvme`. Optimizer computation is offload to CPU regardless of device option. """ - nvme_path: Path = None + nvme_path: Optional[Path] = None """ Filesystem path for NVMe device for optimizer state offloading. """ buffer_count: int = Field(4, ge=0) @@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): fast_init: bool = False """ Enable fast optimizer initialization when offloading to NVMe. """ - @validator("pipeline_read", "pipeline_write", always=True) - def set_pipeline(cls, field_value, values): - values["pipeline"] = field_value or values.get("pipeline", False) - return field_value - ratio: float = Field(1.0, ge=0.0, le=1.0) """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" + + @model_validator(mode="after") + def set_pipeline(self): + pipeline = self.pipeline_read or self.pipeline_write + self.__dict__["pipeline"] = pipeline + return self diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 57e80911d645..83cf996ca019 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -725,8 +725,9 @@ def reduce_gradients(self, pipeline_parallel=False): def get_first_param_index(self, group_id, param_group, partition_id): for index, param in enumerate(param_group): param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index + if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]: + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index return None def initialize_gradient_partitioning_data_structures(self): diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index 1a2ad18611e7..a48a47e4428d 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,10 +1,10 @@ -autodoc_pydantic +autodoc_pydantic>=2.0.0 docutils<0.18 hjson packaging psutil py-cpuinfo -pydantic<2.0.0 +pydantic>=2.0.0 recommonmark sphinx_rtd_theme torch diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6840d6dbcc98..70c94a745435 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,6 +4,6 @@ numpy packaging>=20.0 psutil py-cpuinfo -pydantic +pydantic>=2.0.0 torch tqdm diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py index a5f270cced8c..bdd513445ddb 100644 --- a/tests/unit/inference/v2/ragged/test_manager_configs.py +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -5,7 +5,7 @@ import pytest -from deepspeed.pydantic_v1 import ValidationError +from pydantic import ValidationError from deepspeed.inference.v2.ragged import DSStateManagerConfig diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index c11c63d04867..d06b35e208fe 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success): if not success: assert not status - print("Failed but All is well") return assert ds_config.train_batch_size == batch assert ds_config.train_micro_batch_size_per_gpu == micro_batch assert ds_config.gradient_accumulation_steps == gas - print("All is well") #Tests different batch config provided in deepspeed json file diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py index 87ea747cf423..4d184b2858a8 100644 --- a/tests/unit/runtime/test_ds_config_model.py +++ b/tests/unit/runtime/test_ds_config_model.py @@ -4,18 +4,25 @@ # DeepSpeed Team import pytest -import os import json -from typing import List -from deepspeed.pydantic_v1 import Field, ValidationError +import os +from typing import List, Optional + +from pydantic import Field, ValidationError + from deepspeed.runtime import config as ds_config from deepspeed.runtime.config_utils import DeepSpeedConfigModel class SimpleConf(DeepSpeedConfigModel): param_1: int = 0 - param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x])) - param_2: List[str] = None + param_2_old: Optional[str] = Field(None, + json_schema_extra={ + "deprecated": True, + "new_param": "param_2", + "new_param_fn": (lambda x: [x]) + }) + param_2: Optional[List[str]] = None param_3: int = Field(0, alias="param_3_alias")