From 797a4ab4f05a91fcd4755a7e42a3c73e863d99e2 Mon Sep 17 00:00:00 2001 From: AI-WAIFU <67525070+AI-WAIFU@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:04:18 +0000 Subject: [PATCH] fix bug (#1311) --- megatron/neox_arguments/arguments.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index dec886541..9735a58be 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -956,12 +956,19 @@ def calculate_derived(self): ) # derive precision - fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" if self.fp16 and self.fp16.get("enabled", False): if self.precision is None: self.update_value("precision", "fp16") else: + fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" assert self.precision == "fp16", fp16_conflict + + if self.bf16 and self.bf16.get("enabled", False): + if self.precision is None: + self.update_value("precision", "bfloat16") + else: + bf16_conflict = "DeepSpeed bf16 field was set but precision conflicts" + assert self.precision == "bfloat16", bf16_conflict if self.precision == "fp16": if isinstance(self.fp16, dict) and len(self.fp16) > 0: @@ -971,14 +978,15 @@ def calculate_derived(self): fp16_args = {"type": "fp16", "enabled": True} self.update_value("fp16", fp16_args) elif self.precision == "bfloat16": - bf_config = {"bf16": {"enabled": True}} - # dt_config = {"grad_accum_dtype": "fp32"} - if self.deepspeed_extra_args is None: - self.update_value("deepspeed_extra_args", bf_config) - else: - extra_args = copy.deepcopy(self.deepspeed_extra_args) - extra_args.update(bf_config) - self.update_value("deepspeed_extra_args", extra_args) + if not self.bf16: + bf_config = {"bf16": {"enabled": True}} + # dt_config = {"grad_accum_dtype": "fp32"} + if self.deepspeed_extra_args is None: + self.update_value("deepspeed_extra_args", bf_config) + else: + extra_args = copy.deepcopy(self.deepspeed_extra_args) + extra_args.update(bf_config) + self.update_value("deepspeed_extra_args", extra_args) zero_stage = self.zero_optimization["stage"] if self.data_types is None: