Skip to content

Commit

Permalink
fix pt1
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-gray101 committed Jan 4, 2024
1 parent 49208f7 commit 6d6b066
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 72 deletions.
48 changes: 32 additions & 16 deletions core/services/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package services

import (
"errors"
"fmt"
"io/fs"
"os"
Expand Down Expand Up @@ -93,35 +94,50 @@ func (cm *ConfigLoader) LoadConfigs(path string) error {
return nil
}

// TODO: Does this belong under ConfigLoader?
func (cl *ConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}

log.Info().Msgf("Preloading models from %s", modelPath)

for _, config := range cm.configs {

// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)

if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)

if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}

for i, config := range cl.configs {
modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {

if utils.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent)
})
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}

cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
cl.configs[i] = *c
}
}
return nil
}

Check failure on line 141 in core/services/config.go

View workflow job for this annotation

GitHub Actions / build-macOS (avx2)

missing return

Check failure on line 141 in core/services/config.go

View workflow job for this annotation

GitHub Actions / build-linux (avx, -DLLAMA_AVX2=OFF)

missing return

func (cl *ConfigLoader) LoadConfigFile(file string) error {
Expand Down
49 changes: 0 additions & 49 deletions pkg/datamodel/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package datamodel

import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"

"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -401,49 +398,3 @@ func UpdateConfigFromOpenAIRequest(config *Config, input *OpenAIRequest) {
}

}

// Preload prepare models if they are not local but url or huggingface repositories
func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
}

log.Info().Msgf("Preloading models from %s", modelPath)

for i, config := range cm.configs {

// Download files and verify their SHA
for _, file := range config.DownloadFiles {
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)

if err := utils.VerifyPath(file.Filename, modelPath); err != nil {
return err
}
// Create file path
filePath := filepath.Join(modelPath, file.Filename)

if err := utils.DownloadFile(file.URI, filePath, file.SHA256, status); err != nil {
return err
}
}

modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)

if utils.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", status)
if err != nil {
return err
}
}
}
}
}
2 changes: 1 addition & 1 deletion pkg/datamodel/startup_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func NewStartupOptions(o ...AppOption) *StartupOptions {
}

func WithModelsURL(urls ...string) AppOption {
return func(o *Option) {
return func(o *StartupOptions) {
o.ModelsURL = urls
}
}
Expand Down
50 changes: 44 additions & 6 deletions pkg/utils/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package utils
import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"slices"
"strconv"
"strings"

Expand All @@ -23,6 +25,16 @@ const (
GithubURI2 = "github://"
)

func getRecognizedURIPrefixes() []string {
return []string{
HuggingFacePrefix,
HTTPPrefix,
HTTPSPrefix,
GithubURI,
GithubURI2,
}
}

func GetURI(url string, f func(url string, i []byte) error) error {
url = ConvertURL(url)

Expand Down Expand Up @@ -60,13 +72,8 @@ func GetURI(url string, f func(url string, i []byte) error) error {
return f(url, body)
}

// TODO: Refactor to use a slice of constants and slices.Contains() for easier maintenance?
func LooksLikeURL(s string) bool {
return strings.HasPrefix(s, HTTPPrefix) ||
strings.HasPrefix(s, HTTPSPrefix) ||
strings.HasPrefix(s, HuggingFacePrefix) ||
strings.HasPrefix(s, GithubURI) ||
strings.HasPrefix(s, GithubURI2)
return slices.Contains(getRecognizedURIPrefixes(), s)
}

func ConvertURL(s string) string {
Expand Down Expand Up @@ -242,6 +249,37 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return nil
}

// this function check if the string is an URL, if it's an URL downloads the image in memory
// encodes it in base64 and returns the base64 string
func GetBase64Image(s string) (string, error) {
if strings.HasPrefix(s, "http") {
// download the image
resp, err := http.Get(s)
if err != nil {
return "", err
}
defer resp.Body.Close()

// read the image data into memory
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

// encode the image data in base64
encoded := base64.StdEncoding.EncodeToString(data)

// return the base64 string
return encoded, nil
}

// if the string instead is prefixed with "data:image/jpeg;base64,", drop it
if strings.HasPrefix(s, "data:image/jpeg;base64,") {
return strings.ReplaceAll(s, "data:image/jpeg;base64,", ""), nil
}
return "", fmt.Errorf("not valid string")
}

type progressWriter struct {
fileName string
total int64
Expand Down

0 comments on commit 6d6b066

Please sign in to comment.