Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gallery): support model deletion #2173

Merged
merged 3 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,36 @@ func (c *BackendConfig) ShouldCallSpecificFunction() bool {
return len(c.functionCallNameString) > 0
}

// MMProjFileName returns the filename of the MMProj file
// If the MMProj is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) MMProjFileName() string {
modelURL := downloader.ConvertURL(c.MMProj)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}

return c.MMProj
}

func (c *BackendConfig) IsMMProjURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.MMProj))
}

func (c *BackendConfig) IsModelURL() bool {
return downloader.LooksLikeURL(downloader.ConvertURL(c.Model))
}

// ModelFileName returns the filename of the model
// If the model is a URL, it will return the MD5 of the URL which is the filename
func (c *BackendConfig) ModelFileName() string {
modelURL := downloader.ConvertURL(c.Model)
if downloader.LooksLikeURL(modelURL) {
return utils.MD5(modelURL)
}

return c.Model
}

func (c *BackendConfig) FunctionToCall() string {
if c.functionCallNameString != "" &&
c.functionCallNameString != "none" && c.functionCallNameString != "auto" {
Expand Down Expand Up @@ -532,26 +562,41 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}

modelURL := config.PredictionOptions.Model
modelURL = downloader.ConvertURL(modelURL)
// If the model is an URL, expand it, and download the file
if config.IsModelURL() {
modelFileName := config.ModelFileName()
modelURL := downloader.ConvertURL(config.Model)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}

if downloader.LooksLikeURL(modelURL) {
// md5 of model name
md5Name := utils.MD5(modelURL)
cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
}

if config.IsMMProjURL() {
modelFileName := config.MMProjFileName()
modelURL := downloader.ConvertURL(config.MMProj)
// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", 0, 0, status)
if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) {
err := downloader.DownloadFile(modelURL, filepath.Join(modelPath, modelFileName), "", 0, 0, status)
if err != nil {
return err
}
}

cc := cl.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
c.MMProj = modelFileName
cl.configs[i] = *c
}

if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
}
Expand Down Expand Up @@ -586,7 +631,8 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
}
for _, file := range files {
// Skip templates, YAML and .keep files
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") {
if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") ||
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
Expand Down
43 changes: 33 additions & 10 deletions core/http/elements/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
NoImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg"
)

func DoneProgress(uid string) string {
func DoneProgress(uid, text string) string {
return elem.Div(
attrs.Props{},
elem.H3(
Expand All @@ -23,7 +23,7 @@ func DoneProgress(uid string) string {
"tabindex": "-1",
"autofocus": "",
},
elem.Text("Installation completed"),
elem.Text(text),
),
).Render()
}
Expand Down Expand Up @@ -60,7 +60,7 @@ func ProgressBar(progress string) string {
).Render()
}

