Skip to content

Commit

Permalink
rework fetchModelName to work by endpoint
Browse files Browse the repository at this point in the history
this change allows us to use different model names for the precheckEndpoint and precheckScoringEndpoint

Signed-off-by: greg pereira <[email protected]>
  • Loading branch information
Gregory-Pereira committed May 18, 2024
1 parent 5d8d28a commit e20843c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
5 changes: 2 additions & 3 deletions ui/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (api *ApiServer) runIlabChatCommand(question, context string) (string, erro
cmd = exec.Command("echo", cmdArgs...)
api.logger.Infof("Running in test mode: %s", commandStr)
} else {
modelName, err := api.fetchModelName(true)
modelName, err := api.fetchModelName(true, api.preCheckEndpointURL)
if err != nil {
api.logger.Errorf("Failed to fetch model name: %v", err)
return "failed to retrieve the model name", err
Expand Down Expand Up @@ -382,9 +382,8 @@ func setupLogger(debugMode bool) *zap.SugaredLogger {

// fetchModelName hits the defined precheck endpoint with "/models" appended to extract the model name.
// If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens.
func (api *ApiServer) fetchModelName(fullName bool) (string, error) {
func (api *ApiServer) fetchModelName(fullName bool, endpoint string) (string, error) {
// Ensure the endpoint URL ends with "/models"
endpoint := api.preCheckEndpointURL
if !strings.HasSuffix(endpoint, "/") {
endpoint += "/"
}
Expand Down
25 changes: 19 additions & 6 deletions worker/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ var generateCmd = &cobra.Command{

func (w *Worker) runPrecheckScoring(precheckPRAnswers []string, precheckEndpointAnswers []string, precheckPRQuestions []string, lab string, outputDir string, preCheckScoringModelName string) error {
if len(precheckPRAnswers) != len(precheckEndpointAnswers) {
errMsg := "PR questions a Endpoint answers returned a different number of entries, something went wrong."
errMsg := "PR answers and Endpoint answers returned a different number of entries, something went wrong"
w.logger.Error(errMsg)
return fmt.Errorf(errMsg)
}
Expand Down Expand Up @@ -638,7 +638,7 @@ func (w *Worker) processJob() {
// sdg-svc does not have a models endpoint as yet
if jobType != jobSDG && PreCheckEndpointURL != localEndpoint {
var err error
modelName, err = w.fetchModelName(true)
modelName, err = w.fetchModelName(true, w.precheckEndpoint)
if err != nil {
w.logger.Errorf("Failed to fetch model name: %v", err)
modelName = "unknown"
Expand Down Expand Up @@ -683,7 +683,21 @@ func (w *Worker) processJob() {
w.reportJobError(err)
return
}
err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions, lab, outputDir, modelName)

var scoringModelName string
// sdg-svc does not have a models endpoint as yet
if jobType == jobPreCheck && w.precheckScoringEndpoint != localEndpoint {
var err error
scoringModelName, err = w.fetchModelName(true, w.precheckScoringEndpoint)
if err != nil {
w.logger.Errorf("Failed to fetch model name: %v", err)
scoringModelName = "unknown"
}
} else {
scoringModelName = w.getModelNameFromConfig() // will default to standard precheck model
}

err = w.runPrecheckScoring(precheckPRAnswers, precheckEndpointAnswers, precheckPRQuestions, lab, outputDir, scoringModelName)
if err != nil {
sugar.Errorf("Could not run scoring on result of precheck: %v", err)
w.reportJobError(err)
Expand Down Expand Up @@ -975,9 +989,8 @@ func (w *Worker) getModelNameFromConfig() string {

// fetchModelName hits the defined precheckEndpoint with "/models" appended to extract the model name.
// If fullName is true, it returns the entire ID value; if false, it returns the parsed out name after the double hyphens.
func (w *Worker) fetchModelName(fullName bool) (string, error) {
func (w *Worker) fetchModelName(fullName bool, endpoint string) (string, error) {
// Ensure the endpoint URL ends with "/models"
endpoint := w.precheckEndpoint
if !strings.HasSuffix(endpoint, "/") {
endpoint += "/"
}
Expand Down Expand Up @@ -1073,7 +1086,7 @@ func (w *Worker) determineModelName(jobType string) string {

// precheck is the only case we use a remote OpenAI endpoint right now
if PreCheckEndpointURL != localEndpoint && jobType == jobPreCheck {
modelName, err := w.fetchModelName(false)
modelName, err := w.fetchModelName(false, w.precheckEndpoint)
if err != nil {
w.logger.Errorf("Failed to fetch model name: %v", err)
return "unknown"
Expand Down
6 changes: 3 additions & 3 deletions worker/cmd/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ func TestFetchModelName(t *testing.T) {
20,
)

modelName, err := w.fetchModelName(false)
modelName, err := w.fetchModelName(false, w.precheckEndpoint)
assert.NoError(t, err, "fetchModelName should not return an error")
expectedModelName := "Mixtral-8x7B-Instruct-v0.1"
assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly")

modelName, err = w.fetchModelName(true)
modelName, err = w.fetchModelName(true, w.precheckEndpoint)
assert.NoError(t, err, "fetchModelName should not return an error")
expectedModelName = "/shared_model_storage/transformers_cache/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/5c79a376139be989ef1838f360bf4f1f256d7aec"
assert.Equal(t, expectedModelName, modelName, "The model name should be extracted correctly")
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) {
"dummy-ca-cert-path.pem",
20,
)
modelName, err := w.fetchModelName(false)
modelName, err := w.fetchModelName(false, w.precheckEndpoint)

// Verify that an error was returned due to the invalid "object" field
assert.Error(t, err, "fetchModelName should return an error for invalid object field")
Expand Down

0 comments on commit e20843c

Please sign in to comment.