Skip to content

Commit

Permalink
Merge pull request #497 from nerdalert/vllm-status-db
Browse files Browse the repository at this point in the history
Add any served models in a column in the jobs table
  • Loading branch information
vishnoianil authored Jan 26, 2025
2 parents 94ddd85 + 39dbf3c commit fa2c1ed
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 33 deletions.
3 changes: 3 additions & 0 deletions api-server/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ go.work.sum
# env file
.env

# binary
api-server

# app specific
logs/
jobs.json
108 changes: 96 additions & 12 deletions api-server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
Expand Down Expand Up @@ -243,12 +244,21 @@ func (srv *ILabServer) getVllmStatusHandler(w http.ResponseWriter, r *http.Reque
return
}

srv.jobIDsMutex.RLock()
jobID, ok := srv.servedModelJobIDs[modelName]
srv.jobIDsMutex.RUnlock()
// Directly query the DB for the job associated with this model
var jobID string
err = srv.db.QueryRow(`
SELECT job_id
FROM jobs
WHERE served_model_name = ? AND status = 'running'
LIMIT 1
`, modelName).Scan(&jobID)

if !ok {
srv.log.Infof("WTF jobid not found for model '%s'", modelName)
if err == sql.ErrNoRows {
srv.log.Infof("No running job found for model '%s'", modelName)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"})
return
} else if err != nil {
srv.log.Errorf("Error querying job for model '%s': %v", modelName, err)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"})
return
}
Expand Down Expand Up @@ -629,6 +639,26 @@ func (srv *ILabServer) runVllmContainerHandler(
gpuIndex int, hostVolume, containerVolume string,
w http.ResponseWriter,
) {
// Check if a job is already running for the requested model
existingJob, err := srv.getRunningJobByModel(servedModelName)
if err != nil {
srv.log.Errorf("Error checking existing jobs for model '%s': %v", servedModelName, err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
if existingJob != nil {
srv.log.Infof("A job is already running for model '%s' with job_id: %s", servedModelName, existingJob.JobID)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "already_running",
"job_id": existingJob.JobID,
"message": fmt.Sprintf("Model '%s' is already being served.", servedModelName),
})
return
}

srv.log.Infof("No existing job found for model '%s'. Starting a new job.", servedModelName)

cmdArgs := []string{
"run", "--rm",
fmt.Sprintf("--device=nvidia.com/gpu=%d", gpuIndex),
Expand Down Expand Up @@ -681,13 +711,14 @@ func (srv *ILabServer) runVllmContainerHandler(

// Create a Job record and store it in the DB
newJob := &Job{
JobID: jobID,
Cmd: "podman",
Args: cmdArgs,
Status: "running",
PID: cmd.Process.Pid,
LogFile: logFilePath,
StartTime: time.Now(),
JobID: jobID,
Cmd: "podman",
Args: cmdArgs,
Status: "running",
PID: cmd.Process.Pid,
LogFile: logFilePath,
StartTime: time.Now(),
ServedModelName: servedModelName,
}
if err := srv.createJob(newJob); err != nil {
srv.log.Errorf("Failed to create job in DB for %s: %v", jobID, err)
Expand Down Expand Up @@ -859,6 +890,59 @@ func (srv *ILabServer) serveModelHandler(modelPath, port string, w http.Response
_ = json.NewEncoder(w).Encode(map[string]string{"status": "model process started", "job_id": jobID})
}

// getRunningJobByModel retrieves a running job for the specified served_model_name.
// Returns nil if no such job exists.
func (srv *ILabServer) getRunningJobByModel(servedModelName string) (*Job, error) {
var job Job
var argsJSON string
var startTimeStr, endTimeStr sql.NullString

row := srv.db.QueryRow(`
SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name
FROM jobs
WHERE served_model_name = ? AND status = 'running'
LIMIT 1
`, servedModelName)

err := row.Scan(
&job.JobID,
&job.Cmd,
&argsJSON,
&job.Status,
&job.PID,
&job.LogFile,
&startTimeStr,
&endTimeStr,
&job.Branch,
&job.ServedModelName,
)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}

if err := json.Unmarshal([]byte(argsJSON), &job.Args); err != nil {
srv.log.Errorf("Failed to unmarshal Args for job '%s': %v", job.JobID, err)
return nil, fmt.Errorf("failed to unmarshal Args for job '%s': %v", job.JobID, err)
}

if startTimeStr.Valid {
t, err := time.Parse(time.RFC3339, startTimeStr.String)
if err == nil {
job.StartTime = t
}
}
if endTimeStr.Valid && endTimeStr.String != "" {
t, err := time.Parse(time.RFC3339, endTimeStr.String)
if err == nil {
job.EndTime = &t
}
}

return &job, nil
}

// listServedModelJobIDsHandler is a debug endpoint to list current model to jobID mappings.
func (srv *ILabServer) listServedModelJobIDsHandler(w http.ResponseWriter, r *http.Request) {
srv.jobIDsMutex.RLock()
Expand Down
14 changes: 9 additions & 5 deletions api-server/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ func (srv *ILabServer) initDB() {
log_file TEXT,
start_time TEXT,
end_time TEXT,
branch TEXT
branch TEXT,
served_model_name TEXT
);
`
_, err = srv.db.Exec(createTableSQL)
Expand All @@ -58,8 +59,8 @@ func (srv *ILabServer) createJob(job *Job) error {
endTimeStr = &s
}
_, err = srv.db.Exec(`
INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
job.JobID,
job.Cmd,
Expand All @@ -70,6 +71,7 @@ func (srv *ILabServer) createJob(job *Job) error {
job.StartTime.Format(time.RFC3339),
endTimeStr,
job.Branch,
job.ServedModelName,
)
if err != nil {
return fmt.Errorf("failed to insert job: %v", err)
Expand All @@ -79,7 +81,7 @@ func (srv *ILabServer) createJob(job *Job) error {

// getJob fetches a single job by job_id.
func (srv *ILabServer) getJob(jobID string) (*Job, error) {
row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch FROM jobs WHERE job_id = ?", jobID)
row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch, served_model_name FROM jobs WHERE job_id = ?", jobID)

var j Job
var argsJSON string
Expand All @@ -95,6 +97,7 @@ func (srv *ILabServer) getJob(jobID string) (*Job, error) {
&startTimeStr,
&endTimeStr,
&j.Branch,
&j.ServedModelName,
)
if err == sql.ErrNoRows {
return nil, nil // not found
Expand Down Expand Up @@ -133,7 +136,7 @@ func (srv *ILabServer) updateJob(job *Job) error {
}
_, err = srv.db.Exec(`
UPDATE jobs
SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ?
SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ?, served_model_name = ?
WHERE job_id = ?
`,
job.Cmd,
Expand All @@ -144,6 +147,7 @@ func (srv *ILabServer) updateJob(job *Job) error {
job.StartTime.Format(time.RFC3339),
endTimeStr,
job.Branch,
job.ServedModelName,
job.JobID,
)
if err != nil {
Expand Down
71 changes: 55 additions & 16 deletions api-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ type Data struct {

// Job represents a background job, including train/generate/pipeline/vllm-run jobs.
type Job struct {
JobID string `json:"job_id"`
Cmd string `json:"cmd"`
Args []string `json:"args"`
Status string `json:"status"` // "running", "finished", "failed"
PID int `json:"pid"`
LogFile string `json:"log_file"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Branch string `json:"branch"`
JobID string `json:"job_id"`
Cmd string `json:"cmd"`
Args []string `json:"args"`
Status string `json:"status"` // "running", "finished", "failed"
PID int `json:"pid"`
LogFile string `json:"log_file"`
StartTime time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time,omitempty"`
Branch string `json:"branch"`
ServedModelName string `json:"served_model_name"`

// Lock is not serialized; it protects updates to the Job in memory.
Lock sync.Mutex `json:"-"`
Expand Down Expand Up @@ -94,7 +95,7 @@ type ILabServer struct {
useVllm bool
pipelineType string
debugEnabled bool
homeDir string // New field added
homeDir string

// Logger
logger *zap.Logger
Expand All @@ -119,12 +120,7 @@ type ILabServer struct {
modelCache ModelCache
}

// -----------------------------------------------------------------------------
// main(), flags and Cobra
// -----------------------------------------------------------------------------

func main() {
// We create an instance of ILabServer to hold all state and methods.
srv := &ILabServer{
baseModel: "instructlab/granite-7b-lab",
servedModelJobIDs: make(map[string]string),
Expand All @@ -135,7 +131,6 @@ func main() {
Use: "ilab-server",
Short: "ILab Server Application",
Run: func(cmd *cobra.Command, args []string) {
// Now that flags are set, run the server method on the struct.
srv.runServer(cmd, args)
},
}
Expand Down Expand Up @@ -248,6 +243,8 @@ func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) {
// Initialize the model cache
srv.initializeModelCache()

srv.reconstructServedModelJobIDs()

// Create the logs directory if it doesn't exist
err = os.MkdirAll("logs", os.ModePerm)
if err != nil {
Expand Down Expand Up @@ -348,6 +345,48 @@ func (srv *ILabServer) refreshModelCache() {
srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models))
}

// reconstructServedModelJobIDs rebuilds the servedModelJobIDs map by querying the database
func (srv *ILabServer) reconstructServedModelJobIDs() {
srv.log.Info("Reconstructing servedModelJobIDs from the database...")

rows, err := srv.db.Query(`
SELECT job_id, served_model_name
FROM jobs
WHERE cmd = 'podman' AND status = 'running'
`)
if err != nil {
srv.log.Errorf("Error querying running vLLM jobs: %v", err)
return
}
defer rows.Close()

for rows.Next() {
var jobID, servedModelName string
if err := rows.Scan(&jobID, &servedModelName); err != nil {
srv.log.Errorf("Error scanning row: %v", err)
continue
}

// Validate servedModelName
if servedModelName != "pre-train" && servedModelName != "post-train" {
srv.log.Warnf("Invalid served_model_name '%s' for job_id '%s'", servedModelName, jobID)
continue
}

// Update the servedModelJobIDs map
srv.jobIDsMutex.Lock()
srv.servedModelJobIDs[servedModelName] = jobID
srv.jobIDsMutex.Unlock()
srv.log.Infof("Mapped model '%s' to job_id '%s'", servedModelName, jobID)
}

if err := rows.Err(); err != nil {
srv.log.Errorf("Error iterating over rows: %v", err)
}

srv.log.Info("Reconstruction of servedModelJobIDs completed.")
}

// -----------------------------------------------------------------------------
// Start Generate Data Job
// -----------------------------------------------------------------------------
Expand Down

0 comments on commit fa2c1ed

Please sign in to comment.