From c25b7d9c81567f0163578cca7f3a5d8a6d38e00d Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Mon, 26 Feb 2024 15:18:09 -0500 Subject: [PATCH] MASS RENAME: o / options => appConfig where appropriate --- core/backend/embeddings.go | 26 +++--- core/backend/image.go | 44 +++++----- core/backend/transcript.go | 14 +-- core/backend/tts.go | 18 ++-- core/config/application_config.go | 1 - core/http/api.go | 96 ++++++++++----------- core/http/endpoints/localai/tts.go | 6 +- core/http/endpoints/openai/completion.go | 10 +-- core/http/endpoints/openai/edit.go | 8 +- core/http/endpoints/openai/files.go | 24 +++--- core/http/endpoints/openai/files_test.go | 8 +- core/http/endpoints/openai/image.go | 12 +-- core/http/endpoints/openai/transcription.go | 8 +- core/services/backend_monitor.go | 4 +- core/startup/config_file_watcher.go | 14 +-- 15 files changed, 146 insertions(+), 147 deletions(-) diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 5501bf11bd14..0a74ea4cad55 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -9,30 +9,30 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (func() ([]float32, error), error) { - if !c.Embeddings { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { + if !backendConfig.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } - modelFile := c.Model + modelFile := backendConfig.Model - grpcOpts := gRPCModelOpts(c) + grpcOpts := gRPCModelOpts(backendConfig) var inferenceModel interface{} var err error - opts := modelOpts(c, o, []model.Option{ + opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithAssetDir(appConfig.AssetsDestination), model.WithModel(modelFile), - model.WithContext(o.Context), + model.WithContext(appConfig.Context), }) - if c.Backend == "" { + if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { - opts = append(opts, model.WithBackendString(c.Backend)) + opts = append(opts, model.WithBackendString(backendConfig.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { @@ -43,7 +43,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. switch model := inferenceModel.(type) { case grpc.Backend: fn = func() ([]float32, error) { - predictOptions := gRPCPredictOpts(c, loader.ModelPath) + predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} @@ -52,7 +52,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.EmbeddingTokens = embeds - res, err := model.Embeddings(o.Context, predictOptions) + res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.Embeddings = s - res, err := model.Embeddings(o.Context, predictOptions) + res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } diff --git a/core/backend/image.go b/core/backend/image.go index 69250fd20ce2..60db48f96ba1 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -7,27 +7,27 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (func() error, error) { +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(c.Backend), - model.WithAssetDir(o.AssetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithContext(o.Context), - model.WithModel(c.Model), + opts := modelOpts(backendConfig, appConfig, []model.Option{ + model.WithBackendString(backendConfig.Backend), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithContext(appConfig.Context), + model.WithModel(backendConfig.Model), model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ - CUDA: c.CUDA || c.Diffusers.CUDA, - SchedulerType: c.Diffusers.SchedulerType, - PipelineType: c.Diffusers.PipelineType, - CFGScale: c.Diffusers.CFGScale, - LoraAdapter: c.LoraAdapter, - LoraScale: c.LoraScale, - LoraBase: c.LoraBase, - IMG2IMG: c.Diffusers.IMG2IMG, - CLIPModel: c.Diffusers.ClipModel, - CLIPSubfolder: c.Diffusers.ClipSubFolder, - CLIPSkip: int32(c.Diffusers.ClipSkip), - ControlNet: c.Diffusers.ControlNet, + CUDA: backendConfig.CUDA || backendConfig.Diffusers.CUDA, + SchedulerType: backendConfig.Diffusers.SchedulerType, + PipelineType: backendConfig.Diffusers.PipelineType, + CFGScale: backendConfig.Diffusers.CFGScale, + LoraAdapter: backendConfig.LoraAdapter, + LoraScale: backendConfig.LoraScale, + LoraBase: backendConfig.LoraBase, + IMG2IMG: backendConfig.Diffusers.IMG2IMG, + CLIPModel: backendConfig.Diffusers.ClipModel, + CLIPSubfolder: backendConfig.Diffusers.ClipSubFolder, + CLIPSkip: int32(backendConfig.Diffusers.ClipSkip), + ControlNet: backendConfig.Diffusers.ControlNet, }), }) @@ -40,19 +40,19 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat fn := func() error { _, err := inferenceModel.GenerateImage( - o.Context, + appConfig.Context, &proto.GenerateImageRequest{ Height: int32(height), Width: int32(width), Mode: int32(mode), Step: int32(step), Seed: int32(seed), - CLIPSkip: int32(c.Diffusers.ClipSkip), + CLIPSkip: int32(backendConfig.Diffusers.ClipSkip), PositivePrompt: positive_prompt, NegativePrompt: negative_prompt, Dst: dst, Src: src, - EnableParameters: c.Diffusers.EnableParameters, + EnableParameters: backendConfig.Diffusers.EnableParameters, }) return err } diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 2ee14c6d74e2..bbb4f4b4c309 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -11,14 +11,14 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" ) -func ModelTranscription(audio, language string, ml *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig) (*schema.Result, error) { +func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.Result, error) { - opts := modelOpts(c, o, []model.Option{ + opts := modelOpts(backendConfig, appConfig, []model.Option{ model.WithBackendString(model.WhisperBackend), - model.WithModel(c.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), + model.WithModel(backendConfig.Model), + model.WithContext(appConfig.Context), + model.WithThreads(uint32(backendConfig.Threads)), + model.WithAssetDir(appConfig.AssetsDestination), }) whisperModel, err := ml.BackendLoader(opts...) @@ -33,6 +33,6 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, c config. return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, - Threads: uint32(c.Threads), + Threads: uint32(backendConfig.Threads), }) } diff --git a/core/backend/tts.go b/core/backend/tts.go index 036c0ee7494a..85aa345771df 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -29,19 +29,19 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *config.ApplicationConfig, c config.BackendConfig) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend } - grpcOpts := gRPCModelOpts(c) + grpcOpts := gRPCModelOpts(backendConfig) - opts := modelOpts(config.BackendConfig{}, o, []model.Option{ + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), - model.WithContext(o.Context), - model.WithAssetDir(o.AssetsDestination), + model.WithContext(appConfig.Context), + model.WithAssetDir(appConfig.AssetsDestination), model.WithLoadGRPCLoadModelOpts(grpcOpts), }) piperModel, err := loader.BackendLoader(opts...) @@ -53,19 +53,19 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *con return "", nil, fmt.Errorf("could not load piper model") } - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + if err := os.MkdirAll(appConfig.AudioDir, 0755); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } - fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") - filePath := filepath.Join(o.AudioDir, fileName) + fileName := generateUniqueFileName(appConfig.AudioDir, "piper", ".wav") + filePath := filepath.Join(appConfig.AudioDir, fileName) // If the model file is not empty, we pass it joined with the model path modelPath := "" if modelFile != "" { if bb != model.TransformersMusicGen { modelPath = filepath.Join(loader.ModelPath, modelFile) - if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil { + if err := utils.VerifyPath(modelPath, appConfig.ModelPath); err != nil { return "", nil, err } } else { diff --git a/core/config/application_config.go b/core/config/application_config.go index d99f137bb14e..d90ae9064f44 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -25,7 +25,6 @@ type ApplicationConfig struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string - // Metrics *metrics.Metrics ModelLibraryURL string diff --git a/core/http/api.go b/core/http/api.go index 98ea8004a6d6..aae0dd74746f 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -21,11 +21,11 @@ import ( "github.com/gofiber/fiber/v2/middleware/recover" ) -func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config.ApplicationConfig) (*fiber.App, error) { +func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (*fiber.App, error) { // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.DisableMessage, + BodyLimit: appConfig.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: appConfig.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -46,7 +46,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. }, }) - if options.Debug { + if appConfig.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) @@ -54,7 +54,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. // Default middleware config - if !options.Debug { + if !appConfig.Debug { app.Use(recover.New()) } @@ -72,7 +72,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) == 0 { + if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -87,10 +87,10 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. } // Add file keys to options.ApiKeys - options.ApiKeys = append(options.ApiKeys, fileKeys...) + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) } - if len(options.ApiKeys) == 0 { + if len(appConfig.ApiKeys) == 0 { return c.Next() } @@ -104,7 +104,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. } apiKey := authHeaderParts[1] - for _, key := range options.ApiKeys { + for _, key := range appConfig.ApiKeys { if apiKey == key { return c.Next() } @@ -114,20 +114,20 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. } - if options.CORS { + if appConfig.CORS { var c func(ctx *fiber.Ctx) error - if options.CORSAllowOrigins == "" { + if appConfig.CORSAllowOrigins == "" { c = cors.New() } else { - c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) + c = cors.New(cors.Config{AllowOrigins: appConfig.CORSAllowOrigins}) } app.Use(c) } // LocalAI API endpoints - galleryService := services.NewGalleryService(options.ModelPath) - galleryService.Start(options.Context, cl) + galleryService := services.NewGalleryService(appConfig.ModelPath) + galleryService.Start(appConfig.Context, cl) app.Get("/version", auth, func(c *fiber.Ctx) error { return c.JSON(struct { @@ -136,15 +136,15 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. }) // Make sure directories exists - os.MkdirAll(options.ImageDir, 0755) - os.MkdirAll(options.AudioDir, 0755) - os.MkdirAll(options.UploadDir, 0755) - os.MkdirAll(options.ModelPath, 0755) + os.MkdirAll(appConfig.ImageDir, 0755) + os.MkdirAll(appConfig.AudioDir, 0755) + os.MkdirAll(appConfig.UploadDir, 0755) + os.MkdirAll(appConfig.ModelPath, 0755) // Load upload json - openai.LoadUploadConfig(options.UploadDir) + openai.LoadUploadConfig(appConfig.UploadDir) - modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService) + modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint()) app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint()) @@ -156,48 +156,48 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, appConfig)) // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options)) - app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options)) + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) + app.Post("/edits", auth, openai.EditEndpoint(cl, ml, appConfig)) // files - app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) - app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) - app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options)) - app.Get("/files", auth, openai.ListFilesEndpoint(cl, options)) - app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) - app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) - app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) - app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) - app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) - app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, appConfig)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, appConfig)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, appConfig)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, appConfig)) // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options)) + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, appConfig)) // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, appConfig)) // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options)) + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, appConfig)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, appConfig)) // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options)) + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, appConfig)) - if options.ImageDir != "" { - app.Static("/generated-images", options.ImageDir) + if appConfig.ImageDir != "" { + app.Static("/generated-images", appConfig.ImageDir) } - if options.AudioDir != "" { - app.Static("/generated-audio", options.AudioDir) + if appConfig.AudioDir != "" { + app.Static("/generated-audio", appConfig.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -209,7 +209,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, options *config. app.Get("/readyz", ok) // Experimental Backend Statistics Module - backendMonitor := services.NewBackendMonitor(cl, ml, options) // Split out for now + backendMonitor := services.NewBackendMonitor(cl, ml, appConfig) // Split out for now app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index 44694e3d2a5a..84fb7a555004 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -11,7 +11,7 @@ import ( "github.com/rs/zerolog/log" ) -func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(schema.TTSRequest) @@ -26,7 +26,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *confi modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := config.LoadBackendConfigFileByName(modelFile, o.ModelPath, cl, false, 0, 0, false) + cfg, err := config.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, cl, false, 0, 0, false) if err != nil { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) @@ -39,7 +39,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *confi cfg.Backend = input.Backend } - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, o, *cfg) + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index 7102015a1a76..9344f9fefdf6 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -21,12 +21,12 @@ import ( ) // https://platform.openai.com/docs/api-reference/completions -func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { id := uuid.New().String() created := int(time.Now().Unix()) process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := schema.OpenAIResponse{ ID: id, Created: created, @@ -53,14 +53,14 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o } return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, o, true) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } log.Debug().Msgf("`input`: %+v", input) - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -164,7 +164,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o } r, tokenUsage, err := ComputeChoices( - input, i, config, o, ml, func(s string, c *[]schema.Choice) { + input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) }, nil) if err != nil { diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index adbb6fb9fd9d..254970958278 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -16,14 +16,14 @@ import ( "github.com/rs/zerolog/log" ) -func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, ml, o, true) + modelFile, input, err := readRequest(c, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -57,7 +57,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *conf } } - r, tokenUsage, err := ComputeChoices(input, i, config, o, ml, func(s string, c *[]schema.Choice) { + r, tokenUsage, err := ComputeChoices(input, i, config, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil) if err != nil { diff --git a/core/http/endpoints/openai/files.go b/core/http/endpoints/openai/files.go index 10ee621f25af..5cb8d7a92eaf 100644 --- a/core/http/endpoints/openai/files.go +++ b/core/http/endpoints/openai/files.go @@ -62,7 +62,7 @@ func LoadUploadConfig(uploadPath string) { } // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create -func UploadFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func UploadFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := c.FormFile("file") if err != nil { @@ -70,8 +70,8 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo } // Check the file size - if file.Size > int64(o.UploadLimitMB*1024*1024) { - return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB)) + if file.Size > int64(appConfig.UploadLimitMB*1024*1024) { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, appConfig.UploadLimitMB)) } purpose := c.FormValue("purpose", "") //TODO put in purpose dirs @@ -82,7 +82,7 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo // Sanitize the filename to prevent directory traversal filename := utils.SanitizeFileName(file.Filename) - savePath := filepath.Join(o.UploadDir, filename) + savePath := filepath.Join(appConfig.UploadDir, filename) // Check if file already exists if _, err := os.Stat(savePath); !os.IsNotExist(err) { @@ -104,13 +104,13 @@ func UploadFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo } uploadedFiles = append(uploadedFiles, f) - saveUploadConfig(o.UploadDir) + saveUploadConfig(appConfig.UploadDir) return c.Status(fiber.StatusOK).JSON(f) } } // ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list -func ListFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ListFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { type ListFiles struct { Data []File Object string @@ -150,7 +150,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) { } // GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve -func GetFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func GetFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := getFileFromRequest(c) if err != nil { @@ -162,7 +162,7 @@ func GetFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfi } // DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete -func DeleteFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func DeleteFilesEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { type DeleteStatus struct { Id string Object string @@ -175,7 +175,7 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } - err = os.Remove(filepath.Join(o.UploadDir, file.Filename)) + err = os.Remove(filepath.Join(appConfig.UploadDir, file.Filename)) if err != nil { // If the file doesn't exist then we should just continue to remove it if !errors.Is(err, os.ErrNotExist) { @@ -191,7 +191,7 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo } } - saveUploadConfig(o.UploadDir) + saveUploadConfig(appConfig.UploadDir) return c.JSON(DeleteStatus{ Id: file.ID, Object: "file", @@ -201,14 +201,14 @@ func DeleteFilesEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationCo } // GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents -func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func GetFilesContentsEndpoint(cm *config.BackendConfigLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { file, err := getFileFromRequest(c) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } - fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename)) + fileContents, err := os.ReadFile(filepath.Join(appConfig.UploadDir, file.Filename)) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } diff --git a/core/http/endpoints/openai/files_test.go b/core/http/endpoints/openai/files_test.go index eaf3a40b3660..a036bd0dc2a5 100644 --- a/core/http/endpoints/openai/files_test.go +++ b/core/http/endpoints/openai/files_test.go @@ -174,9 +174,9 @@ func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*htt return app.Test(request) } -func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *config.ApplicationConfig) (*http.Response, error) { +func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) (*http.Response, error) { // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, o) + file := createTestFile(t, fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) @@ -186,9 +186,9 @@ func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpos return app.Test(req) } -func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *config.ApplicationConfig) File { +func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, appConfig *config.ApplicationConfig) File { // Create a file that exceeds the limit - file := createTestFile(t, fileName, fileSize, o) + file := createTestFile(t, fileName, fileSize, appConfig) // Creating a new HTTP Request body, writer := newMultipartFile(file.Name(), tag, purpose) diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index c715b968356a..8f535801f63f 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -59,9 +59,9 @@ func downloadFile(url string) (string, error) { * */ -func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, o, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -71,7 +71,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *con } log.Debug().Msgf("Loading model: %+v", m) - config, input, err := mergeRequestWithConfig(m, input, cl, ml, o.Debug, 0, 0, false) + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -104,7 +104,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *con } // Create a temporary file - outputFile, err := os.CreateTemp(o.ImageDir, "b64") + outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64") if err != nil { return err } @@ -179,7 +179,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *con tempDir := "" if !b64JSON { - tempDir = o.ImageDir + tempDir = appConfig.ImageDir } // Create a temporary file outputFile, err := os.CreateTemp(tempDir, "b64") @@ -196,7 +196,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *con baseURL := c.BaseURL() - fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, o) + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig) if err != nil { return err } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 7cc056b5f625..403f8b021c66 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -17,14 +17,14 @@ import ( ) // https://platform.openai.com/docs/api-reference/audio/create -func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig) func(c *fiber.Ctx) error { +func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, ml, o, false) + m, input, err := readRequest(c, ml, appConfig, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := mergeRequestWithConfig(m, input, cl, ml, o.Debug, o.Threads, o.ContextSize, o.F16) + config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -59,7 +59,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, o log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, o) + tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) if err != nil { return err } diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go index d995c9c63f3e..881767532664 100644 --- a/core/services/backend_monitor.go +++ b/core/services/backend_monitor.go @@ -21,11 +21,11 @@ type BackendMonitor struct { options *config.ApplicationConfig // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. } -func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, options *config.ApplicationConfig) BackendMonitor { +func NewBackendMonitor(configLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, appConfig *config.ApplicationConfig) BackendMonitor { return BackendMonitor{ configLoader: configLoader, modelLoader: modelLoader, - options: options, + options: appConfig, } } diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go index 577d77243f14..0c7eff2de569 100644 --- a/core/startup/config_file_watcher.go +++ b/core/startup/config_file_watcher.go @@ -14,14 +14,14 @@ import ( type WatchConfigDirectoryCloser func() error -func ReadApiKeysJson(configDir string, options *config.ApplicationConfig) error { +func ReadApiKeysJson(configDir string, appConfig *config.ApplicationConfig) error { fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json")) if err == nil { // Parse JSON content from the file var fileKeys []string err := json.Unmarshal(fileContent, &fileKeys) if err == nil { - options.ApiKeys = append(options.ApiKeys, fileKeys...) + appConfig.ApiKeys = append(appConfig.ApiKeys, fileKeys...) return nil } return err @@ -29,7 +29,7 @@ func ReadApiKeysJson(configDir string, options *config.ApplicationConfig) error return err } -func ReadExternalBackendsJson(configDir string, options *config.ApplicationConfig) error { +func ReadExternalBackendsJson(configDir string, appConfig *config.ApplicationConfig) error { fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json")) if err != nil { return err @@ -40,19 +40,19 @@ func ReadExternalBackendsJson(configDir string, options *config.ApplicationConfi if err != nil { return err } - err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends) + err = mergo.Merge(&appConfig.ExternalGRPCBackends, fileBackends) if err != nil { return err } return nil } -var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *config.ApplicationConfig) error{ +var CONFIG_FILE_UPDATES = map[string]func(configDir string, appConfig *config.ApplicationConfig) error{ "api_keys.json": ReadApiKeysJson, "external_backends.json": ReadExternalBackendsJson, } -func WatchConfigDirectory(configDir string, options *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) { +func WatchConfigDirectory(configDir string, appConfig *config.ApplicationConfig) (WatchConfigDirectoryCloser, error) { if len(configDir) == 0 { return nil, fmt.Errorf("configDir blank") } @@ -76,7 +76,7 @@ func WatchConfigDirectory(configDir string, options *config.ApplicationConfig) ( if event.Has(fsnotify.Write) { for targetName, watchFn := range CONFIG_FILE_UPDATES { if event.Name == targetName { - err := watchFn(configDir, options) + err := watchFn(configDir, appConfig) log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err) } }