Skip to content

Commit

Permalink
Merge pull request #44 from tracel-ai/chore/bert-update
Browse files Browse the repository at this point in the history
Update BERT to Burn 0.14
  • Loading branch information
nathanielsimard authored Oct 20, 2024
2 parents 042ba6a + 894f7af commit e659097
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 62 deletions.
7 changes: 4 additions & 3 deletions bert-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@ version = "0.2.0"
edition = "2021"

[features]
default = ["burn/dataset"]
default = ["wgpu", "fusion"]
f16 = []
ndarray = ["burn/ndarray"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda = ["burn/cuda-jit"]
fusion = ["burn/fusion"]
# To be replaced by burn-safetensors once supported: https://github.com/tracel-ai/burn/issues/626
safetensors = ["candle-core/default"]


[dependencies]
# Burn
burn = { version = "0.13", default-features = false }
candle-core = { version = "0.3.2", optional = true }
burn = { version = "0.14", default-features = false, features = ["dataset", "std"] }
candle-core = { version = "0.3" }
# Tokenizer
tokenizers = { version = "0.15.0", default-features = false, features = [
"onig",
Expand Down
2 changes: 1 addition & 1 deletion bert-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bert-burn = { git = "https://github.com/tracel-ai/models", package = "bert-burn"
## Example Usage

Example usage for getting sentence embedding from given input text. The model supports multiple backends from burn
(e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`) which can be selected using the `--features` flag. An example with `wgpu`
(e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`, `cuda`) which can be selected using the `--features` flag. An example with `wgpu`
backend is shown below. The `fusion` flag is used to enable kernel fusion for the `wgpu` backend. It is not required
with other backends. The `safetensors` flag is used to support loading weights in `safetensors` format via `candle-core`
crate.
Expand Down
26 changes: 8 additions & 18 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn launch<B: Backend>(device: B::Device) {
// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());
println!("Input: {}", input.tokens);

let output = model.forward(input);

Expand All @@ -84,17 +84,12 @@ pub fn launch<B: Backend>(device: B::Device) {

let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1);
println!(
"Roberta Sentence embedding {:?} // (Batch Size, Embedding_dim)",
sentence_embedding.shape()
"Roberta Sentence embedding: {}",
sentence_embedding
);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
#[cfg(feature = "ndarray")]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

Expand Down Expand Up @@ -132,21 +127,16 @@ mod tch_cpu {

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use crate::launch;
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
launch::<Wgpu>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
#[cfg(feature = "ndarray")]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
Expand Down
20 changes: 5 additions & 15 deletions bert-burn/examples/masked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,7 @@ pub fn launch<B: Backend>(device: B::Device) {
}
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
#[cfg(feature = "ndarray")]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

Expand Down Expand Up @@ -125,21 +120,16 @@ mod tch_cpu {

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use crate::launch;
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
launch::<Wgpu>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
#[cfg(feature = "ndarray")]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
Expand Down
26 changes: 13 additions & 13 deletions bert-burn/src/fill_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
model::BertMaskedLM,
model::BertModelConfig,
};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor};
use burn::tensor::{activation::softmax, backend::Backend, Element, Tensor};

type TokenType = usize;
const MASK_TOKEN_ID: TokenType = 50264;
Expand Down Expand Up @@ -34,10 +34,10 @@ pub fn fill_mask<B: Backend>(
.tokens
.clone()
.slice([i..i + 1, 0..seq_len])
.squeeze(0)
.to_data();
.squeeze::<1>(0)
.into_data();
// Find the mask tokens in the input, as a list of indices
let masks = find_masks(&input_tokens, MASK_TOKEN_ID);
let masks = find_masks(input_tokens.as_slice::<B::IntElem>().unwrap(), MASK_TOKEN_ID);
for mask in masks {
let logits = output
.clone()
Expand All @@ -60,29 +60,29 @@ pub fn fill_mask<B: Backend>(
results
}

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
fn find_masks<T: Element>(tokens: &[T], mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
for (i, token) in tokens.iter().enumerate() {
if token.to_usize() == mask_token_id {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
fn data_to_vec_f32<T: Element>(data: &[T]) -> Vec<f32> {
data.iter().map(|x| x.to_f32()).collect()
}

fn data_to_vec_usize<T: Element>(data: &Data<T, 1>) -> Vec<usize> {
data.value.iter().map(|x| x.to_usize().unwrap()).collect()
fn data_to_vec_usize<T: Element>(data: &[T]) -> Vec<usize> {
data.iter().map(|x| x.to_usize()).collect()
}

fn top_k<B: Backend>(k: usize, logits: Tensor<B, 1>) -> Vec<(usize, f32)> {
let (pre_soft_probs, indices) = logits.sort_with_indices(0);
let (probabilities, indices) = (
data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()),
data_to_vec_usize(&indices.to_data()),
data_to_vec_f32(&softmax(pre_soft_probs, 0).into_data().as_slice::<B::FloatElem>().unwrap()),
data_to_vec_usize(&indices.into_data().as_slice::<B::IntElem>().unwrap()),
);
probabilities
.iter()
Expand Down
33 changes: 21 additions & 12 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use burn::nn::transformer::{
};
use burn::nn::{EmbeddingRecord, LayerNormRecord, LinearRecord};
use burn::tensor::backend::Backend;
use burn::tensor::{Data, Shape, Tensor};
use burn::tensor::{Shape, Tensor, TensorData};
use candle_core::Tensor as CandleTensor;
use std::collections::HashMap;
use std::path::PathBuf;
Expand All @@ -24,7 +24,7 @@ pub(crate) fn load_1d_tensor_from_candle<B: Backend>(
let dims = tensor.dims();
let data = tensor.to_vec1::<f32>().unwrap();
let array: [usize; 1] = dims.try_into().expect("Unexpected size");
let data = Data::new(data, Shape::new(array));
let data = TensorData::new(data, Shape::new(array));
let weight = Tensor::<B, 1>::from_floats(data, &device.clone());
weight
}
Expand All @@ -41,7 +41,7 @@ pub(crate) fn load_2d_tensor_from_candle<B: Backend>(
.flatten()
.collect::<Vec<f32>>();
let array: [usize; 2] = dims.try_into().expect("Unexpected size");
let data = Data::new(data, Shape::new(array));
let data = TensorData::new(data, Shape::new(array));
let weight = Tensor::<B, 2>::from_floats(data, &device.clone());
weight
}
Expand Down Expand Up @@ -90,8 +90,8 @@ pub(crate) fn load_intermediate_layer_safetensor<B: Backend>(
let linear_outer = load_linear_safetensor::<B>(linear_outer_bias, linear_outer_weight, device);

let pwff_record = PositionWiseFeedForwardRecord {
linear_inner: linear_inner,
linear_outer: linear_outer,
linear_inner,
linear_outer,
dropout: ConstantRecord::new(),
gelu: ConstantRecord::new(),
};
Expand Down Expand Up @@ -128,10 +128,11 @@ fn load_attention_layer_safetensor<B: Backend>(
);

let attention_record = MultiHeadAttentionRecord {
query: query,
key: key,
value: value,
output: output,
query,
key,
value,
output,
d_model: ConstantRecord::new(),
dropout: ConstantRecord::new(),
activation: ConstantRecord::new(),
n_heads: ConstantRecord::new(),
Expand Down Expand Up @@ -211,9 +212,9 @@ pub fn load_encoder_from_safetensors<B: Backend>(

let layer_record = TransformerEncoderLayerRecord {
mha: attention_layer,
pwff: pwff,
norm_1: norm_1,
norm_2: norm_2,
pwff,
norm_1,
norm_2,
dropout: ConstantRecord::new(),
norm_first: ConstantRecord::new(),
};
Expand All @@ -223,6 +224,14 @@ pub fn load_encoder_from_safetensors<B: Backend>(

let encoder_record = TransformerEncoderRecord {
layers: bert_encoder_layers,
d_model: ConstantRecord::new(),
d_ff: ConstantRecord::new(),
n_heads: ConstantRecord::new(),
n_layers: ConstantRecord::new(),
dropout: ConstantRecord::new(),
norm_first: ConstantRecord::new(),
quiet_softmax: ConstantRecord::new(),

};
encoder_record
}
Expand Down

0 comments on commit e659097

Please sign in to comment.