Skip to content

Commit

Permalink
Merge pull request #8 from edgenai/feat/cuda
Browse files Browse the repository at this point in the history
Feat/cuda
  • Loading branch information
pedro-devv authored Feb 18, 2024
2 parents 0a5369c + 1bad47f commit 5833e4e
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 140 deletions.
47 changes: 27 additions & 20 deletions crates/whisper_cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use whisper_cpp_sys::{
whisper_full_params__bindgen_ty_2, whisper_full_with_state,
whisper_init_from_file_with_params_no_state, whisper_init_state, whisper_log_set,
whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH,
whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY, whisper_state, whisper_token,
whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY, whisper_state,
};

/// Boolean indicating if a logger has already been set using [`whisper_log_set`].
Expand Down Expand Up @@ -64,15 +64,19 @@ pub struct WhisperModel {
}

impl WhisperModel {
/// Loads a new *ggml* *whisper* model, given its file path.
/// Loads a new *ggml* *whisper* model, given its file path. If a device (GPU) index is
/// provided, the model is loaded into the GPU.
#[doc(alias = "whisper_init_from_file_with_params_no_state")]
pub fn new_from_file<P>(model_path: P, use_gpu: bool) -> Result<Self, WhisperError>
pub fn new_from_file<P>(model_path: P, device: Option<u32>) -> Result<Self, WhisperError>
where
P: AsRef<std::path::Path>,
{
set_log();

let params = whisper_context_params { use_gpu };
let params = whisper_context_params {
use_gpu: device.is_some(),
gpu_device: device.unwrap_or(0) as i32,
};

let path_bytes = model_path
.as_ref()
Expand Down Expand Up @@ -152,7 +156,6 @@ pub enum WhisperSessionError {
pub struct WhisperSession {
context: Arc<RwLock<WhisperContext>>,
state: WhisperState,
prompt: Vec<whisper_token>,
}

impl WhisperSession {
Expand All @@ -171,7 +174,6 @@ impl WhisperSession {
Ok(Self {
context,
state: WhisperState(state),
prompt: vec![],
})
}

Expand Down Expand Up @@ -237,13 +239,11 @@ impl WhisperSession {
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
/// Uses the specified decoding strategy to obtain the text.
#[doc(alias = "whisper_full_with_state")]
pub async fn full(
pub async fn advance(
&mut self,
mut params: WhisperParams,
params: WhisperParams,
samples: &[f32],
) -> Result<(), WhisperSessionError> {
// TODO use no_context from whisper_params instead
params.prompt_tokens = self.prompt.clone();
let locked = self.context.read().await;
let res = unsafe {
let (_vec, c_params) = params.c_params()?;
Expand All @@ -260,15 +260,6 @@ impl WhisperSession {
return Err(WhisperSessionError::Internal);
}

let segments = self.segment_count();

for s in 0..segments {
let tokens = self.token_count(s);
for t in 0..tokens {
self.prompt.push(self.token_id(s, t));
}
}

Ok(())
}

Expand Down Expand Up @@ -305,6 +296,11 @@ impl WhisperSession {
pub fn segment_text(&self, segment: u32) -> Result<String, WhisperSessionError> {
let text = unsafe {
let res = whisper_full_get_segment_text_from_state(self.state.0, segment as c_int);

if res.is_null() {
return Err(WhisperSessionError::Internal);
}

CStr::from_ptr(res.cast_mut())
};

Expand Down Expand Up @@ -345,6 +341,17 @@ impl WhisperSession {
pub fn token_probability(&self) {
todo!()
}

/// Returns the decoded text of the last segment encoding.
pub fn new_context(&self) -> Result<String, WhisperSessionError> {
let mut res = "".to_string();

for i in 0..self.segment_count() {
res += &*self.segment_text(i)?;
}

Ok(res)
}
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -402,7 +409,7 @@ pub struct WhisperParams {
translate: bool,

/// Do not use past transcription (if any) as initial prompt for the decoder.
no_context: bool,
pub no_context: bool,

/// Do not generate timestamps.
no_timestamps: bool,
Expand Down
Loading

0 comments on commit 5833e4e

Please sign in to comment.