func StartProgressBar(uid, progress string) string {
func StartProgressBar(uid, progress, text string) string {
if progress == "" {
progress = "0"
}
Expand All @@ -77,7 +77,7 @@ func StartProgressBar(uid, progress string) string {
"tabindex": "-1",
"autofocus": "",
},
elem.Text("Installing"),
elem.Text(text),
// This is a simple example of how to use the HTMLX library to create a progress bar that updates every 600ms.
elem.Div(attrs.Props{
"hx-get": "/browse/job/progress/" + uid,
Expand Down Expand Up @@ -106,14 +106,33 @@ func cardSpan(text, icon string) elem.Node {
func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[string, string]) string {
//StartProgressBar(uid, "0")
modelsElements := []elem.Node{}
span := func(s string) elem.Node {
return elem.Span(
// span := func(s string) elem.Node {
// return elem.Span(
// attrs.Props{
// "class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs",
// },
// elem.Text(s),
// )
// }
deleteButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button(
attrs.Props{
"class": "float-right inline-block bg-green-500 text-white py-1 px-3 rounded-full text-xs",
"data-twe-ripple-init": "",
"data-twe-ripple-color": "light",
"class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
"hx-swap": "outerHTML",
// post the Model ID as param
"hx-post": "/browse/delete/model/" + m.Name,
},
elem.Text(s),
elem.I(
attrs.Props{
"class": "fa-solid fa-cancel pr-2",
},
),
elem.Text("Delete"),
)
}

installButton := func(m *gallery.GalleryModel) elem.Node {
return elem.Button(
attrs.Props{
Expand Down Expand Up @@ -202,10 +221,14 @@ func ListModels(models []*gallery.GalleryModel, installing *xsync.SyncedMap[stri
elem.If(
currentlyInstalling,
elem.Node( // If currently installing, show progress bar
elem.Raw(StartProgressBar(installing.Get(galleryID), "0")),
elem.Raw(StartProgressBar(installing.Get(galleryID), "0", "Installing")),
), // Otherwise, show install button (if not installed) or display "Installed"
elem.If(m.Installed,
span("Installed"),
//elem.Node(elem.Div(
// attrs.Props{},
// span("Installed"), deleteButton(m),
// )),
deleteButton(m),
installButton(m),
),
),
Expand Down
21 changes: 21 additions & 0 deletions core/http/endpoints/localai/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,27 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fibe
}
}

func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")

mgs.galleryApplier.C <- gallery.GalleryOp{
Delete: true,
GalleryName: modelName,
}

uuid, err := uuid.NewUUID()
if err != nil {
return err
}

return c.JSON(struct {
ID string `json:"uuid"`
StatusURL string `json:"status"`
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()})
}
}

func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing models from galleries: %+v", mgs.galleries)
Expand Down
2 changes: 2 additions & 0 deletions core/http/routes/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func RegisterLocalAIRoutes(app *fiber.App,

modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.ModelPath, galleryService)
app.Post("/models/apply", auth, modelGalleryEndpointService.ApplyModelGalleryEndpoint())
app.Post("/models/delete/:name", auth, modelGalleryEndpointService.DeleteModelGalleryEndpoint())

app.Get("/models/available", auth, modelGalleryEndpointService.ListModelFromGalleryEndpoint())
app.Get("/models/galleries", auth, modelGalleryEndpointService.ListModelGalleriesEndpoint())
app.Post("/models/galleries", auth, modelGalleryEndpointService.AddModelGalleryEndpoint())
Expand Down
44 changes: 42 additions & 2 deletions core/http/routes/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ func RegisterUIRoutes(app *fiber.App,
return c.SendString(elements.ListModels(filteredModels, installingModels))
})

/*

Install routes

*/

// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/install/model/:id", auth, func(c *fiber.Ctx) error {
Expand All @@ -89,7 +95,33 @@ func RegisterUIRoutes(app *fiber.App,
galleryService.C <- op
}()

return c.SendString(elements.StartProgressBar(uid, "0"))
return c.SendString(elements.StartProgressBar(uid, "0", "Installation"))
})

// This route is used when the "Install" button is pressed, we submit here a new job to the gallery service
// https://htmx.org/examples/progress-bar/
app.Post("/browse/delete/model/:id", auth, func(c *fiber.Ctx) error {
galleryID := strings.Clone(c.Params("id")) // note: strings.Clone is required for multiple requests!

id, err := uuid.NewUUID()
if err != nil {
return err
}

uid := id.String()

installingModels.Set(galleryID, uid)

op := gallery.GalleryOp{
Id: uid,
Delete: true,
GalleryName: galleryID,
}
go func() {
galleryService.C <- op
}()

return c.SendString(elements.StartProgressBar(uid, "0", "Deletion"))
})

// Display the job current progress status
Expand Down Expand Up @@ -118,12 +150,20 @@ func RegisterUIRoutes(app *fiber.App,
// this route is hit when the job is done, and we display the
// final state (for now just displays "Installation completed")
app.Get("/browse/job/:uid", auth, func(c *fiber.Ctx) error {

status := galleryService.GetStatus(c.Params("uid"))

for _, k := range installingModels.Keys() {
if installingModels.Get(k) == c.Params("uid") {
installingModels.Delete(k)
}
}

return c.SendString(elements.DoneProgress(c.Params("uid")))
displayText := "Installation completed"
if status.Deletion {
displayText = "Deletion completed"
}

return c.SendString(elements.DoneProgress(c.Params("uid"), displayText))
})
}
57 changes: 46 additions & 11 deletions core/services/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"os"
"path/filepath"
"strings"
"sync"

Expand Down Expand Up @@ -84,18 +85,47 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
}

var err error
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)

// delete a model
if op.Delete {
modelConfig := &config.BackendConfig{}
// Galleryname is the name of the model in this case
dat, err := os.ReadFile(filepath.Join(g.modelPath, op.GalleryName+".yaml"))
if err != nil {
updateError(err)
continue
}
err = yaml.Unmarshal(dat, modelConfig)
if err != nil {
updateError(err)
continue
}

files := []string{}
// Remove the model from the config
if modelConfig.Model != "" {
files = append(files, modelConfig.ModelFileName())
}
} else if op.ConfigURL != "" {
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
err = cl.Preload(g.modelPath)

if modelConfig.MMProj != "" {
files = append(files, modelConfig.MMProjFileName())
}

err = gallery.DeleteModelFromSystem(g.modelPath, op.GalleryName, files)
} else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
// if the request contains a gallery name, we apply the gallery from the gallery list
if op.GalleryName != "" {
if strings.Contains(op.GalleryName, "@") {
err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryName, g.modelPath, op.Req, progressCallback)
}
} else if op.ConfigURL != "" {
startup.PreloadModelsConfigurations(op.ConfigURL, g.modelPath, op.ConfigURL)
err = cl.Preload(g.modelPath)
} else {
err = prepareModel(g.modelPath, op.Req, cl, progressCallback)
}
}

if err != nil {
Expand All @@ -116,7 +146,12 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader
continue
}

g.UpdateStatus(op.Id, &gallery.GalleryOpStatus{Processed: true, Message: "completed", Progress: 100})
g.UpdateStatus(op.Id,
&gallery.GalleryOpStatus{
Deletion: op.Delete,
Processed: true,
Message: "completed",
Progress: 100})
}
}
}()
Expand Down
Loading