Skip to content

Commit

Permalink
Merge pull request #21 from OpenMOSS/config
Browse files Browse the repository at this point in the history
Add Configuration Entrypoint
  • Loading branch information
dest1n1s authored Jun 13, 2024
2 parents 83f863c + 7d2751f commit 5dfcdbe
Show file tree
Hide file tree
Showing 21 changed files with 1,402 additions and 375 deletions.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,27 @@ bun install

It's worth noting that `bun` is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as `pnpm` or `yarn`.

## Training/Analyzing a Dictionary
## Launch an Experiment

We give some basic examples to show how to train a dictionary and analyze the learned dictionary in the [examples](https://github.com/OpenMOSS/Language-Model-SAEs/exapmles). You can copy the example scripts to the `exp` directory and modify them to fit your needs. More examples will be added in the future.
We provide both a programmatic and a configuration-based way to launch an experiment. The configuration-based way is more flexible and recommended for most users. You can find the configuration files in the [examples/configuration](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration) directory, and modify them to fit your needs. The programmatic way is more suitable for advanced users who want to customize the training process, and you can find the example scripts in the [examples/programmatic](https://github.com/OpenMOSS/Language-Model-SAEs/examples/programmatic) directory.

To simply begin a training process, you can run the following command:

```bash
lm-saes train examples/configuration/train.toml
```

which will start the training process using the configuration file [examples/configuration/train.toml](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration/train.toml).

To analyze a trained dictionary, you can run the following command:

```bash
lm-saes analyze examples/configuration/analyze.toml --sae <path_to_sae_model>
```

which will start the analysis process using the configuration file [examples/configuration/analyze.toml](https://github.com/OpenMOSS/Language-Model-SAEs/examples/configuration/analyze.toml). The analysis process requires a trained SAE model, which can be obtained from the training process. You may need launch a MongoDB server to store the analysis results, and you can modify the MongoDB settings in the configuration file.

Generally, our configuration-based pipeline uses outer layer settings as default of the inner layer settings. This is beneficial for easily building deeply nested configurations, where sub-configurations can be reused (such as device and dtype settings). More detail will be provided in the configuration files.

## Visualizing the Learned Dictionary

Expand Down
41 changes: 41 additions & 0 deletions examples/configuration/analyze.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
total_analyzing_tokens = 20_000_000

use_ddp = false
device = "cuda"
seed = 42
dtype = "torch.float32"

exp_name = "L3M"
exp_series = "default"
exp_result_dir = "results"

[subsample]
"top_activations" = { "proportion" = 1.0, "n_samples" = 80 }
"subsample-0.9" = { "proportion" = 0.9, "n_samples" = 20}
"subsample-0.8" = { "proportion" = 0.8, "n_samples" = 20}
"subsample-0.7" = { "proportion" = 0.7, "n_samples" = 20}
"subsample-0.5" = { "proportion" = 0.5, "n_samples" = 20}

[lm]
model_name = "gpt2"
d_model = 768

[dataset]
dataset_path = "openwebtext"
is_dataset_tokenized = false
is_dataset_on_disk = true
concat_tokens = false
context_size = 256
store_batch_size = 32

[act_store]
device = "cuda"
seed = 42
dtype = "torch.float32"
hook_points = [ "blocks.3.hook_mlp_out",]
use_cached_activations = false
n_tokens_in_buffer = 500000

[mongo]
mongo_db = "mechinterp"
mongo_uri = "mongodb://localhost:27017"
60 changes: 60 additions & 0 deletions examples/configuration/train.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use_ddp = false
exp_name = "L3M"
exp_result_dir = "results"
device = "cuda"
seed = 42
dtype = "torch.float32"
total_training_tokens = 1_600_000_000
lr = 4e-4
betas = [ 0.0, 0.9999,]
lr_scheduler_name = "constantwithwarmup"
lr_warm_up_steps = 5000
lr_cool_down_steps = 10000
train_batch_size = 4096
finetuning = false
feature_sampling_window = 1000
dead_feature_window = 5000
dead_feature_threshold = 1e-6
eval_frequency = 1000
log_frequency = 100
n_checkpoints = 10


[sae]
hook_point_in = "blocks.3.hook_mlp_out"
hook_point_out = "blocks.3.hook_mlp_out"
strict_loading = true
use_decoder_bias = false
apply_decoder_bias_to_pre_encoder = true
decoder_bias_init_method = "geometric_median"
expansion_factor = 32
d_model = 768
norm_activation = "token-wise"
decoder_exactly_unit_norm = false
use_glu_encoder = false
l1_coefficient = 1.2e-4
lp = 1
use_ghost_grads = true

[lm]
model_name = "gpt2"
d_model = 768

[dataset]
dataset_path = "openwebtext"
is_dataset_on_disk = false
concat_tokens = false
context_size = 256
store_batch_size = 32

[act_store]
device = "cuda"
seed = 42
dtype = "torch.float32"
hook_points = [ "blocks.3.hook_mlp_out",]
use_cached_activations = false
n_tokens_in_buffer = 500000

[wandb]
log_to_wandb = true
wandb_project = "gpt2-sae"
20 changes: 3 additions & 17 deletions examples/analyze.py → examples/programmatic/analyze.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
import torch
import os
import torch.distributed as dist
from lm_saes.config import LanguageModelSAEAnalysisConfig, SAEConfig
from lm_saes.runner import sample_feature_activations_runner

use_ddp = False

if use_ddp:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())

cfg = LanguageModelSAEAnalysisConfig(
cfg = LanguageModelSAEAnalysisConfig.from_flattened(dict(
# LanguageModelConfig
model_name = "gpt2",

Expand Down Expand Up @@ -44,18 +35,13 @@
mongo_uri="mongodb://localhost:27017", # MongoDB URI.

# RunnerConfig
use_ddp = use_ddp,
device = "cuda",
seed = 42,
dtype = torch.float32,

exp_name = "L3M",
exp_series = "default",
exp_result_dir = "results",
)

sample_feature_activations_runner(cfg)
))

if use_ddp:
dist.destroy_process_group()
torch.cuda.empty_cache()
sample_feature_activations_runner(cfg)
22 changes: 4 additions & 18 deletions examples/train.py → examples/programmatic/train.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
import os
import sys
import torch
import torch.distributed as dist
from lm_saes.config import LanguageModelSAETrainingConfig
from lm_saes.runner import language_model_sae_runner

use_ddp = False

if use_ddp:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())

cfg = LanguageModelSAETrainingConfig(
cfg = LanguageModelSAETrainingConfig.from_flattened(dict(
# LanguageModelConfig
model_name = "gpt2", # The model name or path for the pre-trained model.
d_model = 768, # The hidden size of the model.

# TextDatasetConfig
dataset_path = "data/openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
dataset_path = "openwebtext", # The corpus name or path. Each of a data record should contain (and may only contain) a "text" field.
is_dataset_tokenized = False, # Whether the dataset is tokenized.
is_dataset_on_disk = True, # Whether the dataset is on disk. If not on disk, `datasets.load_dataset`` will be used to load the dataset, and the train split will be used for training.
concat_tokens = False, # Whether to concatenate tokens into a single sequence. If False, only data record with length of non-padding tokens larger than `context_size` will be used.
Expand Down Expand Up @@ -61,17 +51,13 @@
wandb_project= "gpt2-sae", # The wandb project name.

# RunnerConfig
use_ddp = use_ddp, # Whether to use the DistributedDataParallel.
device = "cuda", # The device to place all torch tensors.
seed = 42, # The random seed.
dtype = torch.float32, # The torch data type of non-integer tensors.

exp_name = "L3M", # The experiment name. Would be used for creating exp folder (which may contain checkpoints and analysis results) and setting wandb run name.
exp_series = "default",
exp_result_dir = "results"
)

sparse_autoencoder = language_model_sae_runner(cfg)
))

if use_ddp:
dist.destroy_process_group()
sparse_autoencoder = language_model_sae_runner(cfg)
Loading

0 comments on commit 5dfcdbe

Please sign in to comment.