Skip to content

Commit

Permalink
Implement wasmedge-llmc Rust SDK (#1)
Browse files Browse the repository at this point in the history
* Implement wasmedge-llmc Rust SDK

Signed-off-by: Jun Zhang <[email protected]>

* delete useless workflow

Signed-off-by: Jun Zhang <[email protected]>

---------

Signed-off-by: Jun Zhang <[email protected]>
  • Loading branch information
junaire authored Aug 22, 2024
1 parent a322089 commit 41954f5
Show file tree
Hide file tree
Showing 8 changed files with 447 additions and 19 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
env:
CARGO_TERM_COLOR: always
RUST_LOG: DEBUG
RUST_BACKTRACE: full

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }}
cancel-in-progress: true

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build SDK
run: |
rustup target add wasm32-wasi
cd example
cargo build --target wasm32-wasi --release
- name: Download data
run: |
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin
- name: Build WasmEdge
run: |
sudo apt update && sudo apt install software-properties-common llvm-14-dev liblld-14-dev
git clone https://github.com/WasmEdge/WasmEdge /tmp/WasmEdge
mkdir /tmp/build && cd /tmp/build
cmake /tmp/WasmEdge -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_PLUGIN_LLM=ON -DWASMEDGE_BUILD_TESTS=OFF
make -j$(nproc) && sudo make install
- name: Train GPT2
run: |
WASMEDGE_PLUGIN_PATH=/usr/local/lib/wasmedge/ wasmedge --dir .:. ./example/target/wasm32-wasi/release/example.wasm
19 changes: 0 additions & 19 deletions .github/workflows/ci.yml

This file was deleted.

23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,25 @@
# wasmedge-llmc
A Rust library for using llm.c functions when the Wasi is being executed on WasmEdge.

## Set up WasmEdge
```bash
git clone https://github.com/WasmEdge/WasmEdge.git
cd WasmEdge
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TESTS=OFF -DWASMEDGE_PLUGIN_LLM=On
cmake --build build
cmake --install build
```

## Download Checkpoints & Training data
```bash
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin
wget -P /tmp/data/ https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin
```

## Run the example
```bash
cargo build --target wasm32-wasi --release
wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm
```
9 changes: 9 additions & 0 deletions example/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "example"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
wasmedge-llmc = {path="../wasmedge-llmc"}
51 changes: 51 additions & 0 deletions example/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use wasmedge_llmc::*;

fn main() {
let config = ConfigBuilder::default().lr(0.0002).epoch(20).build();
let model = match Model::from_checkpoints("/tmp/data/gpt2_124M.bin") {
Ok(m) => m,
Err(e) => {
eprintln!("Failed to load model: {:?}", e);
return;
}
};
let train_dataloader = match DataLoader::from_file(
"/tmp/data/tiny_shakespeare_train.bin",
4, // batch size
64, // sequence length
0, // process rank
1, // number of processes
true, // should shuffle
) {
Ok(loader) => loader,
Err(e) => {
eprintln!("Failed to load training data: {:?}", e);
return;
}
};
let val_dataloader = match DataLoader::from_file(
"/tmp/data/tiny_shakespeare_val.bin",
4, // batch size
64, // sequence length
0, // process rank
1, // number of processes
false, // should shuffle
) {
Ok(loader) => loader,
Err(e) => {
eprintln!("Failed to load validation data: {:?}", e);
return;
}
};
let tokenizer = match Tokenizer::from_file("/tmp/data/gpt2_tokenizer.bin") {
Ok(t) => t,
Err(e) => {
eprintln!("Failed to load tokenizer: {:?}", e);
return;
}
};
match model.train(train_dataloader, val_dataloader, tokenizer, config) {
Ok(_) => println!("Training completed successfully."),
Err(e) => eprintln!("Failed to train model: {:?}", e),
}
}
8 changes: 8 additions & 0 deletions wasmedge-llmc/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "wasmedge-llmc"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
138 changes: 138 additions & 0 deletions wasmedge-llmc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
pub mod llmc_interface;
use core::mem::MaybeUninit;
use llmc_interface::*;

pub struct Model {
id: u32,
}

pub struct Config {
lr: f32,
epoch: u32,
}

pub struct ConfigBuilder {
lr: f32,
epoch: u32,
}

impl Default for ConfigBuilder {
fn default() -> Self {
Self {
lr: 0.0001,
epoch: 20,
}
}
}

impl ConfigBuilder {
pub fn lr(mut self, lr: f32) -> Self {
self.lr = lr;
self
}

pub fn epoch(mut self, epoch: u32) -> Self {
self.epoch = epoch;
self
}

pub fn build(self) -> Config {
Config {
lr: self.lr,
epoch: self.epoch,
}
}
}

impl Model {
pub fn from_checkpoints(checkpoint_path: &str) -> Result<Self, WasmedgeLLMErrno> {
let mut model_id = MaybeUninit::<u32>::uninit();
unsafe {
let result = model_create(checkpoint_path, model_id.as_mut_ptr());
if let Err(code) = result {
return Err(code);
}
Ok(Model {
id: model_id.assume_init(),
})
}
}

pub fn train(
&self,
train_data_loader: DataLoader,
val_data_loader: DataLoader,
tokenizer: Tokenizer,
config: Config,
) -> Result<(), WasmedgeLLMErrno> {
unsafe {
model_train(
self.id,
train_data_loader.id,
val_data_loader.id,
tokenizer.id,
train_data_loader.batch_size,
train_data_loader.sequence_length,
config.lr,
config.epoch,
)
}
}
}

pub struct DataLoader {
id: u32,
batch_size: u32,
sequence_length: u32,
}

impl DataLoader {
pub fn from_file(
data_path: &str,
batch_size: u32,
sequence_length: u32,
process_rank: u32,
num_processes: u32,
should_shuffle: bool,
) -> Result<Self, WasmedgeLLMErrno> {
let mut dataloader_id = MaybeUninit::<u32>::uninit();
unsafe {
let result = dataloader_create(
data_path,
batch_size,
sequence_length,
process_rank,
num_processes,
should_shuffle,
dataloader_id.as_mut_ptr(),
);
if let Err(code) = result {
return Err(code);
}
Ok(DataLoader {
id: dataloader_id.assume_init(),
batch_size,
sequence_length,
})
}
}
}

pub struct Tokenizer {
id: u32,
}

impl Tokenizer {
pub fn from_file(filepath: &str) -> Result<Self, WasmedgeLLMErrno> {
let mut tokenizer_id = MaybeUninit::<u32>::uninit();
unsafe {
let result = tokenizer_create(filepath, tokenizer_id.as_mut_ptr());
if let Err(code) = result {
return Err(code);
}
Ok(Tokenizer {
id: tokenizer_id.assume_init(),
})
}
}
}
Loading

0 comments on commit 41954f5

Please sign in to comment.