From 03b1cf51fddc958848fdd9929530d10af58a8dff Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 24 Jun 2024 19:21:22 +0200 Subject: [PATCH] feat(whisper): add translate option (#2649) Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 1 + backend/go/transcribe/transcript.go | 6 +++++- backend/go/transcribe/whisper.go | 2 +- core/backend/transcript.go | 9 +++++---- core/cli/transcript.go | 3 ++- core/http/endpoints/openai/transcription.go | 2 +- core/schema/prediction.go | 3 +++ 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index aec0c00e74e4..0d3d5f7f795e 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -230,6 +230,7 @@ message TranscriptRequest { string dst = 2; string language = 3; uint32 threads = 4; + bool translate = 5; } message TranscriptResult { diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index 9b3da01cd284..6831167f3aac 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -29,7 +29,7 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { +func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) { res := schema.TranscriptionResult{} dir, err := os.MkdirTemp("", "whisper") @@ -75,6 +75,10 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( context.SetLanguage("auto") } + if translate { + context.SetTranslate(true) + } + if err := context.Process(data, nil, nil); err != nil { return res, err } diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index 7ed60c821266..61ae98e943de 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -22,5 +22,5 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error { } func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { - return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) + return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads)) } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 1807f5101058..0980288f644c 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -11,7 +11,7 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { +func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), @@ -31,8 +31,9 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo } return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: audio, - Language: language, - Threads: uint32(*backendConfig.Threads), + Dst: audio, + Language: language, + Translate: translate, + Threads: uint32(*backendConfig.Threads), }) } diff --git a/core/cli/transcript.go b/core/cli/transcript.go index 52381741f73a..fd78557fd9cd 100644 --- a/core/cli/transcript.go +++ b/core/cli/transcript.go @@ -18,6 +18,7 @@ type TranscriptCMD struct { Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"` Model string `short:"m" required:"" help:"Model name to run the TTS"` Language string `short:"l" help:"Language of the audio file"` + Translate bool `short:"t" help:"Translate the transcription to english"` Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` @@ -50,7 +51,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { } }() - tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts) + tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts) if err != nil { return err } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index eddcc6fcda8c..c8e447f79cb0 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -65,7 +65,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) + tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig) if err != nil { return err } diff --git a/core/schema/prediction.go b/core/schema/prediction.go index 7e509167d5fa..8ad5692806d0 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -8,6 +8,9 @@ type PredictionOptions struct { // Also part of the OpenAI official spec Language string `json:"language"` + // Only for audio transcription + Translate bool `json:"translate"` + // Also part of the OpenAI official spec. use it for returning multiple results N int `json:"n"`