From e034003d379d311a7f542e7ca70136d0564fd2ee Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 7 Feb 2024 10:58:28 -0800 Subject: [PATCH] Optim package (#457) --- docs/Configuration-Guide.md | 6 +++--- src/levanter/trainer.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index f203f0dcc..bdb09e4f1 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -13,7 +13,7 @@ class TrainLmConfig: data: LMDatasetConfig = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) - optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) ``` Your training run will typically be associated with a single config file. For instance, you might have a file @@ -290,7 +290,7 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th ## Optimizer -[levanter.trainer.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: +[levanter.optim.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields: | Parameter | Description | Default | |-----------------|-------------------------------------------------------------------|----------| @@ -358,7 +358,7 @@ trainer: ### Optimizer -::: levanter.trainer.OptimizerConfig +::: levanter.optim.OptimizerConfig ### LM Model diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 847bb0e06..a98e0c10b 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -5,7 +5,6 @@ import os import sys import typing -import warnings from dataclasses import dataclass from functools import cached_property from pathlib import Path @@ -51,7 +50,7 @@ from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched from levanter.logging import capture_time -from levanter.optim import SecondOrderTransformation +from levanter.optim import OptimizerConfig, SecondOrderTransformation # noqa: F401 from levanter.tracker import TrackerConfig from levanter.types import FilterSpec from levanter.utils import cloud_utils