Skip to content

Commit

Permalink
Optim package (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 7, 2024
1 parent 378470a commit e034003
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
6 changes: 3 additions & 3 deletions docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TrainLmConfig:
data: LMDatasetConfig = field(default_factory=LMDatasetConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=Gpt2Config)
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
```

Your training run will typically be associated with a single config file. For instance, you might have a file
Expand Down Expand Up @@ -290,7 +290,7 @@ If you're not using SLURM or TPUs, you can specify the cluster manually using th

## Optimizer

[levanter.trainer.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields:
[levanter.optim.OptimizerConfig][] is a dataclass that specifies the optimizer configuration. It has the following fields:

| Parameter | Description | Default |
|-----------------|-------------------------------------------------------------------|----------|
Expand Down Expand Up @@ -358,7 +358,7 @@ trainer:

### Optimizer

::: levanter.trainer.OptimizerConfig
::: levanter.optim.OptimizerConfig

### LM Model

Expand Down
3 changes: 1 addition & 2 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import sys
import typing
import warnings
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -51,7 +50,7 @@
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import microbatched
from levanter.logging import capture_time
from levanter.optim import SecondOrderTransformation
from levanter.optim import OptimizerConfig, SecondOrderTransformation # noqa: F401
from levanter.tracker import TrackerConfig
from levanter.types import FilterSpec
from levanter.utils import cloud_utils
Expand Down

0 comments on commit e034003

Please sign in to comment.