Skip to content

Commit

Permalink
allow upload metadata from cli
Browse files Browse the repository at this point in the history
  • Loading branch information
etai-shuchatowitz committed Nov 14, 2024
1 parent 53de7fb commit ea7bf65
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
9 changes: 9 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ const (
packageFlagType = "type"
packageFlagDestination = "destination"
packageFlagPath = "path"
packageFlagFramework = "model-framework"

packageMetadataFlagFramework = "model_framework"

authApplicationFlagName = "application-name"
authApplicationFlagApplicationID = "application-id"
Expand Down Expand Up @@ -1944,6 +1947,12 @@ This won't work unless you have an existing installation of our GitHub app on yo
Required: true,
Usage: "type of the requested package, can be: " + strings.Join(packageTypes, ", "),
},
&cli.StringFlag{
Name: packageFlagFramework,
Required: false,
Usage: "framework for an ml_model being uploaded, can be: " +
strings.Join(modelFrameworks, ", ") + ", Required if packages if of type `ml_model`",
},
},
Action: PackageUploadAction,
},
Expand Down
31 changes: 30 additions & 1 deletion cli/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path"
"path/filepath"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -183,13 +184,25 @@ func PackageUploadAction(c *cli.Context) error {
return err
}

if err := validatePackageUploadRequest(c); err != nil {
return err
}

resp, err := client.uploadPackage(
c.String(generalFlagOrgID),
c.String(packageFlagName),
c.String(packageFlagVersion),
c.String(packageFlagType),
c.Path(packageFlagPath),
nil,
&structpb.Struct{
Fields: map[string]*structpb.Value{
packageMetadataFlagFramework: &structpb.Value{
Kind: &structpb.Value_StringValue{
StringValue: c.String(packageFlagFramework),
},
},
},
},
)
if err != nil {
return err
Expand Down Expand Up @@ -274,3 +287,19 @@ func getNextPackageUploadRequest(file *os.File) (*packagespb.CreatePackageReques
func (m *moduleID) ToDetailURL(baseURL string, packageType PackageType) string {
return fmt.Sprintf("https://%s/%s/%s/%s", baseURL, strings.ReplaceAll(string(packageType), "_", "-"), m.prefix, m.name)
}

func validatePackageUploadRequest(c *cli.Context) error {
packageType := c.String(packageFlagType)

if packageType == "ml_model" {
if c.String(packageFlagFramework) == "" {
return errors.New("must pass in a model-framework if package is of type `ml_model`")
}

if !slices.Contains(modelFrameworks, c.String(packageFlagFramework)) {
return errors.New("framework must be of type " + strings.Join(modelFrameworks, ", "))
}
}

return nil
}

0 comments on commit ea7bf65

Please sign in to comment.