Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
J38 committed Dec 29, 2022
1 parent b4e7544 commit 1659a9e
Showing 1 changed file with 6 additions and 23 deletions.
29 changes: 6 additions & 23 deletions conf/train_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Cerberus schema used by Quinine for train.py.
"""
import logging
from typing import Any, Dict

from quinine.common.cerberus import (
Expand All @@ -21,42 +20,27 @@
)


def deprecated_field(msg):
"""Can be used in a schema to indicate that a field has been deprecated."""

def _deprecated_field(field, value, _error):
if value is not None:
logging.warning(f"{field} is deprecated and will be removed in a future release.")
if msg:
logging.warning(msg)

return {"check_with": _deprecated_field}


def get_schema() -> Dict[str, Any]:
"""Get the Cerberus schema for the Quinine config used in train.py."""

# Schema for Dataset
data_schema = {
"id": merge(tstring, required),
"name": merge(tstring, nullable, default(None)),
"dataset_dir": merge(tstring, nullable, default(None)),
"source": merge(tstring, nullable, default(None)),
"source_ratios": merge(tstring, nullable, default(None)),
"validation_ratio": merge(tfloat, default(0.0005)),
"num_proc": merge(tinteger, default(64)),
"eval_num_proc": merge(tinteger, default(4)),
"dataset_dir": merge(tstring, nullable, default(None)),
"detokenize": merge(tboolean, default(True)),
}

# Schema for Model
model_schema = {
"id": merge(tstring, required),
"gradient_checkpointing": merge(
tboolean,
nullable,
default(None),
deprecated_field("This config is now in training_arguments to better match HF."),
),
"gradient_checkpointing": merge(tboolean, default(False)),
"pretrained_tokenizer": merge(tboolean, default(True)),
"passthrough_tokenizer": merge(tboolean, default(False)),
"seq_len": merge(tinteger, default(1024)),
"reorder_and_upcast_attn": merge(tboolean, nullable, default(True)),
"scale_attn_by_inverse_layer_idx": merge(tboolean, nullable, default(True)),
Expand Down Expand Up @@ -94,9 +78,8 @@ def get_schema() -> Dict[str, Any]:
"fp16_backend": merge(tstring, default("auto")),
"sharded_ddp": merge(tstring, nullable, default(None)),
"deepspeed": merge(tstring, nullable, default(None)),
"dataloader_num_workers": merge(tinteger, default(0)),
"dataloader_num_workers": merge(tinteger, default(4)),
"local_rank": merge(tinteger, nullable, default(None)),
"gradient_checkpointing": merge(tboolean, default(False)),
}

# Schema for Online Custom Evaluation Datasets (e.g. LAMBADA)
Expand Down

0 comments on commit 1659a9e

Please sign in to comment.