diff --git a/cli/app.go b/cli/app.go index 83e1e3777c4..f9685a7d626 100644 --- a/cli/app.go +++ b/cli/app.go @@ -111,6 +111,9 @@ const ( packageFlagType = "type" packageFlagDestination = "destination" packageFlagPath = "path" + packageFlagFramework = "model-framework" + + packageMetadataFlagFramework = "model_framework" authApplicationFlagName = "application-name" authApplicationFlagApplicationID = "application-id" @@ -1944,6 +1947,12 @@ This won't work unless you have an existing installation of our GitHub app on yo Required: true, Usage: "type of the requested package, can be: " + strings.Join(packageTypes, ", "), }, + &cli.StringFlag{ + Name: packageFlagFramework, + Required: false, + Usage: "framework for an ml_model being uploaded, can be: " + + strings.Join(modelFrameworks, ", ") + ", Required if packages if of type `ml_model`", + }, }, Action: PackageUploadAction, }, diff --git a/cli/packages.go b/cli/packages.go index 2b76e15ec6f..63e1f28a334 100644 --- a/cli/packages.go +++ b/cli/packages.go @@ -8,6 +8,7 @@ import ( "os" "path" "path/filepath" + "slices" "strings" "time" @@ -183,13 +184,25 @@ func PackageUploadAction(c *cli.Context) error { return err } + if err := validatePackageUploadRequest(c); err != nil { + return err + } + resp, err := client.uploadPackage( c.String(generalFlagOrgID), c.String(packageFlagName), c.String(packageFlagVersion), c.String(packageFlagType), c.Path(packageFlagPath), - nil, + &structpb.Struct{ + Fields: map[string]*structpb.Value{ + packageMetadataFlagFramework: &structpb.Value{ + Kind: &structpb.Value_StringValue{ + StringValue: c.String(packageFlagFramework), + }, + }, + }, + }, ) if err != nil { return err @@ -274,3 +287,19 @@ func getNextPackageUploadRequest(file *os.File) (*packagespb.CreatePackageReques func (m *moduleID) ToDetailURL(baseURL string, packageType PackageType) string { return fmt.Sprintf("https://%s/%s/%s/%s", baseURL, strings.ReplaceAll(string(packageType), "_", "-"), m.prefix, m.name) } + +func validatePackageUploadRequest(c *cli.Context) error { + packageType := c.String(packageFlagType) + + if packageType == "ml_model" { + if c.String(packageFlagFramework) == "" { + return errors.New("must pass in a model-framework if package is of type `ml_model`") + } + + if !slices.Contains(modelFrameworks, c.String(packageFlagFramework)) { + return errors.New("framework must be of type " + strings.Join(modelFrameworks, ", ")) + } + } + + return nil +}