-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #34 from seurimas/main
Add LM head and masked fill_mask for bert-burn.
- Loading branch information
Showing
5 changed files
with
370 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
use bert_burn::data::{BertInputBatcher, BertTokenizer}; | ||
use bert_burn::fill_mask::fill_mask; | ||
use bert_burn::loader::{download_hf_model, load_model_config}; | ||
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord}; | ||
use burn::data::dataloader::batcher::Batcher; | ||
use burn::module::Module; | ||
use burn::tensor::backend::Backend; | ||
use std::env; | ||
use std::sync::Arc; | ||
|
||
#[cfg(not(feature = "f16"))] | ||
#[allow(dead_code)] | ||
type ElemType = f32; | ||
#[cfg(feature = "f16")] | ||
type ElemType = burn::tensor::f16; | ||
|
||
pub fn launch<B: Backend>(device: B::Device) { | ||
let args: Vec<String> = env::args().collect(); | ||
let default_model = "roberta-base".to_string(); | ||
let model_variant = if args.len() > 1 { | ||
// Use the argument provided by the user | ||
// Possible values: "bert-base-uncased", "roberta-large" etc. | ||
&args[1] | ||
} else { | ||
// Use the default value if no argument is provided | ||
&default_model | ||
}; | ||
|
||
println!("Model variant: {}", model_variant); | ||
|
||
let text_samples = vec![ | ||
"Paris is the <mask> of France.".to_string(), | ||
"The goal of life is <mask>.".to_string(), | ||
]; | ||
|
||
let (config_file, model_file) = download_hf_model(model_variant); | ||
let model_config = load_model_config(config_file); | ||
|
||
let model_record: BertMaskedLMRecord<B> = | ||
BertMaskedLM::from_safetensors(model_file, &device, model_config.clone()); | ||
|
||
let model = model_config | ||
.init_with_lm_head(&device) | ||
.load_record(model_record); | ||
|
||
let tokenizer = Arc::new(BertTokenizer::new( | ||
model_variant.to_string(), | ||
model_config.pad_token_id, | ||
)); | ||
|
||
// Batch the input samples to max sequence length with padding | ||
let batcher = Arc::new(BertInputBatcher::<B>::new( | ||
tokenizer.clone(), | ||
device.clone(), | ||
model_config.max_seq_len.unwrap(), | ||
)); | ||
|
||
// Batch input samples using the batcher Shape: [Batch size, Seq_len] | ||
let input = batcher.batch(text_samples.clone()); | ||
let [batch_size, _seq_len] = input.tokens.dims(); | ||
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); | ||
|
||
let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input); | ||
|
||
for i in 0..batch_size { | ||
let input = &text_samples[i]; | ||
let result = &output[i]; | ||
println!("Input: {}", input); | ||
for fill_mask_result in result.iter() { | ||
let mask_idx = fill_mask_result.mask_idx; | ||
let top_k = &fill_mask_result.top_k; | ||
for (k, (score, token)) in top_k.iter().enumerate() { | ||
println!( | ||
"Top {} Prediction for {}: {} (Score: {:.4})", | ||
k + 1, | ||
mask_idx, | ||
token, | ||
score | ||
); | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
mod ndarray { | ||
use burn::backend::ndarray::{NdArray, NdArrayDevice}; | ||
|
||
use crate::{launch, ElemType}; | ||
|
||
pub fn run() { | ||
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::libtorch::{LibTorch, LibTorchDevice}; | ||
|
||
pub fn run() { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
launch::<LibTorch<ElemType>>(device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "tch-cpu")] | ||
mod tch_cpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::libtorch::{LibTorch, LibTorchDevice}; | ||
|
||
pub fn run() { | ||
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use crate::{launch, ElemType}; | ||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; | ||
|
||
pub fn run() { | ||
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default()); | ||
} | ||
} | ||
|
||
fn main() { | ||
#[cfg(any( | ||
feature = "ndarray", | ||
feature = "ndarray-blas-netlib", | ||
feature = "ndarray-blas-openblas", | ||
feature = "ndarray-blas-accelerate", | ||
))] | ||
ndarray::run(); | ||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(); | ||
#[cfg(feature = "tch-cpu")] | ||
tch_cpu::run(); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
use crate::{ | ||
data::Tokenizer, | ||
data::{BertInferenceBatch, BertTokenizer}, | ||
model::BertMaskedLM, | ||
model::BertModelConfig, | ||
}; | ||
use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor}; | ||
|
||
type TokenType = usize; | ||
const MASK_TOKEN_ID: TokenType = 50264; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct FillMaskResult { | ||
pub mask_idx: usize, | ||
pub top_k: Vec<(f32, String)>, | ||
} | ||
|
||
pub fn fill_mask<B: Backend>( | ||
model: &BertMaskedLM<B>, | ||
model_config: &BertModelConfig, | ||
tokenizer: &BertTokenizer, | ||
input: BertInferenceBatch<B>, | ||
) -> Vec<Vec<FillMaskResult>> { | ||
let [batch_size, seq_len] = input.tokens.dims(); | ||
let output = model.forward(input.clone()); | ||
|
||
let mut results = vec![]; | ||
|
||
// Embedding size | ||
let d_model = model_config.vocab_size.clone(); | ||
for i in 0..batch_size { | ||
let mut batch_results = vec![]; | ||
let input_tokens = input | ||
.tokens | ||
.clone() | ||
.slice([i..i + 1, 0..seq_len]) | ||
.squeeze(0) | ||
.to_data(); | ||
// Find the mask tokens in the input, as a list of indices | ||
let masks = find_masks(&input_tokens, MASK_TOKEN_ID); | ||
for mask in masks { | ||
let logits = output | ||
.clone() | ||
.slice([i..i + 1, mask..(mask + 1), 0..d_model]) | ||
.squeeze::<2>(0) | ||
.squeeze(0); | ||
// Find the top k tokens with the highest probabilities | ||
let top_k = top_k(5, logits); | ||
batch_results.push(FillMaskResult { | ||
mask_idx: mask, | ||
top_k: top_k | ||
.iter() | ||
.map(|(k, prob)| (*prob, tokenizer.decode(&[*k]))) | ||
.collect(), | ||
}); | ||
} | ||
results.push(batch_results); | ||
} | ||
|
||
results | ||
} | ||
|
||
fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> { | ||
let mut masks = Vec::new(); | ||
for (i, token) in tokens.value.iter().enumerate() { | ||
if token.to_usize() == Some(mask_token_id) { | ||
masks.push(i); | ||
} | ||
} | ||
masks | ||
} | ||
|
||
fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> { | ||
data.value.iter().map(|x| x.to_f32().unwrap()).collect() | ||
} | ||
|
||
fn data_to_vec_usize<T: Element>(data: &Data<T, 1>) -> Vec<usize> { | ||
data.value.iter().map(|x| x.to_usize().unwrap()).collect() | ||
} | ||
|
||
fn top_k<B: Backend>(k: usize, logits: Tensor<B, 1>) -> Vec<(usize, f32)> { | ||
let (pre_soft_probs, indices) = logits.sort_with_indices(0); | ||
let (probabilities, indices) = ( | ||
data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()), | ||
data_to_vec_usize(&indices.to_data()), | ||
); | ||
probabilities | ||
.iter() | ||
.enumerate() | ||
.rev() | ||
.take(k) | ||
.map(|(i, &p)| (indices[i], p)) | ||
.collect() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.