Skip to content

Commit

Permalink
feat: add bench method & update example (#269)
Browse files Browse the repository at this point in the history
* feat(example): bump to rn 0.74 & add react-navigation

* feat(example): wip Bench container

* feat(example): download models

* feat(ios, cpp): add bench method

* feat(ios): log as markdown

* fix: timings

* feat(example): copy to clipboard for logs

* fix(ts): lint errors

* feat(example): update default select

* feat(cpp, android): add system config & update android module

* fix(cpp): system_info parse

* fix: patch

* feat(cpp): refactor whisper_timings

* fix(example): use react-native-gesture-handler

* feat(example): split copy logs button

* fix(example): log

* feat(example): move useFlashAttn usage to context-opts

* feat(example): minor refactor

* chore: update deps & docgen

* fix(example): remove default props

* feat(ts): update mock

* fix: deps

* feat(example): update model button style

* fix(example): realtime button title
  • Loading branch information
jhen0409 authored Nov 8, 2024
1 parent 48de6ed commit 45afdc7
Show file tree
Hide file tree
Showing 46 changed files with 3,140 additions and 1,623 deletions.
9 changes: 9 additions & 0 deletions android/src/main/java/com/rnwhisper/RNWhisper.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ protected void onPostExecute(Void result) {
tasks.put(task, "abortTranscribe-" + id);
}

public void bench(double id, double nThreads, Promise promise) {
final WhisperContext context = contexts.get((int) id);
if (context == null) {
promise.reject("Context not found");
return;
}
promise.resolve(context.bench((int) nThreads));
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
5 changes: 5 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ public void stopCurrentTranscribe() {
stopTranscribe(this.jobId);
}

public String bench(int n_threads) {
return bench(context, n_threads);
}

public void release() {
stopCurrentTranscribe();
freeContext(context);
Expand Down Expand Up @@ -527,4 +531,5 @@ protected static native int fullWithJob(
int slice_index,
int n_samples
);
protected static native String bench(long context, int n_threads);
}
13 changes: 13 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,4 +508,17 @@ Java_com_rnwhisper_WhisperContext_getTextSegmentSpeakerTurnNext(
return whisper_full_get_segment_speaker_turn_next(context, index);
}

JNIEXPORT jstring JNICALL
Java_com_rnwhisper_WhisperContext_bench(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jint n_threads
) {
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
std::string result = rnwhisper::bench(context, n_threads);
return env->NewStringUTF(result.c_str());
}

} // extern "C"
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
rnwhisper.abortTranscribe(contextId, jobId, promise);
}

