-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement wasmedge-llmc Rust SDK (#1)
* Implement wasmedge-llmc Rust SDK Signed-off-by: Jun Zhang <[email protected]> * delete useless workflow Signed-off-by: Jun Zhang <[email protected]> --------- Signed-off-by: Jun Zhang <[email protected]>
- Loading branch information
Showing
8 changed files
with
447 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
name: CI | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
env: | ||
CARGO_TERM_COLOR: always | ||
RUST_LOG: DEBUG | ||
RUST_BACKTRACE: full | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Build SDK | ||
run: | | ||
rustup target add wasm32-wasi | ||
cd example | ||
cargo build --target wasm32-wasi --release | ||
- name: Download data | ||
run: | | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin | ||
- name: Build WasmEdge | ||
run: | | ||
sudo apt update && sudo apt install software-properties-common llvm-14-dev liblld-14-dev | ||
git clone https://github.com/WasmEdge/WasmEdge /tmp/WasmEdge | ||
mkdir /tmp/build && cd /tmp/build | ||
cmake /tmp/WasmEdge -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_LLM=ON -DWASMEDGE_BUILD_TESTS=OFF | ||
make -j$(nproc) && sudo make install | ||
- name: Train GPT2 | ||
run: | | ||
WASMEDGE_PLUGIN_PATH=/usr/local/lib/wasmedge/ wasmedge --dir .:. ./example/target/wasm32-wasi/release/example.wasm |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,25 @@ | ||
# wasmedge-llmc | ||
A Rust library for using llm.c functions when the Wasi is being executed on WasmEdge. | ||
|
||
## Set up WasmEdge | ||
```bash | ||
git clone https://github.com/WasmEdge/WasmEdge.git | ||
cd WasmEdge | ||
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TESTS=OFF -DWASMEDGE_PLUGIN_LLM=On | ||
cmake --build build | ||
cmake --install build | ||
``` | ||
|
||
## Download Checkpoints & Training data | ||
```bash | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin | ||
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin | ||
``` | ||
|
||
## Run the example | ||
```bash | ||
cargo build --target wasm32-wasi --release | ||
wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[package] | ||
name = "example" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
wasmedge-llmc = {path="../wasmedge-llmc"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use wasmedge_llmc::*; | ||
|
||
fn main() { | ||
let config = ConfigBuilder::default().lr(0.0002).epoch(20).build(); | ||
let model = match Model::from_checkpoints("/tmp/data/gpt2_124M.bin") { | ||
Ok(m) => m, | ||
Err(e) => { | ||
eprintln!("Failed to load model: {:?}", e); | ||
return; | ||
} | ||
}; | ||
let train_dataloader = match DataLoader::from_file( | ||
"/tmp/data/tiny_shakespeare_train.bin", | ||
4, // batch size | ||
64, // sequence length | ||
0, // process rank | ||
1, // number of processes | ||
true, // should shuffle | ||
) { | ||
Ok(loader) => loader, | ||
Err(e) => { | ||
eprintln!("Failed to load training data: {:?}", e); | ||
return; | ||
} | ||
}; | ||
let val_dataloader = match DataLoader::from_file( | ||
"/tmp/data/tiny_shakespeare_val.bin", | ||
4, // batch size | ||
64, // sequence length | ||
0, // process rank | ||
1, // number of processes | ||
false, // should shuffle | ||
) { | ||
Ok(loader) => loader, | ||
Err(e) => { | ||
eprintln!("Failed to load validation data: {:?}", e); | ||
return; | ||
} | ||
}; | ||
let tokenizer = match Tokenizer::from_file("/tmp/data/gpt2_tokenizer.bin") { | ||
Ok(t) => t, | ||
Err(e) => { | ||
eprintln!("Failed to load tokenizer: {:?}", e); | ||
return; | ||
} | ||
}; | ||
match model.train(train_dataloader, val_dataloader, tokenizer, config) { | ||
Ok(_) => println!("Training completed successfully."), | ||
Err(e) => eprintln!("Failed to train model: {:?}", e), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[package] | ||
name = "wasmedge-llmc" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
pub mod llmc_interface; | ||
use core::mem::MaybeUninit; | ||
use llmc_interface::*; | ||
|
||
pub struct Model { | ||
id: u32, | ||
} | ||
|
||
pub struct Config { | ||
lr: f32, | ||
epoch: u32, | ||
} | ||
|
||
pub struct ConfigBuilder { | ||
lr: f32, | ||
epoch: u32, | ||
} | ||
|
||
impl Default for ConfigBuilder { | ||
fn default() -> Self { | ||
Self { | ||
lr: 0.0001, | ||
epoch: 20, | ||
} | ||
} | ||
} | ||
|
||
impl ConfigBuilder { | ||
pub fn lr(mut self, lr: f32) -> Self { | ||
self.lr = lr; | ||
self | ||
} | ||
|
||
pub fn epoch(mut self, epoch: u32) -> Self { | ||
self.epoch = epoch; | ||
self | ||
} | ||
|
||
pub fn build(self) -> Config { | ||
Config { | ||
lr: self.lr, | ||
epoch: self.epoch, | ||
} | ||
} | ||
} | ||
|
||
impl Model { | ||
pub fn from_checkpoints(checkpoint_path: &str) -> Result<Self, WasmedgeLLMErrno> { | ||
let mut model_id = MaybeUninit::<u32>::uninit(); | ||
unsafe { | ||
let result = model_create(checkpoint_path, model_id.as_mut_ptr()); | ||
if let Err(code) = result { | ||
return Err(code); | ||
} | ||
Ok(Model { | ||
id: model_id.assume_init(), | ||
}) | ||
} | ||
} | ||
|
||
pub fn train( | ||
&self, | ||
train_data_loader: DataLoader, | ||
val_data_loader: DataLoader, | ||
tokenizer: Tokenizer, | ||
config: Config, | ||
) -> Result<(), WasmedgeLLMErrno> { | ||
unsafe { | ||
model_train( | ||
self.id, | ||
train_data_loader.id, | ||
val_data_loader.id, | ||
tokenizer.id, | ||
train_data_loader.batch_size, | ||
train_data_loader.sequence_length, | ||
config.lr, | ||
config.epoch, | ||
) | ||
} | ||
} | ||
} | ||
|
||
pub struct DataLoader { | ||
id: u32, | ||
batch_size: u32, | ||
sequence_length: u32, | ||
} | ||
|
||
impl DataLoader { | ||
pub fn from_file( | ||
data_path: &str, | ||
batch_size: u32, | ||
sequence_length: u32, | ||
process_rank: u32, | ||
num_processes: u32, | ||
should_shuffle: bool, | ||
) -> Result<Self, WasmedgeLLMErrno> { | ||
let mut dataloader_id = MaybeUninit::<u32>::uninit(); | ||
unsafe { | ||
let result = dataloader_create( | ||
data_path, | ||
batch_size, | ||
sequence_length, | ||
process_rank, | ||
num_processes, | ||
should_shuffle, | ||
dataloader_id.as_mut_ptr(), | ||
); | ||
if let Err(code) = result { | ||
return Err(code); | ||
} | ||
Ok(DataLoader { | ||
id: dataloader_id.assume_init(), | ||
batch_size, | ||
sequence_length, | ||
}) | ||
} | ||
} | ||
} | ||
|
||
pub struct Tokenizer { | ||
id: u32, | ||
} | ||
|
||
impl Tokenizer { | ||
pub fn from_file(filepath: &str) -> Result<Self, WasmedgeLLMErrno> { | ||
let mut tokenizer_id = MaybeUninit::<u32>::uninit(); | ||
unsafe { | ||
let result = tokenizer_create(filepath, tokenizer_id.as_mut_ptr()); | ||
if let Err(code) = result { | ||
return Err(code); | ||
} | ||
Ok(Tokenizer { | ||
id: tokenizer_id.assume_init(), | ||
}) | ||
} | ||
} | ||
} |
Oops, something went wrong.