Skip to content

Commit

Permalink
feat(bark-cpp): add new bark.cpp backend (#4287)
Browse files Browse the repository at this point in the history
* feat(bark-cpp): add new bark.cpp backend

Signed-off-by: Ettore Di Giacinto <[email protected]>

* build on linux only for now

Signed-off-by: Ettore Di Giacinto <[email protected]>

* track bark.cpp in CI bumps

Signed-off-by: Ettore Di Giacinto <[email protected]>

* Drop old entries from bumper

Signed-off-by: Ettore Di Giacinto <[email protected]>

* No need to test rwkv specifically, now part of llama.cpp

Signed-off-by: Ettore Di Giacinto <[email protected]>

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Nov 28, 2024
1 parent 0d6c3a7 commit 58ff47d
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 82 deletions.
16 changes: 2 additions & 14 deletions .github/workflows/bump_deps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,12 @@ jobs:
- repository: "ggerganov/llama.cpp"
variable: "CPPLLAMA_VERSION"
branch: "master"
- repository: "go-skynet/go-ggml-transformers.cpp"
variable: "GOGGMLTRANSFORMERS_VERSION"
branch: "master"
- repository: "donomii/go-rwkv.cpp"
variable: "RWKV_VERSION"
branch: "main"
- repository: "ggerganov/whisper.cpp"
variable: "WHISPER_CPP_VERSION"
branch: "master"
- repository: "go-skynet/go-bert.cpp"
variable: "BERT_VERSION"
branch: "master"
- repository: "go-skynet/bloomz.cpp"
variable: "BLOOMZ_VERSION"
- repository: "PABannier/bark.cpp"
variable: "BARKCPP_VERSION"
branch: "main"
- repository: "mudler/go-ggllm.cpp"
variable: "GOGGLLM_VERSION"
branch: "master"
- repository: "mudler/go-stable-diffusion"
variable: "STABLEDIFFUSION_VERSION"
branch: "master"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/sources/
__pycache__/
*.a
*.o
get-sources
prepare-sources
/backend/cpp/llama/grpc-server
Expand Down
37 changes: 36 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057

# bark.cpp
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
BARKCPP_VERSION?=v1.0.0

ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64
ONNX_OS?=linux
Expand Down Expand Up @@ -201,6 +205,13 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper

ifeq ($(ONNX_OS),linux)
ifeq ($(ONNX_ARCH),x64)
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
endif
endif

ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
Expand Down Expand Up @@ -233,6 +244,22 @@ sources/go-llama.cpp:
git checkout $(GOLLAMA_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch

sources/bark.cpp:
git clone --recursive https://github.com/PABannier/bark.cpp.git sources/bark.cpp && \
cd sources/bark.cpp && \
git checkout $(BARKCPP_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch

sources/bark.cpp/build/libbark.a: sources/bark.cpp
cd sources/bark.cpp && \
mkdir build && \
cd build && \
cmake $(CMAKE_ARGS) .. && \
cmake --build . --config Release

backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
$(MAKE) -C backend/go/bark libbark.a

sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a

Expand Down Expand Up @@ -302,7 +329,7 @@ sources/whisper.cpp:
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a

get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
get-sources: sources/go-llama.cpp sources/go-piper sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp

replace:
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
Expand Down Expand Up @@ -343,6 +370,7 @@ clean: ## Remove build related file
rm -rf release/
rm -rf backend-assets/*
$(MAKE) -C backend/cpp/grpc clean
$(MAKE) -C backend/go/bark clean
$(MAKE) -C backend/cpp/llama clean
rm -rf backend/cpp/llama-* || true
$(MAKE) dropreplace
Expand Down Expand Up @@ -792,6 +820,13 @@ ifneq ($(UPX),)
$(UPX) backend-assets/grpc/llama-ggml
endif

backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/bark-cpp
endif

backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
Expand Down
25 changes: 25 additions & 0 deletions backend/go/bark/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
INCLUDE_PATH := $(abspath ./)
LIBRARY_PATH := $(abspath ./)

AR?=ar

BUILD_TYPE?=
# keep standard at C11 and C++11
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm

# warnings
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function

gobark.o:
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)

libbark.a: gobark.o
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
$(AR) rcs libbark.a gobark.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o

clean:
rm -f gobark.o libbark.a
85 changes: 85 additions & 0 deletions backend/go/bark/gobark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <iostream>
#include <tuple>

#include "bark.h"
#include "gobark.h"
#include "common.h"
#include "ggml.h"

struct bark_context *c;

void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
if (step == bark_encoding_step::SEMANTIC) {
printf("\rGenerating semantic tokens... %d%%", progress);
} else if (step == bark_encoding_step::COARSE) {
printf("\rGenerating coarse tokens... %d%%", progress);
} else if (step == bark_encoding_step::FINE) {
printf("\rGenerating fine tokens... %d%%", progress);
}
fflush(stdout);
}

int load_model(char *model) {
// initialize bark context
struct bark_context_params ctx_params = bark_context_default_params();
bark_params params;

params.model_path = model;

// ctx_params.verbosity = verbosity;
ctx_params.progress_callback = bark_print_progress_callback;
ctx_params.progress_callback_user_data = nullptr;

struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
if (!bctx) {
fprintf(stderr, "%s: Could not load model\n", __func__);
return 1;
}

c = bctx;

return 0;
}

int tts(char *text,int threads, char *dst ) {

ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();

// generate audio
if (!bark_generate_audio(c, text, threads)) {
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
return 1;
}

const float *audio_data = bark_get_audio_data(c);
if (audio_data == NULL) {
fprintf(stderr, "%s: Could not get audio data\n", __func__);
return 1;
}

const int audio_arr_size = bark_get_audio_data_size(c);

std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);

write_wav_on_disk(audio_arr, dst);

// report timing
{
const int64_t t_main_end_us = ggml_time_us();
const int64_t t_load_us = bark_get_load_time(c);
const int64_t t_eval_us = bark_get_eval_time(c);

printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
}

return 0;
}

int unload() {
bark_free(c);
}

52 changes: 52 additions & 0 deletions backend/go/bark/gobark.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
// #include <gobark.h>
// #include <stdlib.h>
import "C"

import (
"fmt"
"unsafe"

"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)

type Bark struct {
base.SingleThread
threads int
}

func (sd *Bark) Load(opts *pb.ModelOptions) error {

sd.threads = int(opts.Threads)

modelFile := C.CString(opts.ModelFile)
defer C.free(unsafe.Pointer(modelFile))

ret := C.load_model(modelFile)
if ret != 0 {
return fmt.Errorf("inference failed")
}

return nil
}

func (sd *Bark) TTS(opts *pb.TTSRequest) error {
t := C.CString(opts.Text)
defer C.free(unsafe.Pointer(t))

dst := C.CString(opts.Dst)
defer C.free(unsafe.Pointer(dst))

threads := C.int(sd.threads)

ret := C.tts(t, threads, dst)
if ret != 0 {
return fmt.Errorf("inference failed")
}

return nil
}
8 changes: 8 additions & 0 deletions backend/go/bark/gobark.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int load_model(char *model);
int tts(char *text,int threads, char *dst );
#ifdef __cplusplus
}
#endif
20 changes: 20 additions & 0 deletions backend/go/bark/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package main

// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"

grpc "github.com/mudler/LocalAI/pkg/grpc"
)

var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)

func main() {
flag.Parse()

if err := grpc.StartServer(*addr, &Bark{}); err != nil {
panic(err)
}
}
67 changes: 0 additions & 67 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"

"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
Expand Down Expand Up @@ -913,71 +911,6 @@ var _ = Describe("API test", func() {
})
})

Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))

stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()

tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))

stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()

tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))

Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})

// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {

Expand Down

0 comments on commit 58ff47d

Please sign in to comment.