Skip to content

Commit

Permalink
switch instructlab dolomite
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 20, 2024
1 parent c6bffd4 commit b9bd82f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# Third Party
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException
from instructlab.dolomite.enums import GradientCheckpointingMethod

Check failure on line 14 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0401: Unable to import 'instructlab.dolomite.enums' (import-error)

Check failure on line 14 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0611: No name 'dolomite' in module 'instructlab' (no-name-in-module)
from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing

Check failure on line 15 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0401: Unable to import 'instructlab.dolomite.gradient_checkpointing' (import-error)

Check failure on line 15 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0611: No name 'dolomite' in module 'instructlab' (no-name-in-module)
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM

Check failure on line 16 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0401: Unable to import 'instructlab.dolomite.hf_models' (import-error)

Check failure on line 16 in src/instructlab/training/main_ds.py

View workflow job for this annotation

GitHub Actions / lint

E0611: No name 'dolomite' in module 'instructlab' (no-name-in-module)
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
Expand Down Expand Up @@ -89,7 +92,7 @@ def setup_model(args, tokenizer, train_loader, grad_accum):

if args.is_granite:
# Third Party
from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM
# from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM

model = GPTDolomiteForCausalLM.from_pretrained(
args.model_name_or_path,
Expand Down Expand Up @@ -202,8 +205,8 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
# for both lora and full here
if args.is_granite:
# Third Party
from dolomite_engine.enums import GradientCheckpointingMethod
from dolomite_engine.gradient_checkpointing import apply_gradient_checkpointing
# from dolomite_engine.enums import GradientCheckpointingMethod
# from dolomite_engine.gradient_checkpointing import apply_gradient_checkpointing

block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import warnings

# Third Party
from instructlab.dolomite.hf_models import export_to_huggingface

Check failure on line 16 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

E0401: Unable to import 'instructlab.dolomite.hf_models' (import-error)

Check failure on line 16 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

E0611: No name 'dolomite' in module 'instructlab' (no-name-in-module)
from rich.logging import RichHandler
from torch import distributed as dist
from torch.distributed import get_rank, is_initialized
Expand Down Expand Up @@ -539,7 +540,6 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True
from tempfile import TemporaryDirectory

# Third Party
from dolomite_engine.hf_models import export_to_huggingface
from safetensors.torch import save_file

with TemporaryDirectory("w") as tmpdir:
Expand Down

0 comments on commit b9bd82f

Please sign in to comment.