From 6d6b066400bb8f10bface657ad364b7fd3360c6b Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Thu, 4 Jan 2024 01:41:56 -0500 Subject: [PATCH] fix pt1 --- core/services/config.go | 48 ++++++++++++++++++++---------- pkg/datamodel/config.go | 49 ------------------------------- pkg/datamodel/startup_options.go | 2 +- pkg/utils/uri.go | 50 ++++++++++++++++++++++++++++---- 4 files changed, 77 insertions(+), 72 deletions(-) diff --git a/core/services/config.go b/core/services/config.go index 1e76a8543f97..989737d491c3 100644 --- a/core/services/config.go +++ b/core/services/config.go @@ -1,6 +1,7 @@ package services import ( + "errors" "fmt" "io/fs" "os" @@ -93,35 +94,50 @@ func (cm *ConfigLoader) LoadConfigs(path string) error { return nil } -// TODO: Does this belong under ConfigLoader? -func (cl *ConfigLoader) Preload(modelPath string) error { - cl.Lock() - defer cl.Unlock() +// Preload prepare models if they are not local but url or huggingface repositories +func (cm *ConfigLoader) Preload(modelPath string) error { + cm.Lock() + defer cm.Unlock() + + status := func(fileName, current, total string, percent float64) { + utils.DisplayDownloadFunction(fileName, current, total, percent) + } + + log.Info().Msgf("Preloading models from %s", modelPath) + + for _, config := range cm.configs { + + // Download files and verify their SHA + for _, file := range config.DownloadFiles { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + if err := utils.VerifyPath(file.Filename, modelPath); err != nil { + return err + } + // Create file path + filePath := filepath.Join(modelPath, file.Filename) + + if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { + return err + } + } - for i, config := range cl.configs { modelURL := config.PredictionOptions.Model modelURL = utils.ConvertURL(modelURL) - if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + + if utils.LooksLikeURL(modelURL) { // md5 of model name md5Name := utils.MD5(modelURL) // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) { - log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent) - }) + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) if err != nil { return err } } - - cc := cl.configs[i] - c := &cc - c.PredictionOptions.Model = md5Name - cl.configs[i] = *c } } - return nil } func (cl *ConfigLoader) LoadConfigFile(file string) error { diff --git a/pkg/datamodel/config.go b/pkg/datamodel/config.go index 1f85b8d6723c..21d5cc922b01 100644 --- a/pkg/datamodel/config.go +++ b/pkg/datamodel/config.go @@ -2,13 +2,10 @@ package datamodel import ( "encoding/json" - "errors" "fmt" "os" - "path/filepath" "github.com/go-skynet/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -401,49 +398,3 @@ func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) { } } - -// Preload prepare models if they are not local but url or huggingface repositories -func (cm *ConfigLoader) Preload(modelPath string) error { - cm.Lock() - defer cm.Unlock() - - status := func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - } - - log.Info().Msgf("Preloading models from %s", modelPath) - - for i, config := range cm.configs { - - // Download files and verify their SHA - for _, file := range config.DownloadFiles { - log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) - - if err := utils.VerifyPath(file.Filename, modelPath); err != nil { - return err - } - // Create file path - filePath := filepath.Join(modelPath, file.Filename) - - if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil { - return err - } - } - - modelURL := config.PredictionOptions.Model - modelURL = utils.ConvertURL(modelURL) - - if utils.LooksLikeURL(modelURL) { - // md5 of model name - md5Name := utils.MD5(modelURL) - - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { - err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status) - if err != nil { - return err - } - } - } - } -} diff --git a/pkg/datamodel/startup_options.go b/pkg/datamodel/startup_options.go index 64f21200a25e..5725df28c1af 100644 --- a/pkg/datamodel/startup_options.go +++ b/pkg/datamodel/startup_options.go @@ -67,7 +67,7 @@ func NewStartupOptions(o ...AppOption) *StartupOptions { } func WithModelsURL(urls ...string) AppOption { - return func(o *Option) { + return func(o *StartupOptions) { o.ModelsURL = urls } } diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index d6cfdf4be4bd..45e842bd1dcf 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -3,12 +3,14 @@ package utils import ( "crypto/md5" "crypto/sha256" + "encoding/base64" "fmt" "hash" "io" "net/http" "os" "path/filepath" + "slices" "strconv" "strings" @@ -23,6 +25,16 @@ const ( GithubURI2 = "github://" ) +func getRecognizedURIPrefixes() []string { + return []string{ + HuggingFacePrefix, + HTTPPrefix, + HTTPSPrefix, + GithubURI, + GithubURI2, + } +} + func GetURI(url string, f func(url string, i []byte) error) error { url = ConvertURL(url) @@ -60,13 +72,8 @@ func GetURI(url string, f func(url string, i []byte) error) error { return f(url, body) } -// TODO: Refactor to use a slice of constants and slices.Contains() for easier maintenance? func LooksLikeURL(s string) bool { - return strings.HasPrefix(s, HTTPPrefix) || - strings.HasPrefix(s, HTTPSPrefix) || - strings.HasPrefix(s, HuggingFacePrefix) || - strings.HasPrefix(s, GithubURI) || - strings.HasPrefix(s, GithubURI2) + return slices.Contains(getRecognizedURIPrefixes(), s) } func ConvertURL(s string) string { @@ -242,6 +249,37 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string, return nil } +// this function check if the string is an URL, if it's an URL downloads the image in memory +// encodes it in base64 and returns the base64 string +func GetBase64Image(s string) (string, error) { + if strings.HasPrefix(s, "http") { + // download the image + resp, err := http.Get(s) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // read the image data into memory + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // encode the image data in base64 + encoded := base64.StdEncoding.EncodeToString(data) + + // return the base64 string + return encoded, nil + } + + // if the string instead is prefixed with "data:image/jpeg;base64,", drop it + if strings.HasPrefix(s, "data:image/jpeg;base64,") { + return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil + } + return "", fmt.Errorf("not valid string") +} + type progressWriter struct { fileName string total int64