@ReactMethod
public void bench(double id, double nThreads, Promise promise) {
rnwhisper.bench(id, nThreads, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnwhisper.releaseContext(id, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public void abortTranscribe(double contextId, double jobId, Promise promise) {
rnwhisper.abortTranscribe(contextId, jobId, promise);
}

@ReactMethod
public void bench(double id, double nThreads, Promise promise) {
rnwhisper.bench(id, nThreads, promise);
}

@ReactMethod
public void releaseContext(double id, Promise promise) {
rnwhisper.releaseContext(id, promise);
Expand Down
91 changes: 91 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,97 @@

namespace rnwhisper {

const char * system_info(void) {
static std::string s;
s = "";
if (wsp_ggml_cpu_has_avx() == 1) s += "AVX ";
if (wsp_ggml_cpu_has_avx2() == 1) s += "AVX2 ";
if (wsp_ggml_cpu_has_avx512() == 1) s += "AVX512 ";
if (wsp_ggml_cpu_has_fma() == 1) s += "FMA ";
if (wsp_ggml_cpu_has_neon() == 1) s += "NEON ";
if (wsp_ggml_cpu_has_arm_fma() == 1) s += "ARM_FMA ";
if (wsp_ggml_cpu_has_metal() == 1) s += "METAL ";
if (wsp_ggml_cpu_has_f16c() == 1) s += "F16C ";
if (wsp_ggml_cpu_has_fp16_va() == 1) s += "FP16_VA ";
if (wsp_ggml_cpu_has_blas() == 1) s += "BLAS ";
if (wsp_ggml_cpu_has_sse3() == 1) s += "SSE3 ";
if (wsp_ggml_cpu_has_ssse3() == 1) s += "SSSE3 ";
if (wsp_ggml_cpu_has_vsx() == 1) s += "VSX ";
#ifdef WHISPER_USE_COREML
s += "COREML ";
#endif
s.erase(s.find_last_not_of(" ") + 1);
return s.c_str();
}

std::string bench(struct whisper_context * ctx, int n_threads) {
const int n_mels = whisper_model_n_mels(ctx);

if (int ret = whisper_set_mel(ctx, nullptr, 0, n_mels)) {
return "error: failed to set mel: " + std::to_string(ret);
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

whisper_token tokens[512];
memset(tokens, 0, sizeof(tokens));

// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}

whisper_reset_timings(ctx);

// actual run
if (int ret = whisper_encode(ctx, 0, n_threads) != 0) {
return "error: failed to encode: " + std::to_string(ret);
}

// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, n_threads) != 0) {
return "error: failed to decode: " + std::to_string(ret);
}
}

const struct whisper_timings * timings = whisper_get_timings(ctx);

const int32_t n_encode = std::max(1, timings->n_encode);
const int32_t n_decode = std::max(1, timings->n_decode);
const int32_t n_batchd = std::max(1, timings->n_batchd);
const int32_t n_prompt = std::max(1, timings->n_prompt);

return std::string("[") +
"\"" + system_info() + "\"," +
std::to_string(n_threads) + "," +
std::to_string(1e-3f * timings->t_encode_us / n_encode) + "," +
std::to_string(1e-3f * timings->t_decode_us / n_decode) + "," +
std::to_string(1e-3f * timings->t_batchd_us / n_batchd) + "," +
std::to_string(1e-3f * timings->t_prompt_us / n_prompt) + "]";
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
Expand Down
2 changes: 2 additions & 0 deletions cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace rnwhisper {

std::string bench(whisper_context * ctx, int n_threads);

struct vad_params {
bool use_vad = false;
float vad_thold = 0.6f;
Expand Down
43 changes: 33 additions & 10 deletions cpp/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4190,28 +4190,51 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}

struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
if (ctx->state == nullptr) {
return nullptr;
}
return new whisper_timings {
.load_us = ctx->t_load_us,
.t_start_us = ctx->t_start_us,
.fail_p = ctx->state->n_fail_p,
.fail_h = ctx->state->n_fail_h,
.t_mel_us = ctx->state->t_mel_us,
.n_sample = ctx->state->n_sample,
.n_encode = ctx->state->n_encode,
.n_decode = ctx->state->n_decode,
.n_batchd = ctx->state->n_batchd,
.n_prompt = ctx->state->n_prompt,
.t_sample_us = ctx->state->t_sample_us,
.t_encode_us = ctx->state->t_encode_us,
.t_decode_us = ctx->state->t_decode_us,
.t_batchd_us = ctx->state->t_batchd_us,
.t_prompt_us = ctx->state->t_prompt_us,
};
}

void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = wsp_ggml_time_us();
const struct whisper_timings * timings = whisper_get_timings(ctx);

WHISPER_LOG_INFO("\n");
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings->load_us / 1000.0f);
if (ctx->state != nullptr) {

const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);

WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, timings->fail_p, timings->fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, timings->t_mel_us/1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_sample_us, n_sample, 1e-3f * timings->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_encode_us, n_encode, 1e-3f * timings->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_decode_us, n_decode, 1e-3f * timings->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_batchd_us, n_batchd, 1e-3f * timings->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * timings->t_prompt_us, n_prompt, 1e-3f * timings->t_prompt_us / n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - timings->t_start_us)/1000.0f);
}

void whisper_reset_timings(struct whisper_context * ctx) {
Expand Down
18 changes: 18 additions & 0 deletions cpp/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,24 @@ extern "C" {
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);

// Performance information from the default state.
struct whisper_timings {
int64_t load_us;
int64_t t_start_us;
int32_t fail_p;
int32_t fail_h;
int64_t t_mel_us;
int32_t n_sample;
int32_t n_encode;
int32_t n_decode;
int32_t n_batchd;
int32_t n_prompt;
int64_t t_sample_us;
int64_t t_encode_us;
int64_t t_decode_us;
int64_t t_batchd_us;
int64_t t_prompt_us;
};
WHISPER_API struct whisper_timings * whisper_get_timings(struct whisper_context * ctx);
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
WHISPER_API void whisper_reset_timings(struct whisper_context * ctx);

Expand Down
Loading

0 comments on commit 45afdc7

Please sign in to comment.