diff --git a/examples/GraphDef-model/model/graphdef.pb b/examples/GraphDef-model/model/model.graphdef similarity index 100% rename from examples/GraphDef-model/model/graphdef.pb rename to examples/GraphDef-model/model/model.graphdef diff --git a/pkg/model/format.go b/pkg/model/format.go index 3f8ce1c3..9c3e2375 100644 --- a/pkg/model/format.go +++ b/pkg/model/format.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" "path" + "strings" ) // Format is the definition of model format. @@ -77,48 +78,55 @@ func (f Format) ValidateDirectory(rootPath string) error { return nil } +func ValidateError(modelPath string, modelName string, modelNum int32) error { + if modelNum != 1 { + return fmt.Errorf("Expected one %v file in %v directory, but found %v .", modelName, modelPath, modelNum) + } + return nil +} + func (f Format) validateForSavedModel(modelPath string, files []os.FileInfo) error { - var pbFileFlag bool - var variablesDirFlag bool + var pbFileNum int32 + var variablesDirNum int32 for _, file := range files { - if path.Ext(file.Name()) == ".pb" { - pbFileFlag = true + if file.Name() == "saved_model.pb" { + pbFileNum++ } if file.IsDir() && file.Name() == "variables" { - variablesDirFlag = true + variablesDirNum++ } } - if !pbFileFlag { - return fmt.Errorf("there are no *.pb file in %v directory", modelPath) + if e := ValidateError(modelPath, "saved_model.pb", pbFileNum); e != nil { + return e } - if !variablesDirFlag { - return fmt.Errorf("there are no variables dir in %v directory", modelPath) + if e := ValidateError(modelPath, "variables", variablesDirNum); e != nil { + return e } return nil } func (f Format) validateForONNX(modelPath string, files []os.FileInfo) error { - var onnxFileFlag bool + var onnxFileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".onnx" { - onnxFileFlag = true + onnxFileNum++ } } - if !onnxFileFlag { - return fmt.Errorf("there are no *.onnx file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.onnx", onnxFileNum); e != nil { + return e } return nil } func (f Format) validateForH5(modelPath string, files []os.FileInfo) error { - var h5FileFlag bool + var h5FileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".h5" { - h5FileFlag = true + h5FileNum++ } } - if !h5FileFlag { - return fmt.Errorf("there are no *.h5 file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.h5", h5FileNum); e != nil { + return e } return nil } @@ -135,141 +143,140 @@ func (f Format) validateForPMML(modelPath string, files []os.FileInfo) error { } func (f Format) validateForCaffeModel(modelPath string, files []os.FileInfo) error { - var caffeModelFileFlag bool - var prototxtFileFlag bool + var caffeModelFileNum int32 + var prototxtFileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".caffemodel" { - caffeModelFileFlag = true + caffeModelFileNum++ } if path.Ext(file.Name()) == ".prototxt" { - prototxtFileFlag = true + prototxtFileNum++ } } - if !caffeModelFileFlag { - return fmt.Errorf("there are no *.caffemodel file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.caffemodel", caffeModelFileNum); e != nil { + return e } - if !prototxtFileFlag { - return fmt.Errorf("there are no *.prototxt file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.prototxt", prototxtFileNum); e != nil { + return e } return nil } func (f Format) validateForNetDef(modelPath string, files []os.FileInfo) error { - var initFileFlag bool - var predictFileFlag bool + var initFileNum int32 + var predictFileNum int32 for _, file := range files { if file.Name() == "init_net.pb" { - initFileFlag = true + initFileNum++ } if file.Name() == "predict_net.pb" { - predictFileFlag = true + predictFileNum++ } } - if !initFileFlag { - return fmt.Errorf("there are no init_net.pb file in %v directory", modelPath) + if e := ValidateError(modelPath, "init_net.pb", initFileNum); e != nil { + return e } - if !predictFileFlag { - return fmt.Errorf("there are no predict_net.pb file in %v directory", modelPath) + if e := ValidateError(modelPath, "predict_net.pb", predictFileNum); e != nil { + return e } return nil } -func (f Format) validateForMXNetParams(modelPath string, files []os.FileInfo) error { - var jsonFileFlag bool - var paramsFileFlag bool +func (f Format) validateForMXNETParams(modelPath string, files []os.FileInfo) error { + var jsonFileNum int32 + var paramsFileNum int32 for _, file := range files { - if path.Ext(file.Name()) == ".json" { - jsonFileFlag = true + if strings.HasSuffix(file.Name(), "symbol.json") { + jsonFileNum++ } if path.Ext(file.Name()) == ".params" { - paramsFileFlag = true + paramsFileNum++ } } - if !jsonFileFlag { - return fmt.Errorf("there are no *.json file in %v directory", modelPath) + if e := ValidateError(modelPath, "*symbol.json", jsonFileNum); e != nil { + return e } - if !paramsFileFlag { - return fmt.Errorf("there are no *.params file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.params", paramsFileNum); e != nil { + return e } return nil } func (f Format) validateForTorchScript(modelPath string, files []os.FileInfo) error { - var ptFileFlag bool + var ptFileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".pt" { - ptFileFlag = true + ptFileNum++ } } - if !ptFileFlag { - return fmt.Errorf("there are no *.pt file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.pt", ptFileNum); e != nil { + return e } return nil } func (f Format) validateForGraphDef(modelPath string, files []os.FileInfo) error { - var pbFileFlag bool + var graphdefFileNum int32 for _, file := range files { - if path.Ext(file.Name()) == ".pb" { - pbFileFlag = true - break + if path.Ext(file.Name()) == ".graphdef" { + graphdefFileNum++ } } - if !pbFileFlag { - return fmt.Errorf("there are no *.pb file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.graphdef", graphdefFileNum); e != nil { + return e } return nil } func (f Format) validateForTensorRT(modelPath string, files []os.FileInfo) error { - var tensorrtFileFlag bool + var tensorrtFileNum int32 for _, file := range files { - if path.Ext(file.Name()) == ".plan" { - tensorrtFileFlag = true + if path.Ext(file.Name()) == ".plan" || path.Ext(file.Name()) == ".engine" { + tensorrtFileNum++ } } - if !tensorrtFileFlag { - return fmt.Errorf("there are no *.plan file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.plan or *.engine", tensorrtFileNum); e != nil { + return e } return nil } func (f Format) validateForSKLearn(modelPath string, files []os.FileInfo) error { - var sklearnFileFlag bool + var sklearnFileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".joblib" { - sklearnFileFlag = true + sklearnFileNum++ } } - if !sklearnFileFlag { - return fmt.Errorf("there are no *.joblib file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.joblib", sklearnFileNum); e != nil { + return e } return nil } func (f Format) validateForXGBoost(modelPath string, files []os.FileInfo) error { - var xgboostFileFlag bool + var xgboostFileNum int32 for _, file := range files { if path.Ext(file.Name()) == ".xgboost" { - xgboostFileFlag = true + xgboostFileNum++ } } - if !xgboostFileFlag { - return fmt.Errorf("there are no *.xgboost file in %v directory", modelPath) + if e := ValidateError(modelPath, "*.xgboost", xgboostFileNum); e != nil { + return e } return nil } func (f Format) validateForMLflow(modelPath string, files []os.FileInfo) error { - var isMLflowFile bool + var MLflowFileNum int32 for _, file := range files { if file.Name() == "MLmodel" { // assuming that user would not fool the tool - isMLflowFile = true + MLflowFileNum++ } } - if !isMLflowFile { - return fmt.Errorf("there are no MLmodel file in %v, directory", modelPath) + if e := ValidateError(modelPath, "MLmodel", MLflowFileNum); e != nil { + return e } return nil }