Skip to content

Commit

Permalink
fixes post-rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 18, 2024
1 parent b17a1b7 commit d1ba285
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 19 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
Expand Down
8 changes: 2 additions & 6 deletions src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
from axolotl.utils.trainer import setup_trainer

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
Expand Down Expand Up @@ -62,6 +61,7 @@ def evaluate_dataset(
return metrics


# pylint: disable=duplicate-code
def evaluate(
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
Expand All @@ -79,10 +79,6 @@ def evaluate(
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

# Load model
LOG.debug("loading model for evaluation...")

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
from axolotl.utils.trainer import setup_trainer

try:
from optimum.bettertransformer import BetterTransformer
Expand Down Expand Up @@ -87,7 +87,7 @@ def train(
)
resume_from_checkpoint = cfg.resume_from_checkpoint

# Load the model
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
Expand Down
11 changes: 0 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,17 +512,6 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"


def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
Expand Down

0 comments on commit d1ba285

Please sign in to comment.