diff --git a/api/api.go b/api/api.go deleted file mode 100644 index e3409effe17a..000000000000 --- a/api/api.go +++ /dev/null @@ -1,280 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "fmt" - "os" - "strings" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/localai" - "github.com/go-skynet/LocalAI/api/openai" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/metrics" - "github.com/go-skynet/LocalAI/pkg/assets" - "github.com/go-skynet/LocalAI/pkg/model" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/recover" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { - options := options.NewOptions(opts...) - - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) - log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - - cl := config.NewConfigLoader() - if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error loading config files: %s", err.Error()) - } - - if options.ConfigFile != "" { - if err := cl.LoadConfigFile(options.ConfigFile); err != nil { - log.Error().Msgf("error loading config file: %s", err.Error()) - } - } - - if err := cl.Preload(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error downloading models: %s", err.Error()) - } - - if options.Debug { - for _, v := range cl.ListConfigs() { - cfg, _ := cl.GetConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) - } - } - - if options.AssetsDestination != "" { - // Extract files from the embedded FS - err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) - if err != nil { - log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) - } - } - - if options.PreloadJSONModels != "" { - if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - if options.PreloadModelsFromPath != "" { - if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - // turn off any process that was started by GRPC if the context is canceled - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - options.Loader.StopAllGRPC() - }() - - if options.WatchDog { - wd := model.NewWatchDog( - options.Loader, - options.WatchDogBusyTimeout, - options.WatchDogIdleTimeout, - options.WatchDogBusy, - options.WatchDogIdle) - options.Loader.SetWatchDog(wd) - go wd.Run() - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - wd.Shutdown() - }() - } - - return options, cl, nil -} - -func App(opts ...options.AppOption) (*fiber.App, error) { - - options, cl, err := Startup(opts...) - if err != nil { - return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) - } - - // Return errors as JSON responses - app := fiber.New(fiber.Config{ - BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.DisableMessage, - // Override default error handler - ErrorHandler: func(ctx *fiber.Ctx, err error) error { - // Status code defaults to 500 - code := fiber.StatusInternalServerError - - // Retrieve the custom status code if it's a *fiber.Error - var e *fiber.Error - if errors.As(err, &e) { - code = e.Code - } - - // Send custom error page - return ctx.Status(code).JSON( - schema.ErrorResponse{ - Error: &schema.APIError{Message: err.Error(), Code: code}, - }, - ) - }, - }) - - if options.Debug { - app.Use(logger.New(logger.Config{ - Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", - })) - } - - // Default middleware config - app.Use(recover.New()) - if options.Metrics != nil { - app.Use(metrics.APIMiddleware(options.Metrics)) - } - - // Auth middleware checking if API key is valid. If no API key is set, no auth is required. - auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) == 0 { - return c.Next() - } - - // Check for api_keys.json file - fileContent, err := os.ReadFile("api_keys.json") - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) - } - - // Add file keys to options.ApiKeys - options.ApiKeys = append(options.ApiKeys, fileKeys...) - } - - if len(options.ApiKeys) == 0 { - return c.Next() - } - - authHeader := c.Get("Authorization") - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) - } - - apiKey := authHeaderParts[1] - for _, key := range options.ApiKeys { - if apiKey == key { - return c.Next() - } - } - - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) - - } - - if options.CORS { - var c func(ctx *fiber.Ctx) error - if options.CORSAllowOrigins == "" { - c = cors.New() - } else { - c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) - } - - app.Use(c) - } - - // LocalAI API endpoints - galleryService := localai.NewGalleryService(options.Loader.ModelPath) - galleryService.Start(options.Context, cl) - - app.Get("/version", auth, func(c *fiber.Ctx) error { - return c.JSON(struct { - Version string `json:"version"` - }{Version: internal.PrintableVersion()}) - }) - - modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) - - // openAI compatible API endpoint - - // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) - - // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) - app.Post("/edits", auth, openai.EditEndpoint(cl, options)) - - // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) - - // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - - // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) - - // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) - - if options.ImageDir != "" { - app.Static("/generated-images", options.ImageDir) - } - - if options.AudioDir != "" { - app.Static("/generated-audio", options.AudioDir) - } - - ok := func(c *fiber.Ctx) error { - return c.SendStatus(200) - } - - // Kubernetes health checks - app.Get("/healthz", ok) - app.Get("/readyz", ok) - - // Experimental Backend Statistics Module - backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now - app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) - app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) - - // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) - app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) - - app.Get("/metrics", metrics.MetricsHandler()) - - return app, nil -}