diff --git a/api/api.go b/api/api.go deleted file mode 100644 index e3409effe17a..000000000000 --- a/api/api.go +++ /dev/null @@ -1,280 +0,0 @@ -package api - -import ( - "encoding/json" - "errors" - "fmt" - "os" - "strings" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/localai" - "github.com/go-skynet/LocalAI/api/openai" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/metrics" - "github.com/go-skynet/LocalAI/pkg/assets" - "github.com/go-skynet/LocalAI/pkg/model" - - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" - "github.com/gofiber/fiber/v2/middleware/logger" - "github.com/gofiber/fiber/v2/middleware/recover" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { - options := options.NewOptions(opts...) - - zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.Debug { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } - - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) - log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - - cl := config.NewConfigLoader() - if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error loading config files: %s", err.Error()) - } - - if options.ConfigFile != "" { - if err := cl.LoadConfigFile(options.ConfigFile); err != nil { - log.Error().Msgf("error loading config file: %s", err.Error()) - } - } - - if err := cl.Preload(options.Loader.ModelPath); err != nil { - log.Error().Msgf("error downloading models: %s", err.Error()) - } - - if options.Debug { - for _, v := range cl.ListConfigs() { - cfg, _ := cl.GetConfig(v) - log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) - } - } - - if options.AssetsDestination != "" { - // Extract files from the embedded FS - err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) - if err != nil { - log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) - } - } - - if options.PreloadJSONModels != "" { - if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - if options.PreloadModelsFromPath != "" { - if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { - return nil, nil, err - } - } - - // turn off any process that was started by GRPC if the context is canceled - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - options.Loader.StopAllGRPC() - }() - - if options.WatchDog { - wd := model.NewWatchDog( - options.Loader, - options.WatchDogBusyTimeout, - options.WatchDogIdleTimeout, - options.WatchDogBusy, - options.WatchDogIdle) - options.Loader.SetWatchDog(wd) - go wd.Run() - go func() { - <-options.Context.Done() - log.Debug().Msgf("Context canceled, shutting down") - wd.Shutdown() - }() - } - - return options, cl, nil -} - -func App(opts ...options.AppOption) (*fiber.App, error) { - - options, cl, err := Startup(opts...) - if err != nil { - return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) - } - - // Return errors as JSON responses - app := fiber.New(fiber.Config{ - BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.DisableMessage, - // Override default error handler - ErrorHandler: func(ctx *fiber.Ctx, err error) error { - // Status code defaults to 500 - code := fiber.StatusInternalServerError - - // Retrieve the custom status code if it's a *fiber.Error - var e *fiber.Error - if errors.As(err, &e) { - code = e.Code - } - - // Send custom error page - return ctx.Status(code).JSON( - schema.ErrorResponse{ - Error: &schema.APIError{Message: err.Error(), Code: code}, - }, - ) - }, - }) - - if options.Debug { - app.Use(logger.New(logger.Config{ - Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", - })) - } - - // Default middleware config - app.Use(recover.New()) - if options.Metrics != nil { - app.Use(metrics.APIMiddleware(options.Metrics)) - } - - // Auth middleware checking if API key is valid. If no API key is set, no auth is required. - auth := func(c *fiber.Ctx) error { - if len(options.ApiKeys) == 0 { - return c.Next() - } - - // Check for api_keys.json file - fileContent, err := os.ReadFile("api_keys.json") - if err == nil { - // Parse JSON content from the file - var fileKeys []string - err := json.Unmarshal(fileContent, &fileKeys) - if err != nil { - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) - } - - // Add file keys to options.ApiKeys - options.ApiKeys = append(options.ApiKeys, fileKeys...) - } - - if len(options.ApiKeys) == 0 { - return c.Next() - } - - authHeader := c.Get("Authorization") - if authHeader == "" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) - } - authHeaderParts := strings.Split(authHeader, " ") - if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) - } - - apiKey := authHeaderParts[1] - for _, key := range options.ApiKeys { - if apiKey == key { - return c.Next() - } - } - - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) - - } - - if options.CORS { - var c func(ctx *fiber.Ctx) error - if options.CORSAllowOrigins == "" { - c = cors.New() - } else { - c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) - } - - app.Use(c) - } - - // LocalAI API endpoints - galleryService := localai.NewGalleryService(options.Loader.ModelPath) - galleryService.Start(options.Context, cl) - - app.Get("/version", auth, func(c *fiber.Ctx) error { - return c.JSON(struct { - Version string `json:"version"` - }{Version: internal.PrintableVersion()}) - }) - - modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) - app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) - app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) - app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) - app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) - app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) - app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) - app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) - - // openAI compatible API endpoint - - // chat - app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) - app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) - - // edit - app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) - app.Post("/edits", auth, openai.EditEndpoint(cl, options)) - - // completion - app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) - app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) - - // embeddings - app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) - - // audio - app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) - app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) - - // images - app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) - - if options.ImageDir != "" { - app.Static("/generated-images", options.ImageDir) - } - - if options.AudioDir != "" { - app.Static("/generated-audio", options.AudioDir) - } - - ok := func(c *fiber.Ctx) error { - return c.SendStatus(200) - } - - // Kubernetes health checks - app.Get("/healthz", ok) - app.Get("/readyz", ok) - - // Experimental Backend Statistics Module - backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now - app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) - app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) - - // models - app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) - app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) - - app.Get("/metrics", metrics.MetricsHandler()) - - return app, nil -} diff --git a/api/backend/image.go b/api/backend/image.go deleted file mode 100644 index 6183269fd3ca..000000000000 --- a/api/backend/image.go +++ /dev/null @@ -1,61 +0,0 @@ -package backend - -import ( - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" -) - -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { - - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(c.Backend), - model.WithAssetDir(o.AssetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithContext(o.Context), - model.WithModel(c.Model), - model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ - CUDA: c.CUDA || c.Diffusers.CUDA, - SchedulerType: c.Diffusers.SchedulerType, - PipelineType: c.Diffusers.PipelineType, - CFGScale: c.Diffusers.CFGScale, - LoraAdapter: c.LoraAdapter, - LoraScale: c.LoraScale, - LoraBase: c.LoraBase, - IMG2IMG: c.Diffusers.IMG2IMG, - CLIPModel: c.Diffusers.ClipModel, - CLIPSubfolder: c.Diffusers.ClipSubFolder, - CLIPSkip: int32(c.Diffusers.ClipSkip), - ControlNet: c.Diffusers.ControlNet, - }), - }) - - inferenceModel, err := loader.BackendLoader( - opts..., - ) - if err != nil { - return nil, err - } - - fn := func() error { - _, err := inferenceModel.GenerateImage( - o.Context, - &proto.GenerateImageRequest{ - Height: int32(height), - Width: int32(width), - Mode: int32(mode), - Step: int32(step), - Seed: int32(seed), - CLIPSkip: int32(c.Diffusers.ClipSkip), - PositivePrompt: positive_prompt, - NegativePrompt: negative_prompt, - Dst: dst, - Src: src, - EnableParameters: c.Diffusers.EnableParameters, - }) - return err - } - - return fn, nil -} diff --git a/api/backend/llm.go b/api/backend/llm.go deleted file mode 100644 index 62eef4d8f1fa..000000000000 --- a/api/backend/llm.go +++ /dev/null @@ -1,164 +0,0 @@ -package backend - -import ( - "context" - "os" - "regexp" - "strings" - "sync" - "unicode/utf8" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/grpc" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" -) - -type LLMResponse struct { - Response string // should this be []byte? - Usage TokenUsage -} - -type TokenUsage struct { - Prompt int - Completion int -} - -func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel *grpc.Client - var err error - - opts := modelOpts(c, o, []model.Option{ - model.WithLoadGRPCLoadModelOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup - model.WithAssetDir(o.AssetsDestination), - model.WithModel(modelFile), - model.WithContext(o.Context), - }) - - if c.Backend != "" { - opts = append(opts, model.WithBackendString(c.Backend)) - } - - // Check if the modelFile exists, if it doesn't try to load it from the gallery - if o.AutoloadGalleries { // experimental - if _, err := os.Stat(modelFile); os.IsNotExist(err) { - utils.ResetDownloadTimers() - // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) - if err != nil { - return nil, err - } - } - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - inferenceModel, err = loader.BackendLoader(opts...) - } - - if err != nil { - return nil, err - } - - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - fn := func() (LLMResponse, error) { - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - opts.Images = images - - tokenUsage := TokenUsage{} - - // check the per-model feature flag for usage, since tokenCallback may have a cost. - // Defaults to off as for now it is still experimental - if c.FeatureFlag.Enabled("usage") { - userTokenCallback := tokenCallback - if userTokenCallback == nil { - userTokenCallback = func(token string, usage TokenUsage) bool { - return true - } - } - - promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) - if pErr == nil && promptInfo.Length > 0 { - tokenUsage.Prompt = int(promptInfo.Length) - } - - tokenCallback = func(token string, usage TokenUsage) bool { - tokenUsage.Completion++ - return userTokenCallback(token, tokenUsage) - } - } - - if tokenCallback != nil { - ss := "" - - var partialRune []byte - err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { - partialRune = append(partialRune, chars...) - - for len(partialRune) > 0 { - r, size := utf8.DecodeRune(partialRune) - if r == utf8.RuneError { - // incomplete rune, wait for more bytes - break - } - - tokenCallback(string(r), tokenUsage) - ss += string(r) - - partialRune = partialRune[size:] - } - }) - return LLMResponse{ - Response: ss, - Usage: tokenUsage, - }, err - } else { - // TODO: Is the chicken bit the only way to get here? is that acceptable? - reply, err := inferenceModel.Predict(ctx, opts) - if err != nil { - return LLMResponse{}, err - } - return LLMResponse{ - Response: string(reply.Message), - Usage: tokenUsage, - }, err - } - } - - return fn, nil -} - -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} - -func Finetune(config config.Config, input, prediction string) string { - if config.Echo { - prediction = input + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - return prediction - -} diff --git a/api/backend/transcript.go b/api/backend/transcript.go deleted file mode 100644 index 77427839992a..000000000000 --- a/api/backend/transcript.go +++ /dev/null @@ -1,39 +0,0 @@ -package backend - -import ( - "context" - "fmt" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" - - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" -) - -func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*schema.Result, error) { - - opts := modelOpts(c, o, []model.Option{ - model.WithBackendString(model.WhisperBackend), - model.WithModel(c.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.AssetsDestination), - }) - - whisperModel, err := o.Loader.BackendLoader(opts...) - if err != nil { - return nil, err - } - - if whisperModel == nil { - return nil, fmt.Errorf("could not load whisper model") - } - - return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: audio, - Language: language, - Threads: uint32(c.Threads), - }) -} diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go deleted file mode 100644 index 8cb0bb45ed14..000000000000 --- a/api/localai/backend_monitor.go +++ /dev/null @@ -1,162 +0,0 @@ -package localai - -import ( - "context" - "fmt" - "strings" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - - "github.com/go-skynet/LocalAI/api/options" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - - gopsutil "github.com/shirou/gopsutil/v3/process" -) - -type BackendMonitorRequest struct { - Model string `json:"model" yaml:"model"` -} - -type BackendMonitorResponse struct { - MemoryInfo *gopsutil.MemoryInfoStat - MemoryPercent float32 - CPUPercent float64 -} - -type BackendMonitor struct { - configLoader *config.ConfigLoader - options *options.Option // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. -} - -func NewBackendMonitor(configLoader *config.ConfigLoader, options *options.Option) BackendMonitor { - return BackendMonitor{ - configLoader: configLoader, - options: options, - } -} - -func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*BackendMonitorResponse, error) { - config, exists := bm.configLoader.GetConfig(model) - var backend string - if exists { - backend = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backend = model - } - - if !strings.HasSuffix(backend, ".bin") { - backend = fmt.Sprintf("%s.bin", backend) - } - - pid, err := bm.options.Loader.GetGRPCPID(backend) - - if err != nil { - log.Error().Msgf("model %s : failed to find pid %+v", model, err) - return nil, err - } - - // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. - backendProcess, err := gopsutil.NewProcess(int32(pid)) - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) - return nil, err - } - - memInfo, err := backendProcess.MemoryInfo() - - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) - return nil, err - } - - memPercent, err := backendProcess.MemoryPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) - return nil, err - } - - cpuPercent, err := backendProcess.CPUPercent() - if err != nil { - log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) - return nil, err - } - - return &BackendMonitorResponse{ - MemoryInfo: memInfo, - MemoryPercent: memPercent, - CPUPercent: cpuPercent, - }, nil -} - -func (bm BackendMonitor) getModelLoaderIDFromCtx(c *fiber.Ctx) (string, error) { - input := new(BackendMonitorRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", err - } - - config, exists := bm.configLoader.GetConfig(input.Model) - var backendId string - if exists { - backendId = config.Model - } else { - // Last ditch effort: use it raw, see if a backend happens to match. - backendId = input.Model - } - - if !strings.HasSuffix(backendId, ".bin") { - backendId = fmt.Sprintf("%s.bin", backendId) - } - - return backendId, nil -} - -func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - backendId, err := bm.getModelLoaderIDFromCtx(c) - if err != nil { - return err - } - - model := bm.options.Loader.CheckIsLoaded(backendId) - if model == "" { - return fmt.Errorf("backend %s is not currently loaded", backendId) - } - - status, rpcErr := model.GRPC(false, nil).Status(context.TODO()) - if rpcErr != nil { - log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) - val, slbErr := bm.SampleLocalBackendProcess(backendId) - if slbErr != nil { - return fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) - } - return c.JSON(proto.StatusResponse{ - State: proto.StatusResponse_ERROR, - Memory: &proto.MemoryUsageData{ - Total: val.MemoryInfo.VMS, - Breakdown: map[string]uint64{ - "gopsutil-RSS": val.MemoryInfo.RSS, - }, - }, - }) - } - - return c.JSON(status) - } -} - -func BackendShutdownEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - backendId, err := bm.getModelLoaderIDFromCtx(c) - if err != nil { - return err - } - - return bm.options.Loader.ShutdownModel(backendId) - } -} diff --git a/api/localai/gallery.go b/api/localai/gallery.go deleted file mode 100644 index a1cab6dc3edd..000000000000 --- a/api/localai/gallery.go +++ /dev/null @@ -1,320 +0,0 @@ -package localai - -import ( - "context" - "fmt" - "os" - "slices" - "strings" - "sync" - - json "github.com/json-iterator/go" - "gopkg.in/yaml.v3" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/pkg/gallery" - "github.com/go-skynet/LocalAI/pkg/utils" - - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -type galleryOp struct { - req gallery.GalleryModel - id string - galleries []gallery.Gallery - galleryName string -} - -type galleryOpStatus struct { - FileName string `json:"file_name"` - Error error `json:"error"` - Processed bool `json:"processed"` - Message string `json:"message"` - Progress float64 `json:"progress"` - TotalFileSize string `json:"file_size"` - DownloadedFileSize string `json:"downloaded_size"` -} - -type galleryApplier struct { - modelPath string - sync.Mutex - C chan galleryOp - statuses map[string]*galleryOpStatus -} - -func NewGalleryService(modelPath string) *galleryApplier { - return &galleryApplier{ - modelPath: modelPath, - C: make(chan galleryOp), - statuses: make(map[string]*galleryOpStatus), - } -} - -func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { - - config, err := gallery.GetGalleryConfigFromURL(req.URL) - if err != nil { - return err - } - - config.Files = append(config.Files, req.AdditionalFiles...) - - return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) -} - -func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) { - g.Lock() - defer g.Unlock() - g.statuses[s] = op -} - -func (g *galleryApplier) getStatus(s string) *galleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses[s] -} - -func (g *galleryApplier) getAllStatus() map[string]*galleryOpStatus { - g.Lock() - defer g.Unlock() - - return g.statuses -} - -func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { - go func() { - for { - select { - case <-c.Done(): - return - case op := <-g.C: - utils.ResetDownloadTimers() - - g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) - - // updates the status with an error - updateError := func(e error) { - g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) - } - - // displayDownload displays the download progress - progressCallback := func(fileName string, current string, total string, percentage float64) { - g.updateStatus(op.id, &galleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - utils.DisplayDownloadFunction(fileName, current, total, percentage) - } - - var err error - // if the request contains a gallery name, we apply the gallery from the gallery list - if op.galleryName != "" { - if strings.Contains(op.galleryName, "@") { - err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) - } else { - err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) - } - } else { - err = prepareModel(g.modelPath, op.req, cm, progressCallback) - } - - if err != nil { - updateError(err) - continue - } - - // Reload models - err = cm.LoadConfigs(g.modelPath) - if err != nil { - updateError(err) - continue - } - - g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) - } - } - }() -} - -type galleryModel struct { - gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 - ID string `json:"id"` -} - -func processRequests(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { - var err error - for _, r := range requests { - utils.ResetDownloadTimers() - if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) - } else { - if strings.Contains(r.ID, "@") { - err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } else { - err = gallery.InstallModelFromGalleryByName( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } - } - } - return err -} - -func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { - dat, err := os.ReadFile(s) - if err != nil { - return err - } - var requests []galleryModel - - if err := yaml.Unmarshal(dat, &requests); err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} - -func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { - var requests []galleryModel - err := json.Unmarshal([]byte(s), &requests) - if err != nil { - return err - } - - return processRequests(modelPath, s, cm, galleries, requests) -} - -/// Endpoint Service - -type ModelGalleryService struct { - galleries []gallery.Gallery - modelPath string - galleryApplier *galleryApplier -} - -type GalleryModel struct { - ID string `json:"id"` - gallery.GalleryModel -} - -func CreateModelGalleryService(galleries []gallery.Gallery, modelPath string, galleryApplier *galleryApplier) ModelGalleryService { - return ModelGalleryService{ - galleries: galleries, - modelPath: modelPath, - galleryApplier: galleryApplier, - } -} - -func (mgs *ModelGalleryService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - status := mgs.galleryApplier.getStatus(c.Params("uuid")) - if status == nil { - return fmt.Errorf("could not find any status for ID") - } - return c.JSON(status) - } -} - -func (mgs *ModelGalleryService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - return c.JSON(mgs.galleryApplier.getAllStatus()) - } -} - -func (mgs *ModelGalleryService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(GalleryModel) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - uuid, err := uuid.NewUUID() - if err != nil { - return err - } - mgs.galleryApplier.C <- galleryOp{ - req: input.GalleryModel, - id: uuid.String(), - galleryName: input.ID, - galleries: mgs.galleries, - } - return c.JSON(struct { - ID string `json:"uuid"` - StatusURL string `json:"status"` - }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) - } -} - -func (mgs *ModelGalleryService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) - - models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) - if err != nil { - return err - } - log.Debug().Msgf("Models found from galleries: %+v", models) - for _, m := range models { - log.Debug().Msgf("Model found from galleries: %+v", m) - } - dat, err := json.Marshal(models) - if err != nil { - return err - } - return c.Send(dat) - } -} - -// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! -func (mgs *ModelGalleryService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - return c.Send(dat) - } -} - -func (mgs *ModelGalleryService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s already exists", input.Name) - } - dat, err := json.Marshal(mgs.galleries) - if err != nil { - return err - } - log.Debug().Msgf("Adding %+v to gallery list", *input) - mgs.galleries = append(mgs.galleries, *input) - return c.Send(dat) - } -} - -func (mgs *ModelGalleryService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - input := new(gallery.Gallery) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) { - return fmt.Errorf("%s is not currently registered", input.Name) - } - mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { - return gallery.Name == input.Name - }) - return c.Send(nil) - } -} diff --git a/api/localai/localai.go b/api/localai/localai.go deleted file mode 100644 index c9aee2ae5c34..000000000000 --- a/api/localai/localai.go +++ /dev/null @@ -1,32 +0,0 @@ -package localai - -import ( - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - - "github.com/go-skynet/LocalAI/api/options" - "github.com/gofiber/fiber/v2" -) - -type TTSRequest struct { - Model string `json:"model" yaml:"model"` - Input string `json:"input" yaml:"input"` - Backend string `json:"backend" yaml:"backend"` -} - -func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - input := new(TTSRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, o.Loader, o) - if err != nil { - return err - } - return c.Download(filePath) - } -} diff --git a/api/openai/chat.go b/api/openai/chat.go deleted file mode 100644 index 02bf6149499e..000000000000 --- a/api/openai/chat.go +++ /dev/null @@ -1,399 +0,0 @@ -package openai - -import ( - "bufio" - "bytes" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - emptyMessage := "" - id := uuid.New().String() - created := int(time.Now().Unix()) - - process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - - responses <- resp - return true - }) - close(responses) - } - return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} - modelFile, input, err := readInput(c, o, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if config.FunctionsConfig.NoActionFunctionName != "" { - noActionName = config.FunctionsConfig.NoActionFunctionName - } - if config.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = config.FunctionsConfig.NoActionDescriptionName - } - - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } - - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && config.ShouldUseFunctions() { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.FunctionToCall() != "" { - funcs = funcs.Select(config.FunctionToCall()) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") - } - - // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - suppressConfigSystemPrompt := false - mess := []string{} - for messageIndex, i := range input.Messages { - var content string - role := i.Role - - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if i.FunctionCall != nil && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && i.StringContent != "" - // First attempt to populate content via a chat message specific template - if config.TemplateConfig.ChatMessage != "" { - chatMessageData := model.ChatMessageTemplateData{ - SystemPrompt: config.SystemPrompt, - Role: r, - RoleName: role, - Content: i.StringContent, - MessageIndex: messageIndex, - } - templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) - if err != nil { - log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) - } else { - if templatedChatMessage == "" { - log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) - continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf - } - log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) - content = templatedChatMessage - } - } - // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. - if content == "" { - if r != "" { - if contentExists { - content = fmt.Sprint(r, i.StringContent) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - } else { - if contentExists { - content = fmt.Sprint(i.StringContent) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - } - // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately - if contentExists && role == "system" { - suppressConfigSystemPrompt = true - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - if toStream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - SuppressSystemPrompt: suppressConfigSystemPrompt, - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - - if toStream { - responses := make(chan schema.OpenAIResponse) - - go process(predInput, input, config, o.Loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - usage := &schema.OpenAIUsage{} - - for ev := range responses { - usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - log.Debug().Msgf("Sending chunk: %s", buf.String()) - _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) - if err != nil { - log.Debug().Msgf("Sending chunk failed: %v", err) - input.Cancel() - break - } - w.Flush() - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - FinishReason: "stop", - Index: 0, - Delta: &schema.Message{Content: &emptyMessage}, - }}, - Object: "chat.completion.chunk", - Usage: *usage, - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s = utils.EscapeNewLines(s) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = backend.Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) - return - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) - } else { - // otherwise reply with the function call - *c = append(*c, schema.Choice{ - FinishReason: "function_call", - Message: &schema.Message{Role: "assistant", FunctionCall: ss}, - }) - } - - return - } - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, - } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/api/openai/completion.go b/api/openai/completion.go deleted file mode 100644 index c0607632b93b..000000000000 --- a/api/openai/completion.go +++ /dev/null @@ -1,199 +0,0 @@ -package openai - -import ( - "bufio" - "bytes" - "encoding/json" - "errors" - "fmt" - "time" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -// https://platform.openai.com/docs/api-reference/completions -func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - id := uuid.New().String() - created := int(time.Now().Unix()) - - process := func(s string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { - ComputeChoices(req, s, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { - resp := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: usage.Prompt, - CompletionTokens: usage.Completion, - TotalTokens: usage.Prompt + usage.Completion, - }, - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - - return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, o, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("`input`: %+v", input) - - config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if input.ResponseFormat.Type == "json_object" { - input.Grammar = grammar.JSONBNF - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - - if input.Stream { - if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - } - - predInput := config.PromptStrings[0] - - if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - } - - responses := make(chan schema.OpenAIResponse) - - go process(predInput, input, config, o.Loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []schema.Choice{ - { - Index: 0, - FinishReason: "stop", - }, - }, - Object: "text_completion", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - var result []schema.Choice - - totalTokenUsage := backend.TokenUsage{} - - for k, i := range config.PromptStrings { - if templateFile != "" { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{ - SystemPrompt: config.SystemPrompt, - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - } - - r, tokenUsage, err := ComputeChoices( - input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k}) - }, nil) - if err != nil { - return err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) - } - - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/api/openai/edit.go b/api/openai/edit.go deleted file mode 100644 index 888b9db7ffd4..000000000000 --- a/api/openai/edit.go +++ /dev/null @@ -1,94 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - "time" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - - "github.com/rs/zerolog/log" -) - -func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - modelFile, input, err := readInput(c, o, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(modelFile, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - templateFile := "" - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { - templateFile = config.Model - } - - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - - var result []schema.Choice - totalTokenUsage := backend.TokenUsage{} - - for _, i := range config.InputStrings { - if templateFile != "" { - templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{ - Input: i, - Instruction: input.Instruction, - SystemPrompt: config.SystemPrompt, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - } - - r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) { - *c = append(*c, schema.Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - totalTokenUsage.Prompt += tokenUsage.Prompt - totalTokenUsage.Completion += tokenUsage.Completion - - result = append(result, r...) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - Usage: schema.OpenAIUsage{ - PromptTokens: totalTokenUsage.Prompt, - CompletionTokens: totalTokenUsage.Completion, - TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, - }, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go deleted file mode 100644 index 15e31e92c6eb..000000000000 --- a/api/openai/embeddings.go +++ /dev/null @@ -1,78 +0,0 @@ -package openai - -import ( - "encoding/json" - "fmt" - "time" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" - "github.com/google/uuid" - - "github.com/go-skynet/LocalAI/api/options" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -// https://platform.openai.com/docs/api-reference/embeddings -func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - items := []schema.Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/api/openai/image.go b/api/openai/image.go deleted file mode 100644 index 3e4bc349af3a..000000000000 --- a/api/openai/image.go +++ /dev/null @@ -1,239 +0,0 @@ -package openai - -import ( - "bufio" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/go-skynet/LocalAI/api/schema" - "github.com/google/uuid" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -func downloadFile(url string) (string, error) { - // Get the data - resp, err := http.Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // Create the file - out, err := os.CreateTemp("", "image") - if err != nil { - return "", err - } - defer out.Close() - - // Write the body to file - _, err = io.Copy(out, resp.Body) - return out.Name(), err -} - -// https://platform.openai.com/docs/api-reference/images/create - -/* -* - - curl http://localhost:8080/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cute baby sea otter", - "n": 1, - "size": "512x512" - }' - -* -*/ -func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if m == "" { - m = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", m) - - config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - src := "" - if input.File != "" { - - fileData := []byte{} - // check if input.File is an URL, if so download it and save it - // to a temporary file - if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { - out, err := downloadFile(input.File) - if err != nil { - return fmt.Errorf("failed downloading file:%w", err) - } - defer os.RemoveAll(out) - - fileData, err = os.ReadFile(out) - if err != nil { - return fmt.Errorf("failed reading file:%w", err) - } - - } else { - // base 64 decode the file and write it somewhere - // that we will cleanup - fileData, err = base64.StdEncoding.DecodeString(input.File) - if err != nil { - return err - } - } - - // Create a temporary file - outputFile, err := os.CreateTemp(o.ImageDir, "b64") - if err != nil { - return err - } - // write the base64 result - writer := bufio.NewWriter(outputFile) - _, err = writer.Write(fileData) - if err != nil { - outputFile.Close() - return err - } - outputFile.Close() - src = outputFile.Name() - defer os.RemoveAll(src) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - switch config.Backend { - case "stablediffusion": - config.Backend = model.StableDiffusionBackend - case "tinydream": - config.Backend = model.TinyDreamBackend - case "": - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return fmt.Errorf("Invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat.Type == "b64_json" { - b64JSON = true - } - // src and clip_skip - var result []schema.Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := config.Step - if step == 0 { - step = 15 - } - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = o.ImageDir - } - // Create a temporary file - outputFile, err := os.CreateTemp(tempDir, "b64") - if err != nil { - return err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } - - baseURL := c.BaseURL() - - fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, o.Loader, *config, o) - if err != nil { - return err - } - if err := fn(); err != nil { - return err - } - - item := &schema.Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base - } - - result = append(result, *item) - } - } - - id := uuid.New().String() - created := int(time.Now().Unix()) - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Data: result, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} diff --git a/api/openai/inference.go b/api/openai/inference.go deleted file mode 100644 index 816c960c3798..000000000000 --- a/api/openai/inference.go +++ /dev/null @@ -1,55 +0,0 @@ -package openai - -import ( - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - model "github.com/go-skynet/LocalAI/pkg/model" -) - -func ComputeChoices( - req *schema.OpenAIRequest, - predInput string, - config *config.Config, - o *options.Option, - loader *model.ModelLoader, - cb func(string, *[]schema.Choice), - tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) { - n := req.N // number of completions to return - result := []schema.Choice{} - - if n == 0 { - n = 1 - } - - images := []string{} - for _, m := range req.Messages { - images = append(images, m.StringImages...) - } - - // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback) - if err != nil { - return result, backend.TokenUsage{}, err - } - - tokenUsage := backend.TokenUsage{} - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, backend.TokenUsage{}, err - } - - tokenUsage.Prompt += prediction.Usage.Prompt - tokenUsage.Completion += prediction.Usage.Completion - - finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) - cb(finetunedResponse, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, tokenUsage, err -} diff --git a/api/openai/request.go b/api/openai/request.go deleted file mode 100644 index cc15fe409c27..000000000000 --- a/api/openai/request.go +++ /dev/null @@ -1,336 +0,0 @@ -package openai - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "os" - "path/filepath" - "strings" - - config "github.com/go-skynet/LocalAI/api/config" - options "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -func readInput(c *fiber.Ctx, o *options.Option, randomModel bool) (string, *schema.OpenAIRequest, error) { - loader := o.Loader - input := new(schema.OpenAIRequest) - ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx - input.Cancel = cancel - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists && randomModel { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - return modelFile, input, nil -} - -// this function check if the string is an URL, if it's an URL downloads the image in memory -// encodes it in base64 and returns the base64 string -func getBase64Image(s string) (string, error) { - if strings.HasPrefix(s, "http") { - // download the image - resp, err := http.Get(s) - if err != nil { - return "", err - } - defer resp.Body.Close() - - // read the image data into memory - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // encode the image data in base64 - encoded := base64.StdEncoding.EncodeToString(data) - - // return the base64 string - return encoded, nil - } - - // if the string instead is prefixed with "data:image/jpeg;base64,", drop it - if strings.HasPrefix(s, "data:image/jpeg;base64,") { - return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil - } - return "", fmt.Errorf("not valid string") -} - -func updateConfig(config *config.Config, input *schema.OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.ModelBaseName != "" { - config.AutoGPTQ.ModelBaseName = input.ModelBaseName - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.UseFastTokenizer { - config.UseFastTokenizer = input.UseFastTokenizer - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != 0 { - config.Temperature = input.Temperature - } - - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - // Decode each request's message content - index := 0 - for i, m := range input.Messages { - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - for _, pp := range c { - if pp.Type == "text" { - input.Messages[i].StringContent = pp.Text - } else if pp.Type == "image_url" { - // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: - base64, err := getBase64Image(pp.ImageURL.URL) - if err == nil { - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - // set a placeholder for each image - input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent - index++ - } else { - fmt.Print("Failed encoding image", err) - } - } - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.F16 { - config.F16 = input.F16 - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != 0 { - config.Seed = input.Seed - } - - if input.Mirostat != 0 { - config.LLMConfig.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.LLMConfig.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.LLMConfig.MirostatTAU = input.MirostatTAU - } - - if input.TypicalP != 0 { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} - -func readConfig(modelFile string, input *schema.OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *schema.OpenAIRequest, error) { - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - - var cfg *config.Config - - defaults := func() { - cfg = config.DefaultConfig(modelFile) - cfg.ContextSize = ctx - cfg.Threads = threads - cfg.F16 = f16 - cfg.Debug = debug - } - - cfgExisting, exists := cm.GetConfig(modelFile) - if !exists { - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfgExisting, exists = cm.GetConfig(modelFile) - if exists { - cfg = &cfgExisting - } else { - defaults() - } - } else { - defaults() - } - } else { - cfg = &cfgExisting - } - - // Set the parameters for the language model prediction - updateConfig(cfg, input) - - // Don't allow 0 as setting - if cfg.Threads == 0 { - if threads != 0 { - cfg.Threads = threads - } else { - cfg.Threads = 4 - } - } - - // Enforce debug flag if passed from CLI - if debug { - cfg.Debug = true - } - - return cfg, input, nil -} diff --git a/api/openai/transcription.go b/api/openai/transcription.go deleted file mode 100644 index 895c110f5df4..000000000000 --- a/api/openai/transcription.go +++ /dev/null @@ -1,71 +0,0 @@ -package openai - -import ( - "fmt" - "io" - "net/http" - "os" - "path" - "path/filepath" - - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" -) - -// https://platform.openai.com/docs/api-reference/audio/create -func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - // retrieve the file data from the request - file, err := c.FormFile("file") - if err != nil { - return err - } - f, err := file.Open() - if err != nil { - return err - } - defer f.Close() - - dir, err := os.MkdirTemp("", "whisper") - - if err != nil { - return err - } - defer os.RemoveAll(dir) - - dst := filepath.Join(dir, path.Base(file.Filename)) - dstFile, err := os.Create(dst) - if err != nil { - return err - } - - if _, err := io.Copy(dstFile, f); err != nil { - log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) - return err - } - - log.Debug().Msgf("Audio file copied to: %+v", dst) - - tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) - if err != nil { - return err - } - - log.Debug().Msgf("Trascribed: %+v", tr) - // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) - } -} diff --git a/api/backend/embeddings.go b/core/backend/embeddings.go similarity index 50% rename from api/backend/embeddings.go rename to core/backend/embeddings.go index 63f1a831e26d..7995c3971642 100644 --- a/api/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -2,14 +2,17 @@ package backend import ( "fmt" + "time" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" "github.com/go-skynet/LocalAI/pkg/grpc" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/google/uuid" + "github.com/rs/zerolog/log" ) -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c datamodel.Config, o *datamodel.StartupOptions) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } @@ -27,6 +30,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. model.WithAssetDir(o.AssetsDestination), model.WithModel(modelFile), model.WithContext(o.Context), + model.WithExternalBackends(o.ExternalGRPCBackends, false), }) if c.Backend == "" { @@ -90,3 +94,51 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. return embeds, nil }, nil } + +func EmbeddingOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.OpenAIResponse, error) { + config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) + if err != nil { + return nil, fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + items := []datamodel.Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := ModelEmbedding("", s, ml, *config, startupOptions) + if err != nil { + return nil, err + } + + embeddings, err := embedFn() + if err != nil { + return nil, err + } + items = append(items, datamodel.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := ModelEmbedding(s, []int{}, ml, *config, startupOptions) + if err != nil { + return nil, err + } + + embeddings, err := embedFn() + if err != nil { + return nil, err + } + items = append(items, datamodel.Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + id := uuid.New().String() + created := int(time.Now().Unix()) + return &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + }, nil +} diff --git a/core/backend/image.go b/core/backend/image.go new file mode 100644 index 000000000000..dee0615e6b40 --- /dev/null +++ b/core/backend/image.go @@ -0,0 +1,212 @@ +package backend + +import ( + "encoding/base64" + "fmt" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, c datamodel.Config, o *datamodel.StartupOptions) (func() error, error) { + + opts := modelOpts(c, o, []model.Option{ + model.WithBackendString(c.Backend), + model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithContext(o.Context), + model.WithModel(c.Model), + model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{ + CUDA: c.CUDA || c.Diffusers.CUDA, + SchedulerType: c.Diffusers.SchedulerType, + PipelineType: c.Diffusers.PipelineType, + CFGScale: c.Diffusers.CFGScale, + LoraAdapter: c.LoraAdapter, + LoraScale: c.LoraScale, + LoraBase: c.LoraBase, + IMG2IMG: c.Diffusers.IMG2IMG, + CLIPModel: c.Diffusers.ClipModel, + CLIPSubfolder: c.Diffusers.ClipSubFolder, + CLIPSkip: int32(c.Diffusers.ClipSkip), + ControlNet: c.Diffusers.ControlNet, + }), + model.WithExternalBackends(o.ExternalGRPCBackends, false), + }) + + inferenceModel, err := loader.BackendLoader( + opts..., + ) + if err != nil { + return nil, err + } + + fn := func() error { + _, err := inferenceModel.GenerateImage( + o.Context, + &proto.GenerateImageRequest{ + Height: int32(height), + Width: int32(width), + Mode: int32(mode), + Step: int32(step), + Seed: int32(seed), + CLIPSkip: int32(c.Diffusers.ClipSkip), + PositivePrompt: positive_prompt, + NegativePrompt: negative_prompt, + Dst: dst, + Src: src, + EnableParameters: c.Diffusers.EnableParameters, + }) + return err + } + + return fn, nil +} + +func ImageGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.OpenAIResponse, error) { + id := uuid.New().String() + created := int(time.Now().Unix()) + + if modelName == "" { + modelName = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", modelName) + + config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) + if err != nil { + return nil, fmt.Errorf("failed reading parameters from request: %w", err) + } + + src := "" + if input.File != "" { + if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { + src, err = utils.CreateTempFileFromUrl(input.File, "", "image-src") + if err != nil { + return nil, fmt.Errorf("failed downloading file:%w", err) + } + } else { + src, err = utils.CreateTempFileFromBase64(input.File, "", "base64-image-src") + if err != nil { + return nil, fmt.Errorf("error creating temporary image source file: %w", err) + } + } + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionBackend + case "tinydream": + config.Backend = model.TinyDreamBackend + case "": + config.Backend = model.StableDiffusionBackend + default: + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return nil, fmt.Errorf("invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return nil, fmt.Errorf("invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return nil, fmt.Errorf("invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat.Type == "b64_json" { + b64JSON = true + } + // src and clip_skip + var result []datamodel.Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := config.Step + if step == 0 { + step = 15 + } + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = startupOptions.ImageDir + } + // Create a temporary file + outputFile, err := os.CreateTemp(tempDir, "b64") + if err != nil { + return nil, err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return nil, err + } + + fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, src, output, ml, *config, startupOptions) + if err != nil { + return nil, err + } + if err := fn(); err != nil { + return nil, err + } + + item := &datamodel.Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return nil, err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = path.Join(startupOptions.ImageDir, base) + } + + result = append(result, *item) + } + } + + return &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Data: result, + }, nil +} diff --git a/core/backend/llm.go b/core/backend/llm.go new file mode 100644 index 000000000000..edd4d6661bc3 --- /dev/null +++ b/core/backend/llm.go @@ -0,0 +1,857 @@ +package backend + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +////////// TYPES ////////////// + +type LLMResponse struct { + Response string // should this be []byte? + Usage TokenUsage +} + +// TODO: Test removing this and using datamodel? +type TokenUsage struct { + Prompt int + Completion int +} + +type TemplateConfigBindingFn func(*datamodel.Config) *string + +// type LLMStreamProcessor func(s string, req *datamodel.OpenAIRequest, config *datamodel.Config, loader *model.ModelLoader, responses chan datamodel.OpenAIResponse) + +/////// CONSTS /////////// + +const DEFAULT_NO_ACTION_NAME = "answer" +const DEFAULT_NO_ACTION_DESCRIPTION = "use this action to answer without performing any action" + +////// INFERENCE ///////// + +func ModelInference(ctx context.Context, s string, images []string, loader *model.ModelLoader, c datamodel.Config, o *datamodel.StartupOptions, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel *grpc.Client + var err error + + opts := modelOpts(c, o, []model.Option{ + model.WithLoadGRPCLoadModelOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup + model.WithAssetDir(o.AssetsDestination), + model.WithModel(modelFile), + model.WithContext(o.Context), + model.WithExternalBackends(o.ExternalGRPCBackends, false), + }) + + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) + } + + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + utils.ResetDownloadTimers() + // if we failed to load the model, we try to download it + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + if err != nil { + return nil, err + } + } + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + inferenceModel, err = loader.BackendLoader(opts...) + } + + if err != nil { + return nil, err + } + + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (LLMResponse, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + opts.Images = images + + tokenUsage := TokenUsage{} + + // check the per-model feature flag for usage, since tokenCallback may have a cost. + // Defaults to off as for now it is still experimental + if c.FeatureFlag.Enabled("usage") { + userTokenCallback := tokenCallback + if userTokenCallback == nil { + userTokenCallback = func(token string, usage TokenUsage) bool { + return true + } + } + + promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) + if pErr == nil && promptInfo.Length > 0 { + tokenUsage.Prompt = int(promptInfo.Length) + } + + tokenCallback = func(token string, usage TokenUsage) bool { + tokenUsage.Completion++ + return userTokenCallback(token, tokenUsage) + } + } + + if tokenCallback != nil { + ss := "" + + var partialRune []byte + err := inferenceModel.PredictStream(ctx, opts, func(chars []byte) { + partialRune = append(partialRune, chars...) + + for len(partialRune) > 0 { + r, size := utf8.DecodeRune(partialRune) + if r == utf8.RuneError { + // incomplete rune, wait for more bytes + break + } + + tokenCallback(string(r), tokenUsage) + ss += string(r) + + partialRune = partialRune[size:] + } + }) + return LLMResponse{ + Response: ss, + Usage: tokenUsage, + }, err + } else { + // TODO: Is the chicken bit the only way to get here? is that acceptable? + reply, err := inferenceModel.Predict(ctx, opts) + if err != nil { + return LLMResponse{}, err + } + return LLMResponse{ + Response: string(reply.Message), + Usage: tokenUsage, + }, err + } + } + + return fn, nil +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config datamodel.Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + +} + +////// CONFIG AND REQUEST HANDLING /////////////// + +func ReadConfigFromFileAndCombineWithOpenAIRequest(modelFile string, input *datamodel.OpenAIRequest, cm *services.ConfigLoader, startupOptions *datamodel.StartupOptions) (*datamodel.Config, *datamodel.OpenAIRequest, error) { + // Load a config file if present after the model name + modelConfig := filepath.Join(startupOptions.ModelPath, modelFile+".yaml") + + var cfg *datamodel.Config + + defaults := func() { + cfg = datamodel.DefaultConfig(modelFile) + cfg.ContextSize = startupOptions.ContextSize + cfg.Threads = startupOptions.Threads + cfg.F16 = startupOptions.F16 + cfg.Debug = startupOptions.Debug + } + + cfgExisting, exists := cm.GetConfig(modelFile) + if !exists { + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cm.GetConfig(modelFile) + if exists { + cfg = &cfgExisting + } else { + defaults() + } + } else { + defaults() + } + } else { + cfg = &cfgExisting + } + + // Set the parameters for the language model prediction + datamodel.UpdateConfigFromOpenAIRequest(cfg, input) + + // Don't allow 0 as setting + if cfg.Threads == 0 { + if startupOptions.Threads != 0 { + cfg.Threads = startupOptions.Threads + } else { + cfg.Threads = 4 + } + } + + // Enforce debug flag if passed from CLI + if startupOptions.Debug { + cfg.Debug = true + } + + return cfg, input, nil +} + +func ComputeChoices( + req *datamodel.OpenAIRequest, + predInput string, + config *datamodel.Config, + o *datamodel.StartupOptions, + loader *model.ModelLoader, + cb func(string, *[]datamodel.Choice), + tokenCallback func(string, TokenUsage) bool) ([]datamodel.Choice, TokenUsage, error) { + n := req.N // number of completions to return + result := []datamodel.Choice{} + + if n == 0 { + n = 1 + } + + images := []string{} + for _, m := range req.Messages { + images = append(images, m.StringImages...) + } + + // get the model function to call for the result + predFunc, err := ModelInference(req.Context, predInput, images, loader, *config, o, tokenCallback) + if err != nil { + return result, TokenUsage{}, err + } + + tokenUsage := TokenUsage{} + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, TokenUsage{}, err + } + + tokenUsage.Prompt += prediction.Usage.Prompt + tokenUsage.Completion += prediction.Usage.Completion + + finetunedResponse := Finetune(*config, predInput, prediction.Response) + cb(finetunedResponse, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, tokenUsage, err +} + +// TODO: No functions???? Commonize with prepareChatGenerationOpenAIRequest below? +func prepareGenerationOpenAIRequest(bindingFn TemplateConfigBindingFn, modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.Config, error) { + config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) + if err != nil { + return nil, fmt.Errorf("failed reading parameters from request:%w", err) + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + configTemplate := bindingFn(config) + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if (*configTemplate == "") && (ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model))) { + *configTemplate = config.Model + } + if *configTemplate == "" { + return nil, fmt.Errorf(("failed to find templateConfig")) + } + + return config, nil +} + +////////// SPECIFIC REQUESTS ////////////// +// TODO: For round one of the refactor, give each of the three primary text endpoints their own function? +// SEMITODO: During a merge, edit/completion were semi-combined - but remain nominally split +// Can cleanup into a common form later if possible easier if they are all here for now +// If they remain different, extract each of these named segments to a seperate file + +func prepareChatGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.Config, string, bool, error) { + + // IMPORTANT DEFS + funcs := grammar.Functions{} + + // The Basic Begining + + config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) + if err != nil { + return nil, "", false, fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Special Input/Config Handling + + // Allow the user to set custom actions via config file + // to be "embedded" in each model - but if they are missing, use defaults. + if config.FunctionsConfig.NoActionFunctionName == "" { + config.FunctionsConfig.NoActionFunctionName = DEFAULT_NO_ACTION_NAME + } + if config.FunctionsConfig.NoActionDescriptionName == "" { + config.FunctionsConfig.NoActionDescriptionName = DEFAULT_NO_ACTION_DESCRIPTION + } + + if input.ResponseFormat.Type == "json_object" { + input.Grammar = grammar.JSONBNF + } + + processFunctions := len(input.Functions) > 0 && config.ShouldUseFunctions() + + if processFunctions { + log.Debug().Msgf("Response needs to process functions") + + noActionGrammar := grammar.Function{ + Name: config.FunctionsConfig.NoActionFunctionName, + Description: config.FunctionsConfig.NoActionDescriptionName, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + suppressConfigSystemPrompt := false + mess := []string{} + for messageIndex, i := range input.Messages { + var content string + role := i.Role + + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && i.StringContent != "" + // First attempt to populate content via a chat message specific template + if config.TemplateConfig.ChatMessage != "" { + chatMessageData := model.ChatMessageTemplateData{ + SystemPrompt: config.SystemPrompt, + Role: r, + RoleName: role, + Content: i.StringContent, + MessageIndex: messageIndex, + } + templatedChatMessage, err := ml.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) + if err != nil { + log.Error().Msgf("error processing message %+v using template \"%s\": %v. Skipping!", chatMessageData, config.TemplateConfig.ChatMessage, err) + } else { + if templatedChatMessage == "" { + log.Warn().Msgf("template \"%s\" produced blank output for %+v. Skipping!", config.TemplateConfig.ChatMessage, chatMessageData) + continue // TODO: This continue is here intentionally to skip over the line `mess = append(mess, content)` below, and to prevent the sprintf + } + log.Debug().Msgf("templated message for chat: %s", templatedChatMessage) + content = templatedChatMessage + } + } + // If this model doesn't have such a template, or if that template fails to return a value, template at the message level. + if content == "" { + if r != "" { + if contentExists { + content = fmt.Sprint(r, i.StringContent) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + } else { + if contentExists { + content = fmt.Sprint(i.StringContent) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + } + // Special Handling: System. We care if it was printed at all, not the r branch, so check seperately + if contentExists && role == "system" { + suppressConfigSystemPrompt = true + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + templateFile := "" + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + if ml.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) { + templateFile = config.Model + } + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + if templateFile != "" { + templatedInput, err := ml.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + SuppressSystemPrompt: suppressConfigSystemPrompt, + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + return config, predInput, processFunctions, nil + +} + +func EditGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.OpenAIResponse, error) { + id := uuid.New().String() + created := int(time.Now().Unix()) + + binding := func(config *datamodel.Config) *string { + return &config.TemplateConfig.Edit + } + + config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) + if err != nil { + return nil, err + } + + var result []datamodel.Choice + totalTokenUsage := TokenUsage{} + + for _, i := range config.InputStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := ml.EvaluateTemplateForPrompt(model.EditPromptTemplate, config.TemplateConfig.Edit, model.PromptTemplateData{ + Input: i, + Instruction: input.Instruction, + SystemPrompt: config.SystemPrompt, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, tokenUsage, err := ComputeChoices(input, i, config, startupOptions, ml, func(s string, c *[]datamodel.Choice) { + *c = append(*c, datamodel.Choice{Text: s}) + }, nil) + if err != nil { + return nil, err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + return &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + Usage: datamodel.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + }, nil +} + +func ChatGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.OpenAIResponse, error) { + + // DEFS + id := uuid.New().String() + created := int(time.Now().Unix()) + + // Prepare + config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) + if err != nil { + return nil, err + } + + result, tokenUsage, err := ComputeChoices(input, predInput, config, startupOptions, ml, func(s string, c *[]datamodel.Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s = utils.EscapeNewLines(s) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == config.FunctionsConfig.NoActionFunctionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, datamodel.Choice{Message: &datamodel.Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + predFunc, err := ModelInference(input.Context, predInput, images, ml, *config, startupOptions, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + fineTunedResponse := Finetune(*config, predInput, prediction.Response) + *c = append(*c, datamodel.Choice{Message: &datamodel.Message{Role: "assistant", Content: &fineTunedResponse}}) + } else { + // otherwise reply with the function call + *c = append(*c, datamodel.Choice{ + FinishReason: "function_call", + Message: &datamodel.Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, datamodel.Choice{FinishReason: "stop", Index: 0, Message: &datamodel.Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return nil, err + } + + return &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: datamodel.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + }, nil + +} + +func CompletionGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.OpenAIResponse, error) { + // Prepare + id := uuid.New().String() + created := int(time.Now().Unix()) + + binding := func(config *datamodel.Config) *string { + return &config.TemplateConfig.Completion + } + + config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) + if err != nil { + return nil, err + } + + var result []datamodel.Choice + + totalTokenUsage := TokenUsage{} + + for k, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{ + SystemPrompt: config.SystemPrompt, + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, tokenUsage, err := ComputeChoices( + input, i, config, startupOptions, ml, func(s string, c *[]datamodel.Choice) { + *c = append(*c, datamodel.Choice{Text: s, FinishReason: "stop", Index: k}) + }, nil) + if err != nil { + return nil, err + } + + totalTokenUsage.Prompt += tokenUsage.Prompt + totalTokenUsage.Completion += tokenUsage.Completion + + result = append(result, r...) + } + + return &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + Usage: datamodel.OpenAIUsage{ + PromptTokens: totalTokenUsage.Prompt, + CompletionTokens: totalTokenUsage.Completion, + TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, + }, + }, nil +} + +func StreamingChatGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (chan datamodel.OpenAIResponse, error) { + + // DEFS + emptyMessage := "" + id := uuid.New().String() + created := int(time.Now().Unix()) + + // Prepare + config, predInput, processFunctions, err := prepareChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) + if err != nil { + return nil, err + } + + if processFunctions { + // TODO: unused variable means I did something wrong. investigate once stable + log.Debug().Msgf("StreamingChatGenerationOpenAIRequest with processFunctions=true for %s?", config.Name) + } + + processor := func(s string, req *datamodel.OpenAIRequest, config *datamodel.Config, loader *model.ModelLoader, responses chan datamodel.OpenAIResponse) { + initialMessage := datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []datamodel.Choice{{Delta: &datamodel.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]datamodel.Choice) {}, func(s string, usage TokenUsage) bool { + resp := datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []datamodel.Choice{{Delta: &datamodel.Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: datamodel.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + + responses <- resp + return true + }) + close(responses) + } + log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to create response channel") + + responses := make(chan datamodel.OpenAIResponse) + + log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: About to start processor goroutine") + + go processor(predInput, input, config, ml, responses) + + log.Trace().Msg("StreamingChatGenerationOpenAIRequest :: DONE! successfully returning to caller!") + + return responses, nil + +} + +func StreamingCompletionGenerationOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (chan datamodel.OpenAIResponse, error) { + // DEFS + id := uuid.New().String() + created := int(time.Now().Unix()) + + binding := func(config *datamodel.Config) *string { + return &config.TemplateConfig.Completion + } + + // Prepare + + config, err := prepareGenerationOpenAIRequest(binding, modelName, input, cl, ml, startupOptions) + if err != nil { + return nil, err + } + + processor := func(s string, req *datamodel.OpenAIRequest, config *datamodel.Config, loader *model.ModelLoader, responses chan datamodel.OpenAIResponse) { + ComputeChoices(req, s, config, startupOptions, loader, func(s string, c *[]datamodel.Choice) {}, func(s string, usage TokenUsage) bool { + resp := datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []datamodel.Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + Usage: datamodel.OpenAIUsage{ + PromptTokens: usage.Prompt, + CompletionTokens: usage.Completion, + TotalTokens: usage.Prompt + usage.Completion, + }, + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + + if len(config.PromptStrings) > 1 { + return nil, errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + + } + + predInput := config.PromptStrings[0] + + //A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := ml.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, config.TemplateConfig.Completion, model.PromptTemplateData{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to create response channel") + + responses := make(chan datamodel.OpenAIResponse) + + log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: About to start processor goroutine") + + go processor(predInput, input, config, ml, responses) + + log.Trace().Msg("StreamingCompletionGenerationOpenAIRequest :: DONE! successfully returning to caller!") + + return responses, nil +} diff --git a/api/backend/options.go b/core/backend/options.go similarity index 90% rename from api/backend/options.go rename to core/backend/options.go index 3266d602cce2..53ae87112353 100644 --- a/api/backend/options.go +++ b/core/backend/options.go @@ -4,14 +4,12 @@ import ( "os" "path/filepath" + "github.com/go-skynet/LocalAI/pkg/datamodel" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" - - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/model" ) -func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { +func modelOpts(c datamodel.Config, o *datamodel.StartupOptions, opts []model.Option) []model.Option { if o.SingleBackend { opts = append(opts, model.WithSingleActiveBackend()) } @@ -35,7 +33,7 @@ func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model. return opts } -func gRPCModelOpts(c config.Config) *pb.ModelOptions { +func gRPCModelOpts(c datamodel.Config) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch @@ -82,7 +80,7 @@ func gRPCModelOpts(c config.Config) *pb.ModelOptions { } } -func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { +func gRPCPredictOpts(c datamodel.Config, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/core/backend/transcription.go b/core/backend/transcription.go new file mode 100644 index 000000000000..d2f27adf586f --- /dev/null +++ b/core/backend/transcription.go @@ -0,0 +1,52 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelTranscription(audio, language string, loader *model.ModelLoader, c datamodel.Config, o *datamodel.StartupOptions) (*datamodel.WhisperResult, error) { + + opts := modelOpts(c, o, []model.Option{ + model.WithBackendString(model.WhisperBackend), + model.WithModel(c.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + model.WithExternalBackends(o.ExternalGRPCBackends, false), + }) + + whisperModel, err := loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if whisperModel == nil { + return nil, fmt.Errorf("could not load whisper model") + } + + return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: audio, + Language: language, + Threads: uint32(c.Threads), + }) +} + +func TranscriptionOpenAIRequest(modelName string, input *datamodel.OpenAIRequest, audioFilePath string, cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) (*datamodel.WhisperResult, error) { + config, input, err := ReadConfigFromFileAndCombineWithOpenAIRequest(modelName, input, cl, startupOptions) + if err != nil { + return nil, fmt.Errorf("failed reading parameters from request:%w", err) + } + + tr, err := ModelTranscription(audioFilePath, input.Language, ml, *config, startupOptions) + if err != nil { + return nil, err + } + + return tr, nil +} diff --git a/api/backend/tts.go b/core/backend/tts.go similarity index 76% rename from api/backend/tts.go rename to core/backend/tts.go index ae8f53eea938..fe6ff09320b4 100644 --- a/api/backend/tts.go +++ b/core/backend/tts.go @@ -6,10 +6,9 @@ import ( "os" "path/filepath" - api_config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/datamodel" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" ) @@ -29,18 +28,19 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *datamodel.StartupOptions) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend } - opts := modelOpts(api_config.Config{}, o, []model.Option{ + opts := modelOpts(datamodel.Config{}, o, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(o.Context), model.WithAssetDir(o.AssetsDestination), + model.WithExternalBackends(o.ExternalGRPCBackends, false), }) - piperModel, err := o.Loader.BackendLoader(opts...) + piperModel, err := loader.BackendLoader(opts...) if err != nil { return "", nil, err } @@ -60,8 +60,8 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt modelPath := "" if modelFile != "" { if bb != model.TransformersMusicGen { - modelPath = filepath.Join(o.Loader.ModelPath, modelFile) - if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + modelPath = filepath.Join(o.ModelPath, modelFile) + if err := utils.VerifyPath(modelPath, o.ModelPath); err != nil { return "", nil, err } } else { diff --git a/core/http/api.go b/core/http/api.go new file mode 100644 index 000000000000..41978ebb8a1f --- /dev/null +++ b/core/http/api.go @@ -0,0 +1,169 @@ +package http + +import ( + "errors" + "strings" + + "github.com/go-skynet/LocalAI/core/http/endpoints/localai" + "github.com/go-skynet/LocalAI/core/http/endpoints/openai" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/recover" +) + +func App(cl *services.ConfigLoader, ml *model.ModelLoader, options *datamodel.StartupOptions) (*fiber.App, error) { + + // Return errors as JSON responses + app := fiber.New(fiber.Config{ + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.DisableMessage, + // Override default error handler + ErrorHandler: func(ctx *fiber.Ctx, err error) error { + // Status code defaults to 500 + code := fiber.StatusInternalServerError + + // Retrieve the custom status code if it's a *fiber.Error + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + } + + // Send custom error page + return ctx.Status(code).JSON( + datamodel.ErrorResponse{ + Error: &datamodel.APIError{Message: err.Error(), Code: code}, + }, + ) + }, + }) + + if options.Debug { + app.Use(logger.New(logger.Config{ + Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", + })) + } + + // Default middleware config + app.Use(recover.New()) + + if options.Metrics != nil { + app.Use(localai.MetricsAPIMiddleware(options.Metrics)) + } + + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. + auth := func(c *fiber.Ctx) error { + if len(options.ApiKeys) == 0 { + return c.Next() + } + + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) + } + authHeaderParts := strings.Split(authHeader, " ") + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } + + apiKey := authHeaderParts[1] + for _, key := range options.ApiKeys { + if apiKey == key { + return c.Next() + } + } + + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + + } + + if options.CORS { + var c func(ctx *fiber.Ctx) error + if options.CORSAllowOrigins == "" { + c = cors.New() + } else { + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) + } + + app.Use(c) + } + + // LocalAI API endpoints + galleryService := services.NewGalleryApplier(options.ModelPath) + galleryService.Start(options.Context, cl) + + app.Get("/version", auth, func(c *fiber.Ctx) error { + return c.JSON(struct { + Version string `json:"version"` + }{Version: internal.PrintableVersion()}) + }) + + modelGalleryService := localai.CreateModelGalleryEndpointService(options.Galleries, options.ModelPath, galleryService) + app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) + app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) + app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) + app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) + app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) + app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) + app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) + + // openAI compatible API endpoint + + // chat + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, ml, options)) + + // edit + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, ml, options)) + app.Post("/edits", auth, openai.EditEndpoint(cl, ml, options)) + + // completion + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, ml, options)) + app.Post("/completions", auth, openai.CompletionEndpoint(cl, ml, options)) + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, ml, options)) + + // embeddings + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, ml, options)) + + // audio + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, ml, options)) + app.Post("/tts", auth, localai.TTSEndpoint(cl, ml, options)) + + // images + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, ml, options)) + + if options.ImageDir != "" { + app.Static("/generated-images", options.ImageDir) + } + + if options.AudioDir != "" { + app.Static("/generated-audio", options.AudioDir) + } + + ok := func(c *fiber.Ctx) error { + return c.SendStatus(200) + } + + // Kubernetes health checks + app.Get("/healthz", ok) + app.Get("/readyz", ok) + + app.Get("/metrics", localai.MetricsHandler()) + + backendMonitor := services.NewBackendMonitor(cl, ml, options) + app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) + app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) + + // model listing + app.Get("/v1/models", auth, openai.ListModelsEndpoint(cl, ml)) + app.Get("/models", auth, openai.ListModelsEndpoint(cl, ml)) + + return app, nil +} diff --git a/api/api_test.go b/core/http/api_test.go similarity index 93% rename from api/api_test.go rename to core/http/api_test.go index a71b450ada7d..b85fc056dcf0 100644 --- a/api/api_test.go +++ b/core/http/api_test.go @@ -1,4 +1,4 @@ -package api_test +package http_test import ( "bytes" @@ -13,9 +13,10 @@ import ( "path/filepath" "runtime" - . "github.com/go-skynet/LocalAI/api" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/metrics" + server "github.com/go-skynet/LocalAI/core/http" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/startup" + "github.com/go-skynet/LocalAI/pkg/datamodel" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -118,16 +119,15 @@ var backendAssets embed.FS var _ = Describe("API test", func() { var app *fiber.App - var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc var tmpdir string - commonOpts := []options.AppOption{ - options.WithDebug(true), - options.WithDisableMessage(true), + commonOpts := []datamodel.AppOption{ + datamodel.WithDebug(true), + datamodel.WithDisableMessage(true), } Context("API with ephemeral models", func() { @@ -136,7 +136,6 @@ var _ = Describe("API test", func() { tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) - modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) g := []gallery.GalleryModel{ @@ -163,15 +162,20 @@ var _ = Describe("API test", func() { }, } - metricsService, err := metrics.SetupMetrics() + metricsService, err := services.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - app, err = App( + cl, ml, options, err := startup.Startup( append(commonOpts, - options.WithMetrics(metricsService), - options.WithContext(c), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) + datamodel.WithMetrics(metricsService), + datamodel.WithContext(c), + datamodel.WithGalleries(galleries), + datamodel.WithModelPath(tmpdir), + datamodel.WithBackendAssets(backendAssets), + datamodel.WithBackendAssetsOutput(tmpdir))...) + + Expect(err).ToNot(HaveOccurred()) + app, err = server.App(cl, ml, options) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -475,7 +479,6 @@ var _ = Describe("API test", func() { tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) - modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) galleries := []gallery.Gallery{ @@ -485,21 +488,22 @@ var _ = Describe("API test", func() { }, } - metricsService, err := metrics.SetupMetrics() + metricsService, err := services.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - app, err = App( + cl, ml, options, err := startup.Startup( append(commonOpts, - options.WithContext(c), - options.WithMetrics(metricsService), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(tmpdir))..., + datamodel.WithContext(c), + datamodel.WithMetrics(metricsService), + datamodel.WithAudioDir(tmpdir), + datamodel.WithImageDir(tmpdir), + datamodel.WithGalleries(galleries), + datamodel.WithModelPath(tmpdir), + datamodel.WithBackendAssets(backendAssets), + datamodel.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) + app, err = server.App(cl, ml, options) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -590,20 +594,21 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - metricsService, err := metrics.SetupMetrics() + metricsService, err := services.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - app, err = App( + cl, ml, options, err := startup.Startup( append(commonOpts, - options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), - options.WithContext(c), - options.WithModelLoader(modelLoader), - options.WithMetrics(metricsService), + datamodel.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + datamodel.WithContext(c), + datamodel.WithModelPath(os.Getenv("MODELS_PATH")), + datamodel.WithMetrics(metricsService), )...) Expect(err).ToNot(HaveOccurred()) + app, err = server.App(cl, ml, options) + Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -802,20 +807,21 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - metricsService, err := metrics.SetupMetrics() + metricsService, err := services.SetupMetrics() Expect(err).ToNot(HaveOccurred()) - app, err = App( + cl, ml, options, err := startup.Startup( append(commonOpts, - options.WithContext(c), - options.WithMetrics(metricsService), - options.WithModelLoader(modelLoader), - options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., + datamodel.WithContext(c), + datamodel.WithMetrics(metricsService), + datamodel.WithModelPath(os.Getenv("MODELS_PATH")), + datamodel.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) + app, err = server.App(cl, ml, options) + Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") diff --git a/api/apt_suite_test.go b/core/http/apt_suite_test.go similarity index 90% rename from api/apt_suite_test.go rename to core/http/apt_suite_test.go index e3c15c048b14..0269a97321df 100644 --- a/api/apt_suite_test.go +++ b/core/http/apt_suite_test.go @@ -1,4 +1,4 @@ -package api_test +package http_test import ( "testing" diff --git a/core/http/endpoints/localai/backend_monitor.go b/core/http/endpoints/localai/backend_monitor.go new file mode 100644 index 000000000000..45c82cb024e1 --- /dev/null +++ b/core/http/endpoints/localai/backend_monitor.go @@ -0,0 +1,34 @@ +package localai + +import ( + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/gofiber/fiber/v2" +) + +func BackendMonitorEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(datamodel.BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + resp, err := bm.CheckAndSample(input.Model) + if err != nil { + return err + } + return c.JSON(resp) + } +} + +func BackendShutdownEndpoint(bm *services.BackendMonitor) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(datamodel.BackendMonitorRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + return bm.ShutdownModel(input.Model) + } +} diff --git a/core/http/endpoints/localai/gallery.go b/core/http/endpoints/localai/gallery.go new file mode 100644 index 000000000000..6b4d73b596e0 --- /dev/null +++ b/core/http/endpoints/localai/gallery.go @@ -0,0 +1,148 @@ +package localai + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +/// Endpoint Service + +type ModelGalleryEndpointService struct { + galleries []gallery.Gallery + modelPath string + galleryApplier *services.GalleryApplier +} + +type GalleryModel struct { + ID string `json:"id"` + gallery.GalleryModel +} + +func CreateModelGalleryEndpointService(galleries []gallery.Gallery, modelPath string, galleryApplier *services.GalleryApplier) ModelGalleryEndpointService { + return ModelGalleryEndpointService{ + galleries: galleries, + modelPath: modelPath, + galleryApplier: galleryApplier, + } +} + +func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + status := mgs.galleryApplier.GetStatus(c.Params("uuid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + return c.JSON(status) + } +} + +func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + return c.JSON(mgs.galleryApplier.GetAllStatus()) + } +} + +func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(GalleryModel) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + mgs.galleryApplier.C <- gallery.GalleryOp{ + Req: input.GalleryModel, + Id: uuid.String(), + GalleryName: input.ID, + Galleries: mgs.galleries, + } + return c.JSON(struct { + ID string `json:"uuid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} + +func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries) + + models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath) + if err != nil { + return err + } + log.Debug().Msgf("Models found from galleries: %+v", models) + for _, m := range models { + log.Debug().Msgf("Model found from galleries: %+v", m) + } + dat, err := json.Marshal(models) + if err != nil { + return err + } + return c.Send(dat) + } +} + +// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! +func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + return c.Send(dat) + } +} + +func (mgs *ModelGalleryEndpointService) AddModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s already exists", input.Name) + } + dat, err := json.Marshal(mgs.galleries) + if err != nil { + return err + } + log.Debug().Msgf("Adding %+v to gallery list", *input) + mgs.galleries = append(mgs.galleries, *input) + return c.Send(dat) + } +} + +func (mgs *ModelGalleryEndpointService) RemoveModelGalleryEndpoint() func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(gallery.Gallery) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + if !slices.ContainsFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) { + return fmt.Errorf("%s is not currently registered", input.Name) + } + mgs.galleries = slices.DeleteFunc(mgs.galleries, func(gallery gallery.Gallery) bool { + return gallery.Name == input.Name + }) + return c.Send(nil) + } +} diff --git a/core/http/endpoints/localai/metrics.go b/core/http/endpoints/localai/metrics.go new file mode 100644 index 000000000000..c77ba08ca5dc --- /dev/null +++ b/core/http/endpoints/localai/metrics.go @@ -0,0 +1,42 @@ +package localai + +import ( + "time" + + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +func MetricsHandler() fiber.Handler { + return adaptor.HTTPHandler(promhttp.Handler()) +} + +type apiMiddlewareConfig struct { + Filter func(c *fiber.Ctx) bool + metrics *datamodel.LocalAIMetrics +} + +func MetricsAPIMiddleware(metrics *datamodel.LocalAIMetrics) fiber.Handler { + cfg := apiMiddlewareConfig{ + metrics: metrics, + Filter: func(c *fiber.Ctx) bool { + return c.Path() == "/metrics" + }, + } + + return func(c *fiber.Ctx) error { + if cfg.Filter != nil && cfg.Filter(c) { + return c.Next() + } + path := c.Path() + method := c.Method() + + start := time.Now() + err := c.Next() + elapsed := float64(time.Since(start)) / float64(time.Second) + cfg.metrics.ObserveAPICall(method, path, elapsed) + return err + } +} diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go new file mode 100644 index 000000000000..518f7915876d --- /dev/null +++ b/core/http/endpoints/localai/tts.go @@ -0,0 +1,26 @@ +package localai + +import ( + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func TTSEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(datamodel.TTSRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + filePath, _, err := backend.ModelTTS(input.Backend, input.Input, input.Model, ml, so) + if err != nil { + return err + } + return c.Download(filePath) + } +} diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go new file mode 100644 index 000000000000..d767e36b5e0c --- /dev/null +++ b/core/http/endpoints/openai/chat.go @@ -0,0 +1,98 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +func ChatEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, startupOptions *datamodel.StartupOptions) func(c *fiber.Ctx) error { + + emptyMessage := "" + + return func(c *fiber.Ctx) error { + modelName, input, err := readInput(c, startupOptions, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + // The scary comment I feel like I forgot about along the way: + // + // functions are not supported in stream mode (yet?) + // + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + + responses, err := backend.StreamingChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) + if err != nil { + return fmt.Errorf("failed establishing streaming chat request :%w", err) + } + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + usage := &datamodel.OpenAIUsage{} + id := "" + created := 0 + for ev := range responses { + usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + id = ev.ID + created = ev.Created // Similarly, grab the ID and created from any / the last response so we can use it for the stop + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + log.Debug().Msgf("Sending chunk: %s", buf.String()) + _, err := fmt.Fprintf(w, "data: %v\n", buf.String()) + if err != nil { + log.Debug().Msgf("Sending chunk failed: %v", err) + input.Cancel() + break + } + w.Flush() + } + + resp := &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []datamodel.Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &datamodel.Message{Content: &emptyMessage}, + }}, + Object: "chat.completion.chunk", + Usage: *usage, + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + ////////////////////////////////////////// + + resp, err := backend.ChatGenerationOpenAIRequest(modelName, input, cl, ml, startupOptions) + if err != nil { + return fmt.Errorf("error generating chat request: +%w", err) + } + respData, _ := json.Marshal(resp) // TODO this is only used for the debug log and costs performance. monitor this? + log.Debug().Msgf("Response: %s", respData) + return c.JSON(resp) + } +} diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go new file mode 100644 index 000000000000..b8b9e626a66e --- /dev/null +++ b/core/http/endpoints/openai/completion.go @@ -0,0 +1,91 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "time" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// https://platform.openai.com/docs/api-reference/completions +func CompletionEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + id := uuid.New().String() + created := int(time.Now().Unix()) + + return func(c *fiber.Ctx) error { + modelName, input, err := readInput(c, so, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("`input`: %+v", input) + + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + + responses, err := backend.StreamingCompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) + if err != nil { + return fmt.Errorf("failed establishing streaming completion request :%w", err) + } + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &datamodel.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []datamodel.Choice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + /////////// + + resp, err := backend.CompletionGenerationOpenAIRequest(modelName, input, cl, ml, so) + if err != nil { + return fmt.Errorf("error generating completion request: +%w", err) + } + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go new file mode 100644 index 000000000000..f167a757febc --- /dev/null +++ b/core/http/endpoints/openai/edit.go @@ -0,0 +1,34 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + + "github.com/rs/zerolog/log" +) + +func EditEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelFile, input, err := readInput(c, so, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + resp, err := backend.EditGenerationOpenAIRequest(modelFile, input, cl, ml, so) + if err != nil { + return err + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go new file mode 100644 index 000000000000..b7db48babbea --- /dev/null +++ b/core/http/endpoints/openai/embeddings.go @@ -0,0 +1,35 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/embeddings +func EmbeddingsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelFile, input, err := readInput(c, so, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + resp, err := backend.EmbeddingOpenAIRequest(modelFile, input, cl, ml, so) + if err != nil { + return err + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go new file mode 100644 index 000000000000..a3284c3b468d --- /dev/null +++ b/core/http/endpoints/openai/image.go @@ -0,0 +1,48 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func ImageEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelName, input, err := readInput(c, so, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + resp, err := backend.ImageGenerationOpenAIRequest(modelName, input, cl, ml, so) + if err != nil { + return fmt.Errorf("error generating image request: +%w", err) + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/list.go b/core/http/endpoints/openai/list.go similarity index 66% rename from api/openai/list.go rename to core/http/endpoints/openai/list.go index 8bc5bbe22bee..327f7b7c7fad 100644 --- a/api/openai/list.go +++ b/core/http/endpoints/openai/list.go @@ -3,21 +3,21 @@ package openai import ( "regexp" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) -func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { +func ListModelsEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { - models, err := loader.ListModels() + models, err := ml.ListModels() if err != nil { return err } var mm map[string]interface{} = map[string]interface{}{} - dataModels := []schema.OpenAIModel{} + dataModels := []datamodel.OpenAIModel{} var filterFn func(name string) bool filter := c.Query("filter") @@ -40,13 +40,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func excludeConfigured := c.QueryBool("excludeConfigured", true) // Start with the known configurations - for _, c := range cm.GetAllConfigs() { + for _, c := range cl.GetAllConfigs() { if excludeConfigured { mm[c.Model] = nil } if filterFn(c.Name) { - dataModels = append(dataModels, schema.OpenAIModel{ID: c.Name, Object: "model"}) + dataModels = append(dataModels, datamodel.OpenAIModel{ID: c.Name, Object: "model"}) } } @@ -54,13 +54,13 @@ func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func for _, m := range models { // And only adds them if they shouldn't be skipped. if _, exists := mm[m]; !exists && filterFn(m) { - dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) + dataModels = append(dataModels, datamodel.OpenAIModel{ID: m, Object: "model"}) } } return c.JSON(struct { - Object string `json:"object"` - Data []schema.OpenAIModel `json:"data"` + Object string `json:"object"` + Data []datamodel.OpenAIModel `json:"data"` }{ Object: "list", Data: dataModels, diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go new file mode 100644 index 000000000000..2d039b99b2bf --- /dev/null +++ b/core/http/endpoints/openai/request.go @@ -0,0 +1,57 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readInput(c *fiber.Ctx, o *datamodel.StartupOptions, ml *model.ModelLoader, randomModel bool) (string, *datamodel.OpenAIRequest, error) { + input := new(datamodel.OpenAIRequest) + ctx, cancel := context.WithCancel(o.Context) + input.Context = ctx + input.Cancel = cancel + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, fmt.Errorf("failed parsing request body: %w", err) + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && ml.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists && randomModel { + models, _ := ml.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return "", nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + return modelFile, input, nil +} diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go new file mode 100644 index 000000000000..bece31aa7e75 --- /dev/null +++ b/core/http/endpoints/openai/transcription.go @@ -0,0 +1,49 @@ +package openai + +import ( + "fmt" + "net/http" + "os" + "path" + + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/audio/create +func TranscriptEndpoint(cl *services.ConfigLoader, ml *model.ModelLoader, so *datamodel.StartupOptions) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + modelName, input, err := readInput(c, so, ml, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + // retrieve the file data from the request + file, err := c.FormFile("file") + if err != nil { + return err + } + + dst, err := utils.CreateTempFileFromMultipartFile(file, "", "transcription") // 3rd param formerly whisper + if err != nil { + return err + } + + log.Debug().Msgf("Audio file copied to: %+v", dst) + defer os.RemoveAll(path.Dir(dst)) + + tr, err := backend.TranscriptionOpenAIRequest(modelName, input, dst, cl, ml, so) + if err != nil { + return fmt.Errorf("error generating transcription request: +%w", err) + } + log.Debug().Msgf("Trascribed: %+v", tr) + // TODO: handle different outputs here + return c.Status(http.StatusOK).JSON(tr) + } +} diff --git a/core/mqtt/manager.go b/core/mqtt/manager.go new file mode 100644 index 000000000000..7fa525a6790b --- /dev/null +++ b/core/mqtt/manager.go @@ -0,0 +1,24 @@ +package mqtt + +import ( + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" +) + +// PLACEHOLDER DURING PART 1 OF THE REFACTOR + +type MQTTManager struct { + configLoader *services.ConfigLoader + modelLoader *model.ModelLoader + startupOptions *datamodel.StartupOptions +} + +func NewMQTTManager(cl *services.ConfigLoader, ml *model.ModelLoader, options *datamodel.StartupOptions) (*MQTTManager, error) { + + return &MQTTManager{ + configLoader: cl, + modelLoader: ml, + startupOptions: options, + }, nil +} diff --git a/core/services/backend_monitor.go b/core/services/backend_monitor.go new file mode 100644 index 000000000000..28412e77f90c --- /dev/null +++ b/core/services/backend_monitor.go @@ -0,0 +1,138 @@ +package services + +import ( + "context" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/rs/zerolog/log" + + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitor struct { + configLoader *ConfigLoader + modelLoader *model.ModelLoader + options *datamodel.StartupOptions // Taking options in case we need to inspect ExternalGRPCBackends, though that's out of scope for now, hence the name. +} + +func NewBackendMonitor(configLoader *ConfigLoader, modelLoader *model.ModelLoader, options *datamodel.StartupOptions) *BackendMonitor { + return &BackendMonitor{ + configLoader: configLoader, + modelLoader: modelLoader, + options: options, + } +} + +func (bm *BackendMonitor) SampleLocalBackendProcess(model string) (*datamodel.BackendMonitorResponse, error) { + config, exists := bm.configLoader.GetConfig(model) + var backend string + if exists { + backend = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backend = model + } + + if !strings.HasSuffix(backend, ".bin") { + backend = fmt.Sprintf("%s.bin", backend) + } + + pid, err := bm.modelLoader.GetGRPCPID(backend) + + if err != nil { + log.Error().Msgf("model %s : failed to find pid %+v", model, err) + return nil, err + } + + // Name is slightly frightening but this does _not_ create a new process, rather it looks up an existing process by PID. + backendProcess, err := gopsutil.NewProcess(int32(pid)) + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting process info %+v", model, pid, err) + return nil, err + } + + memInfo, err := backendProcess.MemoryInfo() + + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory info %+v", model, pid, err) + return nil, err + } + + memPercent, err := backendProcess.MemoryPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting memory percent %+v", model, pid, err) + return nil, err + } + + cpuPercent, err := backendProcess.CPUPercent() + if err != nil { + log.Error().Msgf("model %s [PID %d] : error getting cpu percent %+v", model, pid, err) + return nil, err + } + + return &datamodel.BackendMonitorResponse{ + MemoryInfo: memInfo, + MemoryPercent: memPercent, + CPUPercent: cpuPercent, + }, nil +} + +func (bm BackendMonitor) getModelLoaderIDFromModelName(modelName string) (string, error) { + config, exists := bm.configLoader.GetConfig(modelName) + var backendId string + if exists { + backendId = config.Model + } else { + // Last ditch effort: use it raw, see if a backend happens to match. + backendId = modelName + } + + if !strings.HasSuffix(backendId, ".bin") { + backendId = fmt.Sprintf("%s.bin", backendId) + } + + return backendId, nil +} + +func (bm BackendMonitor) CheckAndSample(modelName string) (*proto.StatusResponse, error) { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) + if err != nil { + return nil, err + } + modelAddr := bm.modelLoader.CheckIsLoaded(backendId) + if modelAddr == "" { + return nil, fmt.Errorf("backend %s is not currently loaded", backendId) + } + + status, rpcErr := modelAddr.GRPC(false, nil).Status(context.TODO()) + if rpcErr != nil { + log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error()) + val, slbErr := bm.SampleLocalBackendProcess(backendId) + if slbErr != nil { + return nil, fmt.Errorf("backend %s experienced an error retrieving status info via rpc: %s, then failed local node process sample: %s", backendId, rpcErr.Error(), slbErr.Error()) + } + return &proto.StatusResponse{ + State: proto.StatusResponse_ERROR, + Memory: &proto.MemoryUsageData{ + Total: val.MemoryInfo.VMS, + Breakdown: map[string]uint64{ + "gopsutil-RSS": val.MemoryInfo.RSS, + }, + }, + }, nil + } + return status, nil +} + +func (bm BackendMonitor) ShutdownModel(modelName string) error { + backendId, err := bm.getModelLoaderIDFromModelName(modelName) + if err != nil { + return err + } + return bm.modelLoader.ShutdownModel(backendId) +} diff --git a/core/services/config.go b/core/services/config.go new file mode 100644 index 000000000000..1e76a8543f97 --- /dev/null +++ b/core/services/config.go @@ -0,0 +1,139 @@ +package services + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/rs/zerolog/log" +) + +type ConfigLoader struct { + configs map[string]datamodel.Config + sync.Mutex +} + +func NewConfigLoader() *ConfigLoader { + return &ConfigLoader{ + configs: make(map[string]datamodel.Config), + } +} + +// TODO: check this is correct post-merge +func (cm *ConfigLoader) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := datamodel.ReadSingleConfigFile(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cm.configs[c.Name] = *c + return nil +} + +func (cm *ConfigLoader) GetConfig(m string) (datamodel.Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm *ConfigLoader) GetAllConfigs() []datamodel.Config { + cm.Lock() + defer cm.Unlock() + var res []datamodel.Config + for _, v := range cm.configs { + res = append(res, v) + } + return res +} + +func (cm *ConfigLoader) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + +func (cm *ConfigLoader) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { + continue + } + c, err := datamodel.ReadSingleConfigFile(filepath.Join(path, file.Name())) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} + +// TODO: Does this belong under ConfigLoader? +func (cl *ConfigLoader) Preload(modelPath string) error { + cl.Lock() + defer cl.Unlock() + + for i, config := range cl.configs { + modelURL := config.PredictionOptions.Model + modelURL = utils.ConvertURL(modelURL) + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + // md5 of model name + md5Name := utils.MD5(modelURL) + + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist { + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { + log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent) + }) + if err != nil { + return err + } + } + + cc := cl.configs[i] + c := &cc + c.PredictionOptions.Model = md5Name + cl.configs[i] = *c + } + } + return nil +} + +func (cl *ConfigLoader) LoadConfigFile(file string) error { + cl.Lock() + defer cl.Unlock() + c, err := datamodel.ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cl.configs[cc.Name] = *cc + } + return nil +} diff --git a/core/services/gallery.go b/core/services/gallery.go new file mode 100644 index 000000000000..edc4e6cc71d9 --- /dev/null +++ b/core/services/gallery.go @@ -0,0 +1,160 @@ +package services + +import ( + "context" + "encoding/json" + "os" + "strings" + "sync" + + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + "gopkg.in/yaml.v2" +) + +type GalleryApplier struct { + modelPath string + sync.Mutex + C chan gallery.GalleryOp + statuses map[string]*gallery.GalleryOpStatus +} + +func NewGalleryApplier(modelPath string) *GalleryApplier { + return &GalleryApplier{ + modelPath: modelPath, + C: make(chan gallery.GalleryOp), + statuses: make(map[string]*gallery.GalleryOpStatus), + } +} + +func (g *GalleryApplier) UpdateStatus(s string, op *gallery.GalleryOpStatus) { + g.Lock() + defer g.Unlock() + g.statuses[s] = op +} + +func (g *GalleryApplier) GetStatus(s string) *gallery.GalleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses[s] +} + +func (g *GalleryApplier) GetAllStatus() map[string]*gallery.GalleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses +} + +func (g *GalleryApplier) Start(c context.Context, cm *ConfigLoader) { + go func() { + for { + select { + case <-c.Done(): + return + case op := <-g.C: + utils.ResetDownloadTimers() + + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", Progress: 0}) + + // updates the status with an error + updateError := func(e error) { + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) + } + + // displayDownload displays the download progress + progressCallback := func(fileName string, current string, total string, percentage float64) { + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) + utils.DisplayDownloadFunction(fileName, current, total, percentage) + } + + var err error + // if the request contains a gallery name, we apply the gallery from the gallery list + if op.GalleryName != "" { + if strings.Contains(op.GalleryName, "@") { + err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback) + } + } else { + err = PrepareModel(g.modelPath, op.Req, cm, progressCallback) + } + + if err != nil { + updateError(err) + continue + } + + // Reload models + err = cm.LoadConfigs(g.modelPath) + if err != nil { + updateError(err) + continue + } + + g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100}) + } + } + }() +} + +type galleryModel struct { + gallery.GalleryModel `yaml:",inline"` // https://github.com/go-yaml/yaml/issues/63 + ID string `json:"id"` +} + +func PrepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigLoader, downloadStatus func(string, string, string, float64)) error { + + config, err := gallery.GetInstallableModelFromURL(req.URL) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) +} + +func processRequests(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery, requests []galleryModel) error { + var err error + for _, r := range requests { + utils.ResetDownloadTimers() + if r.ID == "" { + err = PrepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) + } else { + if strings.Contains(r.ID, "@") { + err = gallery.InstallModelFromGallery( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } else { + err = gallery.InstallModelFromGalleryByName( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) + } + } + } + return err +} + +func ApplyGalleryFromFile(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error { + dat, err := os.ReadFile(s) + if err != nil { + return err + } + var requests []galleryModel + + if err := yaml.Unmarshal(dat, &requests); err != nil { + return err + } + + return processRequests(modelPath, s, cm, galleries, requests) +} + +func ApplyGalleryFromString(modelPath, s string, cm *ConfigLoader, galleries []gallery.Gallery) error { + var requests []galleryModel + err := json.Unmarshal([]byte(s), &requests) + if err != nil { + return err + } + + return processRequests(modelPath, s, cm, galleries, requests) +} diff --git a/core/services/metrics.go b/core/services/metrics.go new file mode 100644 index 000000000000..c28774c0cea0 --- /dev/null +++ b/core/services/metrics.go @@ -0,0 +1,29 @@ +package services + +import ( + "github.com/go-skynet/LocalAI/pkg/datamodel" + "go.opentelemetry.io/otel/exporters/prometheus" + api "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/sdk/metric" +) + +// setupOTelSDK bootstraps the OpenTelemetry pipeline. +// If it does not return an error, make sure to call shutdown for proper cleanup. +func SetupMetrics() (*datamodel.LocalAIMetrics, error) { + exporter, err := prometheus.New() + if err != nil { + return nil, err + } + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + meter := provider.Meter("github.com/go-skynet/LocalAI") + + apiTimeMetric, err := meter.Float64Histogram("api_call", api.WithDescription("api calls")) + if err != nil { + return nil, err + } + + return &datamodel.LocalAIMetrics{ + Meter: meter, + ApiTimeMetric: apiTimeMetric, + }, nil +} diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go new file mode 100644 index 000000000000..9d60a43096c1 --- /dev/null +++ b/core/startup/config_file_watcher.go @@ -0,0 +1,100 @@ +package startup + +import ( + "encoding/json" + "fmt" + "os" + "path" + + "github.com/fsnotify/fsnotify" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/imdario/mergo" + "github.com/rs/zerolog/log" +) + +type WatchConfigDirectoryCloser func() error + +func ReadApiKeysJson(configDir string, options *datamodel.StartupOptions) error { + fileContent, err := os.ReadFile(path.Join(configDir, "api_keys.json")) + if err == nil { + // Parse JSON content from the file + var fileKeys []string + err := json.Unmarshal(fileContent, &fileKeys) + if err == nil { + options.ApiKeys = append(options.ApiKeys, fileKeys...) + return nil + } + return err + } + return err +} + +func ReadExternalBackendsJson(configDir string, options *datamodel.StartupOptions) error { + fileContent, err := os.ReadFile(path.Join(configDir, "external_backends.json")) + if err != nil { + return err + } + // Parse JSON content from the file + var fileBackends map[string]string + err = json.Unmarshal(fileContent, &fileBackends) + if err != nil { + return err + } + err = mergo.Merge(&options.ExternalGRPCBackends, fileBackends) + if err != nil { + return err + } + return nil +} + +var CONFIG_FILE_UPDATES = map[string]func(configDir string, options *datamodel.StartupOptions) error{ + "api_keys.json": ReadApiKeysJson, + "external_backends.json": ReadExternalBackendsJson, +} + +func WatchConfigDirectory(configDir string, options *datamodel.StartupOptions) (WatchConfigDirectoryCloser, error) { + if len(configDir) == 0 { + return nil, fmt.Errorf("configDir blank") + } + configWatcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal().Msgf("Unable to create a watcher for the LocalAI Configuration Directory: %+v", err) + } + ret := func() error { + configWatcher.Close() + return nil + } + + // Start listening for events. + go func() { + for { + select { + case event, ok := <-configWatcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) { + for targetName, watchFn := range CONFIG_FILE_UPDATES { + if event.Name == targetName { + err := watchFn(configDir, options) + log.Warn().Msgf("WatchConfigDirectory goroutine for %s: failed to update options: %+v", targetName, err) + } + } + } + case _, ok := <-configWatcher.Errors: + if !ok { + return + } + log.Error().Msgf("WatchConfigDirectory goroutine error: %+v", err) + } + } + }() + + // Add a path. + err = configWatcher.Add(configDir) + if err != nil { + return ret, fmt.Errorf("unable to establish watch on the LocalAI Configuration Directory: %+v", err) + } + + return ret, nil +} diff --git a/core/startup/startup.go b/core/startup/startup.go new file mode 100644 index 000000000000..bf2c01177dcb --- /dev/null +++ b/core/startup/startup.go @@ -0,0 +1,93 @@ +package startup + +import ( + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/internal" + "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func Startup(opts ...datamodel.AppOption) (*services.ConfigLoader, *model.ModelLoader, *datamodel.StartupOptions, error) { + options := datamodel.NewStartupOptions(opts...) + + ml := model.NewModelLoader(options.ModelPath) + + zerolog.SetGlobalLevel(zerolog.InfoLevel) + if options.Debug { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.ModelPath) + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) + + cl := services.NewConfigLoader() + if err := cl.LoadConfigs(options.ModelPath); err != nil { + log.Error().Msgf("error loading config files: %s", err.Error()) + } + + if options.ConfigFile != "" { + if err := cl.LoadConfigFile(options.ConfigFile); err != nil { + log.Error().Msgf("error loading config file: %s", err.Error()) + } + } + + if err := cl.Preload(options.ModelPath); err != nil { + log.Error().Msgf("error downloading models: %s", err.Error()) + } + + if options.Debug { + for _, v := range cl.ListConfigs() { + cfg, _ := cl.GetConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) + } + } + + if options.AssetsDestination != "" { + // Extract files from the embedded FS + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) + if err != nil { + log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) + } + } + + if options.PreloadJSONModels != "" { + if err := services.ApplyGalleryFromString(options.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { + return nil, nil, nil, err + } + } + + if options.PreloadModelsFromPath != "" { + if err := services.ApplyGalleryFromFile(options.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { + return nil, nil, nil, err + } + } + + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + ml.StopAllGRPC() + }() + + if options.WatchDog { + wd := model.NewWatchDog( + ml, + options.WatchDogBusyTimeout, + options.WatchDogIdleTimeout, + options.WatchDogBusy, + options.WatchDogIdle) + ml.SetWatchDog(wd) + go wd.Run() + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + wd.Shutdown() + }() + } + + return cl, ml, options, nil +} diff --git a/docs/content/advanced/development.md b/docs/content/advanced/development.md index 9f73b8a5b84f..afc6ce3bad8d 100644 --- a/docs/content/advanced/development.md +++ b/docs/content/advanced/development.md @@ -17,6 +17,53 @@ This section will collect how-to, notes and development documentation We use conventional commits and semantic versioning. Please follow the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) specification when writing commit messages. +## LocalAI Project Structure + +**LocalAI is made of multiple components, developed in multiple repositories:** + +The core repository, containing the primary `local-ai` server code, gRPC stubs, this documentation website, and docker container building resources are all located at [mudler/LocalAI](https://github.com/mudler/LocalAI). + +As LocalAI is designed to make use of multiple, independent model galleries, those are maintained seperately. The following public model galleries are available for use: + +* [go-skynet/model-gallery](https://github.com/go-skynet/model-gallery) - The original gallery, the `golang` huggingface scraper ran into limits and was largely retired, so this now holds handmade yaml configs +* [dave-gray101/model-gallery](https://github.com/dave-gray101/model-gallery) - An automated gallery designed to track HuggingFace uploads and produce best-effort automatically generated configurations for LocalAI. It is designed to produce one LocalAI gallery per repository on HuggingFace. + +### Directory Structure of this Repo + +The core repository is broken up into the following primary chunks: + +* `/backend`: gRPC protobuf specification and gRPC backends. Subfolders for each language. +* **`/core`**: golang sourcecode for the core LocalAI application. Broken down below. +* `/docs`: localai.io website that you are reading now +* `/examples`: example code integrating LocalAI to other projects and/or developer samples and tools +* `/internal`: **here be dragons**. Don't touch this, it's used for automatic versioning. +* `/models`: _No code here!_ This is where models are installed! +* **`/pkg`**: golang sourcecode that is intended to be reusable or at least widely imported across LocalAI. Broken down below +* `/prompt-templates`: _No code here!_ This is where **example** prompt templates were historically stored. Somewhat obsolete these days, model-galleries tend to replace manually creating these? +* `/tests`: Does what it says on the tin. Please write tests and put them here when you do. + +The `core` folder is broken down further: + +* **`/core/backend`**: code that interacts with a gRPC backend to perform AI tasks. +* `/core/http`: code specifically related to the REST server +* `/core/http/endpoints`: Has two subdirectories, `openai` and `localai` for binding the respective endpoints to the correct backend or service. +* `/core/mqtt`: core specifically related to the MQTT server. Stub for now. Coming soon! +* **`/core/services`**: code implementing functionality performed by `local-ai` itself, rather than delegated to a backend. +* `/core/startup`: code related specifically to application startup of `local-ai`. Potentially to be refactored to become a part of `/core/services` at a later date, or not. + +The `pkg` folder is broken down further: + +* `/pkg/assets`: Currently contains a single function related to extracting files from archives. Potentially to be refactored to become a part of `/core/utils` at a later date? +* **`/pkg/datamodel`**: Contains the data types and definitions used by the LocalAI project. Imported widely! +* `/pkg/gallery`: Code related to interacting with a `model-gallery` +* `/pkg/grammar`: Code related to BNF / functions for LLM +* `/pkg/grpc`: base classes and interfaces for gRPC backends to implement +* `/pkg/langchain`: langchain related code in golang +* **`/pkg/model`**: Code related to loading and initializing a model and creating the appropriate gRPC backend. +* `/pkg/stablediffusion`: Code related to stablediffusion in golang. +* `/pkg/utils`: Every real programmer knows what they are going to find in here... it's our junk drawer of utility functions. + + ## Creating a gRPC backend LocalAI backends are `gRPC` servers. diff --git a/docs/content/features/text-to-audio.md b/docs/content/features/text-to-audio.md index ab038d2f5e5b..88aba2f1f7da 100644 --- a/docs/content/features/text-to-audio.md +++ b/docs/content/features/text-to-audio.md @@ -20,7 +20,7 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{ Returns an `audio/wav` file. -#### Setup +#### Text-To-Speech Setup LocalAI supports [bark]({{%relref "model-compatibility/bark" %}}) , `piper` and `vall-e-x`: @@ -52,6 +52,8 @@ Note: - The model name is case sensitive. - LocalAI must be compiled with the `GO_TAGS=tts` flag. +#### Music + LocalAI also has experimental support for `transformers-musicgen` for the generation of short musical compositions. Currently, this is implemented via the same requests used for text to speech: ``` @@ -62,7 +64,8 @@ curl --request POST \ "backend": "transformers-musicgen", "model": "facebook/musicgen-medium", "input": "Cello Rave" -}' | aplay``` +}' | aplay +``` Future versions of LocalAI will expose additional control over audio generation beyond the text prompt. diff --git a/go.mod b/go.mod index 250a2361796f..acb9eb267ea3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df + github.com/fsnotify/fsnotify v1.7.0 github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 diff --git a/main.go b/main.go index 97b258c088ce..a10babc27b70 100644 --- a/main.go +++ b/main.go @@ -12,14 +12,14 @@ import ( "syscall" "time" - api "github.com/go-skynet/LocalAI/api" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/http" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/core/startup" "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/metrics" + "github.com/go-skynet/LocalAI/pkg/datamodel" "github.com/go-skynet/LocalAI/pkg/gallery" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" progressbar "github.com/schollz/progressbar/v3" @@ -185,6 +185,12 @@ func main() { EnvVars: []string{"PRELOAD_BACKEND_ONLY"}, Value: false, }, + &cli.StringFlag{ + Name: "localai-config-dir", + Usage: "Directory to use for the configuration files of LocalAI itself. This is NOT where model files should be placed.", + EnvVars: []string{"LOCALAI_CONFIG_DIR"}, + Value: "./config", + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -203,53 +209,53 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - opts := []options.AppOption{ - options.WithConfigFile(ctx.String("config-file")), - options.WithJSONStringPreload(ctx.String("preload-models")), - options.WithYAMLConfigPreload(ctx.String("preload-models-config")), - options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), - options.WithContextSize(ctx.Int("context-size")), - options.WithDebug(ctx.Bool("debug")), - options.WithImageDir(ctx.String("image-path")), - options.WithAudioDir(ctx.String("audio-path")), - options.WithF16(ctx.Bool("f16")), - options.WithStringGalleries(ctx.String("galleries")), - options.WithDisableMessage(false), - options.WithCors(ctx.Bool("cors")), - options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - options.WithThreads(ctx.Int("threads")), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - options.WithUploadLimitMB(ctx.Int("upload-limit")), - options.WithApiKeys(ctx.StringSlice("api-keys")), + opts := []datamodel.AppOption{ + datamodel.WithConfigFile(ctx.String("config-file")), + datamodel.WithJSONStringPreload(ctx.String("preload-models")), + datamodel.WithYAMLConfigPreload(ctx.String("preload-models-config")), + datamodel.WithModelPath(ctx.String("models-path")), + datamodel.WithContextSize(ctx.Int("context-size")), + datamodel.WithDebug(ctx.Bool("debug")), + datamodel.WithImageDir(ctx.String("image-path")), + datamodel.WithAudioDir(ctx.String("audio-path")), + datamodel.WithF16(ctx.Bool("f16")), + datamodel.WithStringGalleries(ctx.String("galleries")), + datamodel.WithDisableMessage(false), + datamodel.WithCors(ctx.Bool("cors")), + datamodel.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + datamodel.WithThreads(ctx.Int("threads")), + datamodel.WithBackendAssets(backendAssets), + datamodel.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + datamodel.WithUploadLimitMB(ctx.Int("upload-limit")), + datamodel.WithApiKeys(ctx.StringSlice("api-keys")), } idleWatchDog := ctx.Bool("enable-watchdog-idle") busyWatchDog := ctx.Bool("enable-watchdog-busy") if idleWatchDog || busyWatchDog { - opts = append(opts, options.EnableWatchDog) + opts = append(opts, datamodel.EnableWatchDog) if idleWatchDog { - opts = append(opts, options.EnableWatchDogIdleCheck) + opts = append(opts, datamodel.EnableWatchDogIdleCheck) dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout")) if err != nil { return err } - opts = append(opts, options.SetWatchDogIdleTimeout(dur)) + opts = append(opts, datamodel.SetWatchDogIdleTimeout(dur)) } if busyWatchDog { - opts = append(opts, options.EnableWatchDogBusyCheck) + opts = append(opts, datamodel.EnableWatchDogBusyCheck) dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout")) if err != nil { return err } - opts = append(opts, options.SetWatchDogBusyTimeout(dur)) + opts = append(opts, datamodel.SetWatchDogBusyTimeout(dur)) } } if ctx.Bool("parallel-requests") { - opts = append(opts, options.EnableParallelBackendRequests) + opts = append(opts, datamodel.EnableParallelBackendRequests) } if ctx.Bool("single-active-backend") { - opts = append(opts, options.EnableSingleBackend) + opts = append(opts, datamodel.EnableSingleBackend) } externalgRPC := ctx.StringSlice("external-grpc-backends") @@ -257,30 +263,42 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit for _, v := range externalgRPC { backend := v[:strings.IndexByte(v, ':')] uri := v[strings.IndexByte(v, ':')+1:] - opts = append(opts, options.WithExternalBackend(backend, uri)) + opts = append(opts, datamodel.WithExternalBackend(backend, uri)) } if ctx.Bool("autoload-galleries") { - opts = append(opts, options.EnableGalleriesAutoload) + opts = append(opts, datamodel.EnableGalleriesAutoload) } if ctx.Bool("preload-backend-only") { - _, _, err := api.Startup(opts...) + _, _, _, err := startup.Startup(opts...) return err } - metrics, err := metrics.SetupMetrics() + metrics, err := services.SetupMetrics() if err != nil { return err } - opts = append(opts, options.WithMetrics(metrics)) + opts = append(opts, datamodel.WithMetrics(metrics)) + + cl, ml, options, err := startup.Startup(opts...) + if err != nil { + return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) + } + + closeConfigWatcherFn, err := startup.WatchConfigDirectory(ctx.String("localai-config-dir"), options) - app, err := api.App(opts...) + defer closeConfigWatcherFn() + if err != nil { + return fmt.Errorf("failed while watching configuration directory %s", ctx.String("localai-config-dir")) + } + + appHTTP, err := http.App(cl, ml, options) if err != nil { return err } - return app.Listen(ctx.String("address")) + return appHTTP.Listen(ctx.String("address")) }, Commands: []*cli.Command{ { @@ -378,16 +396,18 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit text := strings.Join(ctx.Args().Slice(), " ") - opts := &options.Option{ - Loader: model.NewModelLoader(ctx.String("models-path")), + opts := &datamodel.StartupOptions{ + ModelPath: ctx.String("models-path"), Context: context.Background(), AudioDir: outputDir, AssetsDestination: ctx.String("backend-assets-path"), } - defer opts.Loader.StopAllGRPC() + loader := model.NewModelLoader(opts.ModelPath) - filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, opts.Loader, opts) + defer loader.StopAllGRPC() + + filePath, _, err := backend.ModelTTS(backendOption, text, modelOption, loader, opts) if err != nil { return err } @@ -440,13 +460,15 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit language := ctx.String("language") threads := ctx.Int("threads") - opts := &options.Option{ - Loader: model.NewModelLoader(ctx.String("models-path")), + opts := &datamodel.StartupOptions{ + ModelPath: ctx.String("models-path"), Context: context.Background(), AssetsDestination: ctx.String("backend-assets-path"), } - cl := config.NewConfigLoader() + ml := model.NewModelLoader(opts.ModelPath) + + cl := services.NewConfigLoader() if err := cl.LoadConfigs(ctx.String("models-path")); err != nil { return err } @@ -458,9 +480,9 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit c.Threads = threads - defer opts.Loader.StopAllGRPC() + defer ml.StopAllGRPC() - tr, err := backend.ModelTranscription(filename, language, opts.Loader, c, opts) + tr, err := backend.ModelTranscription(filename, language, ml, c, opts) if err != nil { return err } diff --git a/api/config/config.go b/pkg/datamodel/config.go similarity index 60% rename from api/config/config.go rename to pkg/datamodel/config.go index 7ed7061af917..048aaf211bfe 100644 --- a/api/config/config.go +++ b/pkg/datamodel/config.go @@ -1,15 +1,11 @@ -package api_config +package datamodel import ( + "encoding/json" "fmt" - "io/fs" "os" - "path/filepath" - "strings" - "sync" "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -141,11 +137,6 @@ type TemplateConfig struct { Functions string `yaml:"function"` } -type ConfigLoader struct { - configs map[string]Config - sync.Mutex -} - func (c *Config) SetFunctionCallString(s string) { c.functionCallString = s } @@ -182,11 +173,6 @@ func DefaultConfig(modelFile string) *Config { } } -func NewConfigLoader() *ConfigLoader { - return &ConfigLoader{ - configs: make(map[string]Config), - } -} func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) @@ -200,7 +186,7 @@ func ReadConfigFile(file string) ([]*Config, error) { return *c, nil } -func ReadConfig(file string) (*Config, error) { +func ReadSingleConfigFile(file string) (*Config, error) { c := &Config{} f, err := os.ReadFile(file) if err != nil { @@ -213,114 +199,189 @@ func ReadConfig(file string) (*Config, error) { return c, nil } -func (cm *ConfigLoader) LoadConfigFile(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfigFile(file) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) +func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP } - for _, cc := range c { - cm.configs[cc.Name] = *cc + if input.Backend != "" { + config.Backend = input.Backend } - return nil -} -func (cm *ConfigLoader) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfig(file) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip } - cm.configs[c.Name] = *c - return nil -} + if input.ModelBaseName != "" { + config.AutoGPTQ.ModelBaseName = input.ModelBaseName + } -func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] - return v, exists -} + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } -func (cm *ConfigLoader) GetAllConfigs() []Config { - cm.Lock() - defer cm.Unlock() - var res []Config - for _, v := range cm.configs { - res = append(res, v) + if input.UseFastTokenizer { + config.UseFastTokenizer = input.UseFastTokenizer } - return res -} -func (cm *ConfigLoader) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() - var res []string - for k := range cm.configs { - res = append(res, k) + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt } - return res -} -func (cm *ConfigLoader) Preload(modelPath string) error { - cm.Lock() - defer cm.Unlock() - - for i, config := range cm.configs { - modelURL := config.PredictionOptions.Model - modelURL = utils.ConvertURL(modelURL) - if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { - log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent) - }) - if err != nil { - return err - } + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) } + } + } - cc := cm.configs[i] - c := &cc - c.PredictionOptions.Model = md5Name - cm.configs[i] = *c + // Decode each request's message content + index := 0 + for i, m := range input.Messages { + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []Content{} + json.Unmarshal(dat, &c) + for _, pp := range c { + if pp.Type == "text" { + input.Messages[i].StringContent = pp.Text + } else if pp.Type == "image_url" { + // Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64: + base64, err := utils.GetBase64Image(pp.ImageURL.URL) + if err == nil { + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + // set a placeholder for each image + input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent + index++ + } else { + fmt.Print("Failed encoding image", err) + } + } + } } } - return nil -} -func (cm *ConfigLoader) LoadConfigs(path string) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.LLMConfig.Mirostat = input.Mirostat } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err + + if input.MirostatETA != 0 { + config.LLMConfig.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.LLMConfig.MirostatTAU = input.MirostatTAU + } + + if input.TypicalP != 0 { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } } - files = append(files, info) } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") { - continue + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) } - c, err := ReadConfig(filepath.Join(path, file.Name())) - if err == nil { - cm.configs[c.Name] = *c + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } } + config.SetFunctionCallNameString(name) } - return nil + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } } diff --git a/api/config/config_test.go b/pkg/datamodel/config_test.go similarity index 74% rename from api/config/config_test.go rename to pkg/datamodel/config_test.go index 4b00d587eff2..4c6d550bcebc 100644 --- a/api/config/config_test.go +++ b/pkg/datamodel/config_test.go @@ -1,11 +1,10 @@ -package api_config_test +package datamodel_test import ( "os" - . "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/core/services" + "github.com/go-skynet/LocalAI/pkg/datamodel" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -19,7 +18,7 @@ var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") It("Test ReadConfigFile", func() { - config, err := ReadConfigFile(configFile) + config, err := datamodel.ReadConfigFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml @@ -28,12 +27,8 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewConfigLoader() - opts := options.NewOptions() - modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) - options.WithModelLoader(modelLoader)(opts) - - err := cm.LoadConfigs(opts.Loader.ModelPath) + cm := services.NewConfigLoader() + err := cm.LoadConfigs(os.Getenv("MODELS_PATH")) Expect(err).To(BeNil()) Expect(cm.ListConfigs()).ToNot(BeNil()) diff --git a/pkg/datamodel/localai.go b/pkg/datamodel/localai.go new file mode 100644 index 000000000000..d8e042c69fb5 --- /dev/null +++ b/pkg/datamodel/localai.go @@ -0,0 +1,39 @@ +package datamodel + +import ( + "context" + + gopsutil "github.com/shirou/gopsutil/v3/process" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +type BackendMonitorRequest struct { + Model string `json:"model" yaml:"model"` +} + +type BackendMonitorResponse struct { + MemoryInfo *gopsutil.MemoryInfoStat + MemoryPercent float32 + CPUPercent float64 +} + +type TTSRequest struct { + Model string `json:"model" yaml:"model"` + Input string `json:"input" yaml:"input"` + Backend string `json:"backend" yaml:"backend"` +} + +type LocalAIMetrics struct { + Meter metric.Meter + ApiTimeMetric metric.Float64Histogram +} + +func (m *LocalAIMetrics) ObserveAPICall(method string, path string, duration float64) { + opts := metric.WithAttributes( + attribute.String("method", method), + attribute.String("path", path), + ) + m.ApiTimeMetric.Record(context.Background(), duration, opts) +} diff --git a/api/schema/openai.go b/pkg/datamodel/openai.go similarity index 97% rename from api/schema/openai.go rename to pkg/datamodel/openai.go index 6355ff63d5e2..76c7ee10f919 100644 --- a/api/schema/openai.go +++ b/pkg/datamodel/openai.go @@ -1,10 +1,8 @@ -package schema +package datamodel import ( "context" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/pkg/grammar" ) @@ -90,7 +88,7 @@ type ChatCompletionResponseFormat struct { } type OpenAIRequest struct { - config.PredictionOptions + PredictionOptions Context context.Context Cancel context.CancelFunc diff --git a/api/config/prediction.go b/pkg/datamodel/prediction.go similarity index 99% rename from api/config/prediction.go rename to pkg/datamodel/prediction.go index d2fbb1fa9687..685d909edd5c 100644 --- a/api/config/prediction.go +++ b/pkg/datamodel/prediction.go @@ -1,4 +1,4 @@ -package api_config +package datamodel type PredictionOptions struct { diff --git a/api/options/options.go b/pkg/datamodel/startup_options.go similarity index 68% rename from api/options/options.go rename to pkg/datamodel/startup_options.go index 127d06f0e30c..50750f55cb00 100644 --- a/api/options/options.go +++ b/pkg/datamodel/startup_options.go @@ -1,4 +1,4 @@ -package options +package datamodel import ( "context" @@ -6,16 +6,14 @@ import ( "encoding/json" "time" - "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) -type Option struct { +type StartupOptions struct { Context context.Context ConfigFile string - Loader *model.ModelLoader + ModelPath string UploadLimitMB, Threads, ContextSize int F16 bool Debug, DisableMessage bool @@ -26,7 +24,7 @@ type Option struct { PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string - Metrics *metrics.Metrics + Metrics *LocalAIMetrics Galleries []gallery.Gallery @@ -44,12 +42,14 @@ type Option struct { WatchDogBusy bool WatchDog bool WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration + + LocalAIConfigDir string } -type AppOption func(*Option) +type AppOption func(*StartupOptions) -func NewOptions(o ...AppOption) *Option { - opt := &Option{ +func NewStartupOptions(o ...AppOption) *StartupOptions { + opt := &StartupOptions{ Context: context.Background(), UploadLimitMB: 15, Threads: 1, @@ -64,51 +64,51 @@ func NewOptions(o ...AppOption) *Option { } func WithCors(b bool) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.CORS = b } } -var EnableWatchDog = func(o *Option) { +var EnableWatchDog = func(o *StartupOptions) { o.WatchDog = true } -var EnableWatchDogIdleCheck = func(o *Option) { +var EnableWatchDogIdleCheck = func(o *StartupOptions) { o.WatchDog = true o.WatchDogIdle = true } -var EnableWatchDogBusyCheck = func(o *Option) { +var EnableWatchDogBusyCheck = func(o *StartupOptions) { o.WatchDog = true o.WatchDogBusy = true } func SetWatchDogBusyTimeout(t time.Duration) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.WatchDogBusyTimeout = t } } func SetWatchDogIdleTimeout(t time.Duration) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.WatchDogIdleTimeout = t } } -var EnableSingleBackend = func(o *Option) { +var EnableSingleBackend = func(o *StartupOptions) { o.SingleBackend = true } -var EnableParallelBackendRequests = func(o *Option) { +var EnableParallelBackendRequests = func(o *StartupOptions) { o.ParallelBackendRequests = true } -var EnableGalleriesAutoload = func(o *Option) { +var EnableGalleriesAutoload = func(o *StartupOptions) { o.AutoloadGalleries = true } func WithExternalBackend(name string, uri string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { if o.ExternalGRPCBackends == nil { o.ExternalGRPCBackends = make(map[string]string) } @@ -117,25 +117,25 @@ func WithExternalBackend(name string, uri string) AppOption { } func WithCorsAllowOrigins(b string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.BackendAssets = f } } func WithStringGalleries(galls string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { if galls == "" { log.Debug().Msgf("no galleries to load") o.Galleries = []gallery.Gallery{} @@ -150,96 +150,102 @@ func WithStringGalleries(galls string) AppOption { } func WithGalleries(galleries []gallery.Gallery) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.ConfigFile = configFile } } -func WithModelLoader(loader *model.ModelLoader) AppOption { - return func(o *Option) { - o.Loader = loader +func WithModelPath(path string) AppOption { + return func(o *StartupOptions) { + o.ModelPath = path } } func WithUploadLimitMB(limit int) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.F16 = f16 } } func WithDebug(debug bool) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.ImageDir = imageDir } } func WithApiKeys(apiKeys []string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.ApiKeys = apiKeys } } -func WithMetrics(meter *metrics.Metrics) AppOption { - return func(o *Option) { - o.Metrics = meter +func WithMetrics(metrics *LocalAIMetrics) AppOption { + return func(o *StartupOptions) { + o.Metrics = metrics + } +} + +func WithLocalAIConfigDir(configDir string) AppOption { + return func(o *StartupOptions) { + o.LocalAIConfigDir = configDir } } diff --git a/api/schema/whisper.go b/pkg/datamodel/whisper.go similarity index 55% rename from api/schema/whisper.go rename to pkg/datamodel/whisper.go index 41413c1f06ed..8844de81e7d6 100644 --- a/api/schema/whisper.go +++ b/pkg/datamodel/whisper.go @@ -1,8 +1,8 @@ -package schema +package datamodel import "time" -type Segment struct { +type WhisperSegment struct { Id int `json:"id"` Start time.Duration `json:"start"` End time.Duration `json:"end"` @@ -10,7 +10,7 @@ type Segment struct { Tokens []int `json:"tokens"` } -type Result struct { - Segments []Segment `json:"segments"` - Text string `json:"text"` +type WhisperResult struct { + Segments []WhisperSegment `json:"segments"` + Text string `json:"text"` } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 7957ed59d638..4aeb3172fa6a 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -22,11 +22,11 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") - var config Config + var config InstallableModel if len(model.URL) > 0 { var err error - config, err = GetGalleryConfigFromURL(model.URL) + config, err = GetInstallableModelFromURL(model.URL) if err != nil { return err } @@ -36,7 +36,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, if err != nil { return err } - config = Config{ + config = InstallableModel{ ConfigFile: string(reYamlConfig), Description: model.Description, License: model.License, diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 9a1697981614..2e8770f17625 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -1,13 +1,9 @@ package gallery import ( - "crypto/sha256" "fmt" - "hash" - "io" "os" "path/filepath" - "strconv" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" @@ -41,9 +37,9 @@ prompt_templates: content: "" */ -// Config is the model configuration which contains all the model details +// InstallableModel is the model configuration which contains all the model details // This configuration is read from the gallery endpoint and is used to download and install the model -type Config struct { +type InstallableModel struct { Description string `yaml:"description"` License string `yaml:"license"` URLs []string `yaml:"urls"` @@ -64,8 +60,8 @@ type PromptTemplate struct { Content string `yaml:"content"` } -func GetGalleryConfigFromURL(url string) (Config, error) { - var config Config +func GetInstallableModelFromURL(url string) (InstallableModel, error) { + var config InstallableModel err := utils.GetURI(url, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) @@ -76,7 +72,7 @@ func GetGalleryConfigFromURL(url string) (Config, error) { return config, nil } -func ReadConfigFile(filePath string) (*Config, error) { +func ReadInstallableModelFile(filePath string) (*InstallableModel, error) { // Read the YAML file yamlFile, err := os.ReadFile(filePath) if err != nil { @@ -84,7 +80,7 @@ func ReadConfigFile(filePath string) (*Config, error) { } // Unmarshal YAML data into a Config struct - var config Config + var config InstallableModel err = yaml.Unmarshal(yamlFile, &config) if err != nil { return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) @@ -93,7 +89,7 @@ func ReadConfigFile(filePath string) (*Config, error) { return &config, nil } -func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { +func InstallModel(basePath, nameOverride string, config *InstallableModel, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64)) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { @@ -183,54 +179,3 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides return nil } - -type progressWriter struct { - fileName string - total int64 - written int64 - downloadStatus func(string, string, string, float64) - hash hash.Hash -} - -func (pw *progressWriter) Write(p []byte) (n int, err error) { - n, err = pw.hash.Write(p) - pw.written += int64(n) - - if pw.total > 0 { - percentage := float64(pw.written) / float64(pw.total) * 100 - //log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) - pw.downloadStatus(pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage) - } else { - pw.downloadStatus(pw.fileName, formatBytes(pw.written), "", 0) - } - - return -} - -func formatBytes(bytes int64) string { - const unit = 1024 - if bytes < unit { - return strconv.FormatInt(bytes, 10) + " B" - } - div, exp := int64(unit), 0 - for n := bytes / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) -} - -func calculateSHA(filePath string) (string, error) { - file, err := os.Open(filePath) - if err != nil { - return "", err - } - defer file.Close() - - hash := sha256.New() - if _, err := io.Copy(hash, file); err != nil { - return "", err - } - - return fmt.Sprintf("%x", hash.Sum(nil)), nil -} diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index f454c6111aea..96ed17e06b6a 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -16,7 +16,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}) @@ -87,7 +87,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) @@ -103,7 +103,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}) @@ -129,7 +129,7 @@ var _ = Describe("Model test", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) - c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + c, err := ReadInstallableModelFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}) diff --git a/pkg/gallery/op.go b/pkg/gallery/op.go new file mode 100644 index 000000000000..873c356d3056 --- /dev/null +++ b/pkg/gallery/op.go @@ -0,0 +1,18 @@ +package gallery + +type GalleryOp struct { + Req GalleryModel + Id string + Galleries []Gallery + GalleryName string +} + +type GalleryOpStatus struct { + FileName string `json:"file_name"` + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` + Progress float64 `json:"progress"` + TotalFileSize string `json:"file_size"` + DownloadedFileSize string `json:"downloaded_size"` +} diff --git a/pkg/gallery/request_test.go b/pkg/gallery/request_test.go index a9d54e325042..017167d908f9 100644 --- a/pkg/gallery/request_test.go +++ b/pkg/gallery/request_test.go @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() { Context("requests", func() { It("parses github with a branch", func() { req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} - e, err := GetGalleryConfigFromURL(req.URL) + e, err := GetInstallableModelFromURL(req.URL) Expect(err).ToNot(HaveOccurred()) Expect(e.Name).To(Equal("gpt4all-j")) }) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 739d1cbbe6bb..432c44edb478 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/datamodel" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" gopsutil "github.com/shirou/gopsutil/v3/process" ) @@ -53,8 +53,9 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { return fmt.Errorf("unimplemented") } -func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) { - return schema.Result{}, fmt.Errorf("unimplemented") +// TODO CHECK THIS +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (datamodel.WhisperResult, error) { + return datamodel.WhisperResult{}, fmt.Errorf("unimplemented") } func (llm *Base) TTS(*pb.TTSRequest) error { diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 9eab356d487c..bcdfbf2b9c62 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/datamodel" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -223,7 +223,7 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp return client.TTS(ctx, in, opts...) } -func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.Result, error) { +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*datamodel.WhisperResult, error) { if !c.parallel { c.opMutex.Lock() defer c.opMutex.Unlock() @@ -244,14 +244,14 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques if err != nil { return nil, err } - tresult := &schema.Result{} + tresult := &datamodel.WhisperResult{} for _, s := range res.Segments { tks := []int{} for _, t := range s.Tokens { tks = append(tks, int(t)) } tresult.Segments = append(tresult.Segments, - schema.Segment{ + datamodel.WhisperSegment{ Text: s.Text, Id: int(s.Id), Start: time.Duration(s.Start), diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index a76261c15ce9..2f41cde46762 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,7 +1,7 @@ package grpc import ( - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/datamodel" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) @@ -15,7 +15,7 @@ type LLM interface { Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) GenerateImage(*pb.GenerateImageRequest) error - AudioTranscription(*pb.TranscriptRequest) (schema.Result, error) + AudioTranscription(*pb.TranscriptRequest) (datamodel.WhisperResult, error) TTS(*pb.TTSRequest) error TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) Status() (pb.StatusResponse, error) diff --git a/pkg/grpc/proto/backend.pb.go b/pkg/grpc/proto/backend.pb.go index b9569785eef6..2e4a2e9b22fa 100644 --- a/pkg/grpc/proto/backend.pb.go +++ b/pkg/grpc/proto/backend.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.1 -// protoc v3.6.1 +// protoc-gen-go v1.26.0 +// protoc v4.26.0 // source: backend.proto package proto diff --git a/pkg/grpc/proto/backend_grpc.pb.go b/pkg/grpc/proto/backend_grpc.pb.go index d41f77a61446..41a1ba55aadd 100644 --- a/pkg/grpc/proto/backend_grpc.pb.go +++ b/pkg/grpc/proto/backend_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.2.0 -// - protoc v3.6.1 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v4.26.0 // source: backend.proto package proto @@ -18,6 +18,19 @@ import ( // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 +const ( + Backend_Health_FullMethodName = "/backend.Backend/Health" + Backend_Predict_FullMethodName = "/backend.Backend/Predict" + Backend_LoadModel_FullMethodName = "/backend.Backend/LoadModel" + Backend_PredictStream_FullMethodName = "/backend.Backend/PredictStream" + Backend_Embedding_FullMethodName = "/backend.Backend/Embedding" + Backend_GenerateImage_FullMethodName = "/backend.Backend/GenerateImage" + Backend_AudioTranscription_FullMethodName = "/backend.Backend/AudioTranscription" + Backend_TTS_FullMethodName = "/backend.Backend/TTS" + Backend_TokenizeString_FullMethodName = "/backend.Backend/TokenizeString" + Backend_Status_FullMethodName = "/backend.Backend/Status" +) + // BackendClient is the client API for Backend service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. @@ -44,7 +57,7 @@ func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Health_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -53,7 +66,7 @@ func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...g func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { out := new(Reply) - err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Predict_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -62,7 +75,7 @@ func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts .. func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_LoadModel_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -70,7 +83,7 @@ func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts .. } func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { - stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) + stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], Backend_PredictStream_FullMethodName, opts...) if err != nil { return nil, err } @@ -103,7 +116,7 @@ func (x *backendPredictStreamClient) Recv() (*Reply, error) { func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { out := new(EmbeddingResult) - err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Embedding_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -112,7 +125,7 @@ func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_GenerateImage_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -121,7 +134,7 @@ func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { out := new(TranscriptResult) - err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_AudioTranscription_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -130,7 +143,7 @@ func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRe func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { out := new(Result) - err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_TTS_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -139,7 +152,7 @@ func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.Ca func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*TokenizationResponse, error) { out := new(TokenizationResponse) - err := c.cc.Invoke(ctx, "/backend.Backend/TokenizeString", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_TokenizeString_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -148,7 +161,7 @@ func (c *backendClient) TokenizeString(ctx context.Context, in *PredictOptions, func (c *backendClient) Status(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*StatusResponse, error) { out := new(StatusResponse) - err := c.cc.Invoke(ctx, "/backend.Backend/Status", in, out, opts...) + err := c.cc.Invoke(ctx, Backend_Status_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -229,7 +242,7 @@ func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Health", + FullMethod: Backend_Health_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) @@ -247,7 +260,7 @@ func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(int } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Predict", + FullMethod: Backend_Predict_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) @@ -265,7 +278,7 @@ func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/LoadModel", + FullMethod: Backend_LoadModel_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) @@ -304,7 +317,7 @@ func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(i } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Embedding", + FullMethod: Backend_Embedding_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) @@ -322,7 +335,7 @@ func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec fu } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/GenerateImage", + FullMethod: Backend_GenerateImage_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) @@ -340,7 +353,7 @@ func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/AudioTranscription", + FullMethod: Backend_AudioTranscription_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) @@ -358,7 +371,7 @@ func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interfa } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/TTS", + FullMethod: Backend_TTS_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) @@ -376,7 +389,7 @@ func _Backend_TokenizeString_Handler(srv interface{}, ctx context.Context, dec f } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/TokenizeString", + FullMethod: Backend_TokenizeString_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).TokenizeString(ctx, req.(*PredictOptions)) @@ -394,7 +407,7 @@ func _Backend_Status_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/backend.Backend/Status", + FullMethod: Backend_Status_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(BackendServer).Status(ctx, req.(*HealthMessage)) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3195fac913c1..f7f54192cf85 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -8,7 +8,7 @@ import ( "strings" "time" - grpc "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/hashicorp/go-multierror" "github.com/phayes/freeport" "github.com/rs/zerolog/log" @@ -71,7 +71,7 @@ var AutoLoadBackends []string = []string{ // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string) (ModelAddress, error) { +func (ml *ModelLoader) grpcModel(backend string, o *ModelOptions) func(string, string) (ModelAddress, error) { return func(modelName, modelFile string) (ModelAddress, error) { log.Debug().Msgf("Loading Model %s with gRPC (file: %s) (backend: %s): %+v", modelName, modelFile, backend, *o) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index d02f9e84c959..aafa313b068a 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,7 +10,7 @@ import ( "sync" "text/template" - grammar "github.com/go-skynet/LocalAI/pkg/grammar" + "github.com/go-skynet/LocalAI/pkg/grammar" "github.com/go-skynet/LocalAI/pkg/grpc" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" diff --git a/pkg/model/options.go b/pkg/model/options.go index 5748be9be59e..f7cfbe1ad372 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -6,7 +6,7 @@ import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) -type Options struct { +type ModelOptions struct { backendString string model string threads uint32 @@ -23,14 +23,14 @@ type Options struct { parallelRequests bool } -type Option func(*Options) +type Option func(*ModelOptions) -var EnableParallelRequests = func(o *Options) { +var EnableParallelRequests = func(o *ModelOptions) { o.parallelRequests = true } func WithExternalBackend(name string, uri string) Option { - return func(o *Options) { + return func(o *ModelOptions) { if o.externalBackends == nil { o.externalBackends = make(map[string]string) } @@ -38,62 +38,81 @@ func WithExternalBackend(name string, uri string) Option { } } +// Currently, LocalAI isn't ready for backends to be yanked out from under it - so this is a little overcomplicated to allow non-overwriting updates +func WithExternalBackends(backends map[string]string, overwrite bool) Option { + return func(o *ModelOptions) { + if backends == nil { + return + } + if o.externalBackends == nil { + o.externalBackends = backends + return + } + for name, url := range backends { + _, exists := o.externalBackends[name] + if !exists || overwrite { + o.externalBackends[name] = url + } + } + } +} + func WithGRPCAttempts(attempts int) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.grpcAttempts = attempts } } func WithGRPCAttemptsDelay(delay int) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.grpcAttemptsDelay = delay } } func WithBackendString(backend string) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.backendString = backend } } func WithModel(modelFile string) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.model = modelFile } } func WithLoadGRPCLoadModelOpts(opts *pb.ModelOptions) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.gRPCOptions = opts } } func WithThreads(threads uint32) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.threads = threads } } func WithAssetDir(assetDir string) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.assetDir = assetDir } } func WithContext(ctx context.Context) Option { - return func(o *Options) { + return func(o *ModelOptions) { o.context = ctx } } func WithSingleActiveBackend() Option { - return func(o *Options) { + return func(o *ModelOptions) { o.singleActiveBackend = true } } -func NewOptions(opts ...Option) *Options { - o := &Options{ +func NewOptions(opts ...Option) *ModelOptions { + o := &ModelOptions{ gRPCOptions: &pb.ModelOptions{}, context: context.Background(), grpcAttempts: 20, diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 000000000000..fbeca6e5c9fd --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,81 @@ +package utils + +import ( + "bufio" + "encoding/base64" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + + "github.com/rs/zerolog/log" +) + +func CreateTempFileFromMultipartFile(file *multipart.FileHeader, tempDir string, tempPattern string) (string, error) { + + f, err := file.Open() + if err != nil { + return "", err + } + defer f.Close() + + // Create a temporary file in the requested directory: + outputFile, err := os.CreateTemp(tempDir, tempPattern) + if err != nil { + return "", err + } + defer outputFile.Close() + + if _, err := io.Copy(outputFile, f); err != nil { + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, outputFile, err) + return "", err + } + + return outputFile.Name(), nil +} + +func CreateTempFileFromBase64(base64data string, tempDir string, tempPattern string) (string, error) { + if len(base64data) == 0 { + return "", fmt.Errorf("base64data empty?") + } + //base 64 decode the file and write it somewhere + // that we will cleanup + decoded, err := base64.StdEncoding.DecodeString(base64data) + if err != nil { + return "", err + } + // Create a temporary file in the requested directory: + outputFile, err := os.CreateTemp(tempDir, tempPattern) + if err != nil { + return "", err + } + defer outputFile.Close() + // write the base64 result + writer := bufio.NewWriter(outputFile) + _, err = writer.Write(decoded) + if err != nil { + return "", err + } + return outputFile.Name(), nil +} + +func CreateTempFileFromUrl(url string, tempDir string, tempPattern string) (string, error) { + // Get the data + resp, err := http.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Create the file + out, err := os.CreateTemp(tempDir, tempPattern) + if err != nil { + return "", err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return out.Name(), err +} diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index 260e65a9a1a8..96d1d7cfc17b 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -3,6 +3,7 @@ package utils import ( "crypto/md5" "crypto/sha256" + "encoding/base64" "fmt" "hash" "io" @@ -71,6 +72,37 @@ func GetURI(url string, f func(url string, i []byte) error) error { return f(url, body) } +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string +func GetBase64Image(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := http.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} + func ConvertURL(s string) string { switch { case strings.HasPrefix(s, "huggingface://"): diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go index c0fe7096a1d8..505ae9a56af0 100644 --- a/tests/integration/reflect_test.go +++ b/tests/integration/reflect_test.go @@ -3,16 +3,16 @@ package integration_test import ( "reflect" - config "github.com/go-skynet/LocalAI/api/config" - model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/datamodel" + "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Integration Tests involving reflection in liue of code generation", func() { - Context("config.TemplateConfig and model.TemplateType must stay in sync", func() { + Context("datamodel.TemplateConfig and model.TemplateType must stay in sync", func() { - ttc := reflect.TypeOf(config.TemplateConfig{}) + ttc := reflect.TypeOf(datamodel.TemplateConfig{}) It("TemplateConfig and TemplateType should have the same number of valid values", func() { const lastValidTemplateType = model.IntegrationTestTemplate - 1