From f1f9d3a7779b682b3d93070bfb0015def7ea8753 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 31 Jul 2024 22:23:48 +0200 Subject: [PATCH 01/10] cmd,core,server: llm pipeline with stream support --- cmd/livepeer/starter/starter.go | 16 ++++ core/ai.go | 1 + core/capabilities.go | 2 + core/orchestrator.go | 9 ++ go.mod | 2 + server/ai_http.go | 52 ++++++++-- server/ai_mediaserver.go | 70 +++++++++++++- server/ai_process.go | 163 ++++++++++++++++++++++++++++++++ server/ai_process_test.go | 38 ++++++++ server/rpc.go | 1 + 10 files changed, 347 insertions(+), 7 deletions(-) create mode 100644 server/ai_process_test.go diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 00c23ec3db..344c7a6874 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1328,6 +1328,22 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice) } + n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) + + case "llm-generate": + _, ok := capabilityConstraints[core.Capability_LlmGenerate] + if !ok { + aiCaps = append(aiCaps, core.Capability_LlmGenerate) + capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{ + Models: make(map[string]*core.ModelConstraint), + } + } + + capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint + + if *cfg.Network != "offchain" { + n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice) + } } if len(aiCaps) > 0 { diff --git a/core/ai.go b/core/ai.go index 31f331e49e..c637ccbb11 100644 --- a/core/ai.go +++ b/core/ai.go @@ -23,6 +23,7 @@ type AI interface { Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + LlmGenerate(context.Context, worker.LlmGenerateFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/capabilities.go b/core/capabilities.go index fc9e5217ba..2956e1f08d 100644 --- a/core/capabilities.go +++ b/core/capabilities.go @@ -79,6 +79,7 @@ const ( Capability_Upscale Capability_AudioToText Capability_SegmentAnything2 + Capability_LlmGenerate ) var CapabilityNameLookup = map[Capability]string{ @@ -116,6 +117,7 @@ var CapabilityNameLookup = map[Capability]string{ Capability_Upscale: "Upscale", Capability_AudioToText: "Audio to text", Capability_SegmentAnything2: "Segment anything 2", + Capability_LlmGenerate: "LLM Generate", } var CapabilityTestLookup = map[Capability]CapabilityTest{ diff --git a/core/orchestrator.go b/core/orchestrator.go index f8e343ae32..65b4e82294 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -134,6 +134,11 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe return orch.node.SegmentAnything2(ctx, req) } +// Return type is LlmResponse, but a stream is available as well as chan(string) +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + return orch.node.llmGenerate(ctx, req) +} + func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error { if orch.node == nil || orch.node.Recipient == nil { return nil @@ -1051,6 +1056,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + return n.AIWorker.LlmGenerate(ctx, req) +} + func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) { remoteChan, err := rtm.getTaskChan(tcID) if err != nil { diff --git a/go.mod b/go.mod index f84f4a6b37..c32e9d596c 100644 --- a/go.mod +++ b/go.mod @@ -238,3 +238,5 @@ require ( lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) + +replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker diff --git a/server/ai_http.go b/server/ai_http.go index 3f0bb97d9e..60a836482a 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -45,6 +45,7 @@ func startAIServer(lp lphttp) error { lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale())) lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText())) lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2())) + lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate())) return nil } @@ -148,7 +149,7 @@ func (h *lphttp) AudioToText() http.Handler { return } - var req worker.GenAudioToTextMultipartRequestBody + var req worker.AudioToTextMultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -158,7 +159,7 @@ func (h *lphttp) AudioToText() http.Handler { }) } -func (h *lphttp) SegmentAnything2() http.Handler { +func (h *lphttp) LlmGenerate() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { orch := h.orchestrator @@ -171,7 +172,7 @@ func (h *lphttp) SegmentAnything2() http.Handler { return } - var req worker.GenSegmentAnything2MultipartRequestBody + var req worker.LlmGenerateFormdataRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -324,6 +325,15 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) + case worker.LlmGenerateFormdataRequestBody: + pipeline = "llm-generate" + cap = core.Capability_LlmGenerate + modelID = *v.ModelId + submitFn = func(ctx context.Context) (interface{}, error) { + return orch.LlmGenerate(ctx, v) + } + + // TODO: handle tokens for pricing default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return @@ -407,7 +417,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit}) } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) + // Check if the response is a streaming response + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + for chunk := range streamChan { + data, err := json.Marshal(chunk) + if err != nil { + clog.Errorf(ctx, "Error marshaling stream chunk: %v", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + + if chunk.Done { + break + } + } + } else { + // Non-streaming response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + } } diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 078fa05ee9..4940531a25 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "time" @@ -70,7 +71,7 @@ func startAIMediaServer(ls *LivepeerServer) error { ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) - + ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate())) return nil } @@ -428,6 +429,73 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler { }) } +func (ls *LivepeerServer) LlmGenerate() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + requestID := string(core.RandomManifestID()) + ctx = clog.AddVal(ctx, "request_id", requestID) + + var req worker.LlmGenerateFormdataRequestBody + + multiRdr, err := r.MultipartReader() + if err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondJsonError(ctx, w, err, http.StatusBadRequest) + return + } + + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + + params := aiRequestParams{ + node: ls.LivepeerNode, + os: drivers.NodeStorage.NewSession(requestID), + sessManager: ls.AISessionManager, + } + + start := time.Now() + resp, err := processLlmGenerate(ctx, params, req) + if err != nil { + var e *ServiceUnavailableError + if errors.As(err, &e) { + respondJsonError(ctx, w, err, http.StatusServiceUnavailable) + return + } + respondJsonError(ctx, w, err, http.StatusInternalServerError) + return + } + + took := time.Since(start) + clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) + + if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + // Handle streaming response (SSE) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + for chunk := range streamChan { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + if chunk.Done { + break + } + } + } else if llmResp, ok := resp.(*worker.LlmResponse); ok { + // Handle non-streaming response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(llmResp) + } else { + http.Error(w, "Unexpected response type", http.StatusInternalServerError) + } + }) +} + func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index 249cc506b6..93e98fb7fa 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -32,6 +32,7 @@ const defaultImageToVideoModelID = "stabilityai/stable-video-diffusion-img2vid-x const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler" const defaultAudioToTextModelID = "openai/whisper-large-v3" const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large" +const defaultLlmGenerateModelID = "meta-llama/llama-3.1-8B-Instruct" type ServiceUnavailableError struct { err error @@ -792,6 +793,159 @@ func submitAudioToText(ctx context.Context, params aiRequestParams, sess *AISess return &res, nil } +func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float64 { + if tokensUsed <= 0 { + return 0 + } + + return took.Seconds() / float64(tokensUsed) +} + +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + resp, err := processAIRequest(ctx, params, req) + if err != nil { + return nil, err + } + + if req.Stream != nil && *req.Stream { + streamChan, ok := resp.(chan worker.LlmStreamChunk) + if !ok { + return nil, errors.New("unexpected response type for streaming request") + } + return streamChan, nil + } + + llmResp, ok := resp.(*worker.LlmResponse) + if !ok { + return nil, errors.New("unexpected response type") + } + + return llmResp, nil +} + +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { + var buf bytes.Buffer + mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, nil) + } + return nil, err + } + + client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient)) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + // TODO: calculate payment + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, 0) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) + + start := time.Now() + resp, err := client.LlmGenerateWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + if req.Stream != nil && *req.Stream { + return handleSSEStream(ctx, resp.Body, sess, req, start) + } + + return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) +} + +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { + streamChan := make(chan worker.LlmStreamChunk, 100) + go func() { + defer close(streamChan) + scanner := bufio.NewScanner(body) + var totalTokens int + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + streamChan <- worker.LlmStreamChunk{Done: true, TokensUsed: totalTokens} + break + } + var chunk worker.LlmStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + clog.Errorf(ctx, "Error unmarshaling SSE data: %v", err) + continue + } + totalTokens += chunk.TokensUsed + streamChan <- chunk + } + } + if err := scanner.Err(); err != nil { + clog.Errorf(ctx, "Error reading SSE stream: %v", err) + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, totalTokens) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + }() + + return streamChan, nil +} + +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { + data, err := io.ReadAll(body) + if err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + var res worker.LlmResponse + if err := json.Unmarshal(data, &res); err != nil { + if monitor.Enabled { + monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) + } + return nil, err + } + + took := time.Since(start) + sess.LatencyScore = CalculateLlmGenerateLatencyScore(took, res.TokensUsed) + + if monitor.Enabled { + var pricePerAIUnit float64 + if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 { + pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit) + } + monitor.AIRequestFinished(ctx, "llm-generate", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo) + } + + return &res, nil +} + func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) { var cap core.Capability var modelID string @@ -852,6 +1006,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitSegmentAnything2(ctx, params, sess, v) } + case worker.LlmGenerateFormdataRequestBody: + cap = core.Capability_LlmGenerate + modelID = defaultLlmGenerateModelID + if v.ModelId != nil { + modelID = *v.ModelId + } + submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { + return submitLlmGenerate(ctx, params, sess, v) + } default: return nil, fmt.Errorf("unsupported request type %T", req) } diff --git a/server/ai_process_test.go b/server/ai_process_test.go new file mode 100644 index 0000000000..e584637ef2 --- /dev/null +++ b/server/ai_process_test.go @@ -0,0 +1,38 @@ +package server + +import ( + "context" + "reflect" + "testing" + + "github.com/livepeer/ai-worker/worker" +) + +func Test_submitLlmGenerate(t *testing.T) { + type args struct { + ctx context.Context + params aiRequestParams + sess *AISession + req worker.LlmGenerateFormdataRequestBody + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := submitLlmGenerate(tt.args.ctx, tt.args.params, tt.args.sess, tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("submitLlmGenerate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("submitLlmGenerate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/rpc.go b/server/rpc.go index 6c1365ccd6..cfcde7532d 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -69,6 +69,7 @@ type Orchestrator interface { Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) + LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance From 9c5ac2bb1e456b08a197db581923c7f61e1d559c Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Thu, 1 Aug 2024 04:14:53 +0200 Subject: [PATCH 02/10] add basic pricing based on max out tokens --- server/ai_http.go | 8 +++++++- server/ai_process.go | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/server/ai_http.go b/server/ai_http.go index 60a836482a..bbe04b964b 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -333,7 +333,13 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return orch.LlmGenerate(ctx, v) } - // TODO: handle tokens for pricing + if v.MaxTokens == nil { + respondWithError(w, "MaxTokens not specified", http.StatusBadRequest) + return + } + + // TODO: Improve pricing + outPixels = int64(*v.MaxTokens) default: respondWithError(w, "Unknown request type", http.StatusBadRequest) return diff --git a/server/ai_process.go b/server/ai_process.go index 93e98fb7fa..ef713d3794 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -841,8 +841,12 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess return nil, err } - // TODO: calculate payment - setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, 0) + // TODO: Improve pricing + if req.MaxTokens == nil { + req.MaxTokens = new(int) + *req.MaxTokens = 256 + } + setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, int64(*req.MaxTokens)) if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) From 50e955ea9e68516c1b3133158abfd66795cfe8b1 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Mon, 5 Aug 2024 20:37:19 +0200 Subject: [PATCH 03/10] temporary: replace ai-worker with Livepool-Io/ai-worker@llm --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c32e9d596c..0ebd5fefd3 100644 --- a/go.mod +++ b/go.mod @@ -239,4 +239,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/livepeer/ai-worker => /Users/nico/livepool/ai-worker +replace github.com/livepeer/ai-worker => github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb diff --git a/go.sum b/go.sum index 8c7b8cd8cb..158a3a2cc4 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= github.com/Joker/jade v1.0.1-0.20190614124447-d475f43051e7/go.mod h1:6E6s8o2AE4KhCrqr6GRJjdC/gNfTdxkIXvuGZZda2VM= +github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb h1:wfspFHOAZcIH8kNQndKdkIsATELMyGPDv/Q3QXk80XA= +github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= From 422bde9a144f4fc9ee3afaf83f36a7082e37d5ec Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Tue, 6 Aug 2024 04:28:10 +0200 Subject: [PATCH 04/10] fix: llm channel receive, reading response body race condition --- server/ai_http.go | 2 +- server/ai_mediaserver.go | 2 +- server/ai_process.go | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/ai_http.go b/server/ai_http.go index bbe04b964b..11bc4bbb42 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -424,7 +424,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request } // Check if the response is a streaming response - if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok { + if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok { // Set headers for SSE w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 4940531a25..8f1e58ac9a 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -449,7 +449,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { return } - clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v", req.Prompt, *req.ModelId) + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream) params := aiRequestParams{ node: ls.LivepeerNode, diff --git a/server/ai_process.go b/server/ai_process.go index ef713d3794..dbebd4474e 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -863,7 +863,6 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess } return nil, err } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -881,6 +880,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r streamChan := make(chan worker.LlmStreamChunk, 100) go func() { defer close(streamChan) + defer body.Close() scanner := bufio.NewScanner(body) var totalTokens int for scanner.Scan() { @@ -921,6 +921,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { data, err := io.ReadAll(body) + defer body.Close() if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) From 4202873e66aefd87badbfc0dfc774e2f613e3610 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Tue, 6 Aug 2024 04:30:45 +0200 Subject: [PATCH 05/10] fixup! temporary: replace ai-worker with Livepool-Io/ai-worker@llm --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 0ebd5fefd3..661157a80c 100644 --- a/go.mod +++ b/go.mod @@ -239,4 +239,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/livepeer/ai-worker => github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb +replace github.com/livepeer/ai-worker => github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834 diff --git a/go.sum b/go.sum index 158a3a2cc4..3760766b36 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= github.com/Joker/jade v1.0.1-0.20190614124447-d475f43051e7/go.mod h1:6E6s8o2AE4KhCrqr6GRJjdC/gNfTdxkIXvuGZZda2VM= -github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb h1:wfspFHOAZcIH8kNQndKdkIsATELMyGPDv/Q3QXk80XA= -github.com/Livepool-io/ai-worker v0.0.0-20240805181656-87bfe3f909eb/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= +github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834 h1:pinX9lPOmYxkGGa6MQUolPqmdIySICvk03uxfMgYUmk= +github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= From 7c53f04ff6ee650d3c52b85b901bda3b539b4739 Mon Sep 17 00:00:00 2001 From: Brad P Date: Thu, 19 Sep 2024 15:54:42 -0500 Subject: [PATCH 06/10] more updates needed for rebasing --- cmd/livepeer/starter/starter.go | 2 +- core/ai.go | 2 +- core/orchestrator.go | 4 ++-- server/ai_http.go | 29 ++++++++++++++++++++++++++--- server/ai_mediaserver.go | 2 +- server/ai_process.go | 12 ++++++------ server/ai_process_test.go | 2 +- server/rpc.go | 2 +- 8 files changed, 39 insertions(+), 16 deletions(-) diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 344c7a6874..69c4bdf2ac 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1334,7 +1334,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { _, ok := capabilityConstraints[core.Capability_LlmGenerate] if !ok { aiCaps = append(aiCaps, core.Capability_LlmGenerate) - capabilityConstraints[core.Capability_LlmGenerate] = &core.PerCapabilityConstraints{ + capabilityConstraints[core.Capability_LlmGenerate] = &core.CapabilityConstraints{ Models: make(map[string]*core.ModelConstraint), } } diff --git a/core/ai.go b/core/ai.go index c637ccbb11..ebc7e3dfc5 100644 --- a/core/ai.go +++ b/core/ai.go @@ -23,7 +23,7 @@ type AI interface { Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(context.Context, worker.LlmGenerateFormdataRequestBody) (interface{}, error) + LlmGenerate(context.Context, worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/orchestrator.go b/core/orchestrator.go index 65b4e82294..e3685fc7f3 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -135,7 +135,7 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe } // Return type is LlmResponse, but a stream is available as well as chan(string) -func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { return orch.node.llmGenerate(ctx, req) } @@ -1056,7 +1056,7 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } -func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { return n.AIWorker.LlmGenerate(ctx, req) } diff --git a/server/ai_http.go b/server/ai_http.go index 11bc4bbb42..7238e38cd8 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -149,7 +149,30 @@ func (h *lphttp) AudioToText() http.Handler { return } - var req worker.AudioToTextMultipartRequestBody + var req worker.GenAudioToTextMultipartRequestBody + if err := runtime.BindMultipart(&req, *multiRdr); err != nil { + respondWithError(w, err.Error(), http.StatusInternalServerError) + return + } + + handleAIRequest(ctx, w, r, orch, req) + }) +} + +func (h *lphttp) SegmentAnything2() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + orch := h.orchestrator + + remoteAddr := getRemoteAddr(r) + ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) + + multiRdr, err := r.MultipartReader() + if err != nil { + respondWithError(w, err.Error(), http.StatusBadRequest) + return + } + + var req worker.GenSegmentAnything2MultipartRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -172,7 +195,7 @@ func (h *lphttp) LlmGenerate() http.Handler { return } - var req worker.LlmGenerateFormdataRequestBody + var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -325,7 +348,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) - case worker.LlmGenerateFormdataRequestBody: + case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: pipeline = "llm-generate" cap = core.Capability_LlmGenerate modelID = *v.ModelId diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 8f1e58ac9a..7b708ffd54 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -436,7 +436,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.LlmGenerateFormdataRequestBody + var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody multiRdr, err := r.MultipartReader() if err != nil { diff --git a/server/ai_process.go b/server/ai_process.go index dbebd4474e..8404ed401e 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -801,7 +801,7 @@ func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float6 return took.Seconds() / float64(tokensUsed) } -func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -823,7 +823,7 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker. return llmResp, nil } -func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) { +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { var buf bytes.Buffer mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) if err != nil { @@ -856,7 +856,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.LlmGenerateWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.LlmGenerateLlmGeneratePostWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) @@ -876,7 +876,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) } -func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { streamChan := make(chan worker.LlmStreamChunk, 100) go func() { defer close(streamChan) @@ -919,7 +919,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r return streamChan, nil } -func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { data, err := io.ReadAll(body) defer body.Close() if err != nil { @@ -1011,7 +1011,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitSegmentAnything2(ctx, params, sess, v) } - case worker.LlmGenerateFormdataRequestBody: + case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: cap = core.Capability_LlmGenerate modelID = defaultLlmGenerateModelID if v.ModelId != nil { diff --git a/server/ai_process_test.go b/server/ai_process_test.go index e584637ef2..e64382fb38 100644 --- a/server/ai_process_test.go +++ b/server/ai_process_test.go @@ -13,7 +13,7 @@ func Test_submitLlmGenerate(t *testing.T) { ctx context.Context params aiRequestParams sess *AISession - req worker.LlmGenerateFormdataRequestBody + req worker.LlmGenerateLlmGeneratePostFormdataRequestBody } tests := []struct { name string diff --git a/server/rpc.go b/server/rpc.go index cfcde7532d..8eccffc328 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -69,7 +69,7 @@ type Orchestrator interface { Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(ctx context.Context, req worker.LlmGenerateFormdataRequestBody) (interface{}, error) + LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance From 009c52c226f61b37b4bec4d41162fcefdd15e589 Mon Sep 17 00:00:00 2001 From: Brad P Date: Thu, 19 Sep 2024 21:40:58 -0500 Subject: [PATCH 07/10] fix seg fault on log line if parameter not included in request --- server/ai_mediaserver.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 7b708ffd54..88cb233ba0 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -449,7 +449,11 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { return } - clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, *req.Stream) + streamResponse := false + if *req.Stream { + streamResponse = *req.Stream + } + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse) params := aiRequestParams{ node: ls.LivepeerNode, From bf072d46c84ab2ce40dc6120f58531dbb0ba3e11 Mon Sep 17 00:00:00 2001 From: Brad P Date: Fri, 20 Sep 2024 07:29:18 -0500 Subject: [PATCH 08/10] fix rebasing error --- cmd/livepeer/starter/starter.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 69c4bdf2ac..6f3802b1ad 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -1328,7 +1328,6 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice) } - n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice) case "llm-generate": _, ok := capabilityConstraints[core.Capability_LlmGenerate] From b77ec57d0440fad1e4c12f54a34012fee8f06b34 Mon Sep 17 00:00:00 2001 From: Brad P Date: Fri, 20 Sep 2024 22:07:04 -0500 Subject: [PATCH 09/10] update for new codegen and update to segfault fix --- core/ai.go | 2 +- core/orchestrator.go | 4 ++-- server/ai_http.go | 4 ++-- server/ai_mediaserver.go | 5 +++-- server/ai_process.go | 12 ++++++------ server/rpc.go | 2 +- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/ai.go b/core/ai.go index ebc7e3dfc5..0b7d223419 100644 --- a/core/ai.go +++ b/core/ai.go @@ -23,7 +23,7 @@ type AI interface { Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(context.Context, worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) + LlmGenerate(context.Context, worker.GenLlmFormdataRequestBody) (interface{}, error) Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error Stop(context.Context) error HasCapacity(pipeline, modelID string) bool diff --git a/core/orchestrator.go b/core/orchestrator.go index e3685fc7f3..aff5303968 100644 --- a/core/orchestrator.go +++ b/core/orchestrator.go @@ -135,7 +135,7 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe } // Return type is LlmResponse, but a stream is available as well as chan(string) -func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { return orch.node.llmGenerate(ctx, req) } @@ -1056,7 +1056,7 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi return &worker.ImageResponse{Images: videos}, nil } -func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) { return n.AIWorker.LlmGenerate(ctx, req) } diff --git a/server/ai_http.go b/server/ai_http.go index 7238e38cd8..e814b9d92e 100644 --- a/server/ai_http.go +++ b/server/ai_http.go @@ -195,7 +195,7 @@ func (h *lphttp) LlmGenerate() http.Handler { return } - var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody + var req worker.GenLlmFormdataRequestBody if err := runtime.BindMultipart(&req, *multiRdr); err != nil { respondWithError(w, err.Error(), http.StatusInternalServerError) return @@ -348,7 +348,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request return } outPixels = int64(config.Height) * int64(config.Width) - case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: + case worker.GenLlmFormdataRequestBody: pipeline = "llm-generate" cap = core.Capability_LlmGenerate modelID = *v.ModelId diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index 88cb233ba0..e4c7d6176b 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -436,7 +436,7 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.LlmGenerateLlmGeneratePostFormdataRequestBody + var req worker.GenLlmFormdataRequestBody multiRdr, err := r.MultipartReader() if err != nil { @@ -450,9 +450,10 @@ func (ls *LivepeerServer) LlmGenerate() http.Handler { } streamResponse := false - if *req.Stream { + if req.Stream != nil { streamResponse = *req.Stream } + clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse) params := aiRequestParams{ diff --git a/server/ai_process.go b/server/ai_process.go index 8404ed401e..af4704069a 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -801,7 +801,7 @@ func CalculateLlmGenerateLatencyScore(took time.Duration, tokensUsed int) float6 return took.Seconds() / float64(tokensUsed) } -func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker.GenLlmFormdataRequestBody) (interface{}, error) { resp, err := processAIRequest(ctx, params, req) if err != nil { return nil, err @@ -823,7 +823,7 @@ func processLlmGenerate(ctx context.Context, params aiRequestParams, req worker. return llmResp, nil } -func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) { +func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLlmFormdataRequestBody) (interface{}, error) { var buf bytes.Buffer mw, err := worker.NewLlmGenerateMultipartWriter(&buf, req) if err != nil { @@ -856,7 +856,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess defer completeBalanceUpdate(sess.BroadcastSession, balUpdate) start := time.Now() - resp, err := client.LlmGenerateLlmGeneratePostWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) + resp, err := client.GenLlmWithBody(ctx, mw.FormDataContentType(), &buf, setHeaders) if err != nil { if monitor.Enabled { monitor.AIRequestError(err.Error(), "llm-generate", *req.ModelId, sess.OrchestratorInfo) @@ -876,7 +876,7 @@ func submitLlmGenerate(ctx context.Context, params aiRequestParams, sess *AISess return handleNonStreamingResponse(ctx, resp.Body, sess, req, start) } -func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { +func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (chan worker.LlmStreamChunk, error) { streamChan := make(chan worker.LlmStreamChunk, 100) go func() { defer close(streamChan) @@ -919,7 +919,7 @@ func handleSSEStream(ctx context.Context, body io.ReadCloser, sess *AISession, r return streamChan, nil } -func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { +func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *AISession, req worker.GenLlmFormdataRequestBody, start time.Time) (*worker.LlmResponse, error) { data, err := io.ReadAll(body) defer body.Close() if err != nil { @@ -1011,7 +1011,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitSegmentAnything2(ctx, params, sess, v) } - case worker.LlmGenerateLlmGeneratePostFormdataRequestBody: + case worker.GenLlmFormdataRequestBody: cap = core.Capability_LlmGenerate modelID = defaultLlmGenerateModelID if v.ModelId != nil { diff --git a/server/rpc.go b/server/rpc.go index 8eccffc328..4732567701 100644 --- a/server/rpc.go +++ b/server/rpc.go @@ -69,7 +69,7 @@ type Orchestrator interface { Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) - LlmGenerate(ctx context.Context, req worker.LlmGenerateLlmGeneratePostFormdataRequestBody) (interface{}, error) + LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) } // Balance describes methods for a session's balance maintenance From a3d74628fea5eccf948493a0f11e1c4776ab2e2a Mon Sep 17 00:00:00 2001 From: Brad P Date: Thu, 19 Sep 2024 15:45:30 -0500 Subject: [PATCH 10/10] update ai-worker to rebased version --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 661157a80c..d11b85fab3 100644 --- a/go.mod +++ b/go.mod @@ -239,4 +239,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/livepeer/ai-worker => github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834 +replace github.com/livepeer/ai-worker => github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c diff --git a/go.sum b/go.sum index 3760766b36..3173327fae 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,6 @@ github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Joker/hpp v1.0.0/go.mod h1:8x5n+M1Hp5hC0g8okX3sR3vFQwynaX/UgSOM9MeBKzY= github.com/Joker/jade v1.0.1-0.20190614124447-d475f43051e7/go.mod h1:6E6s8o2AE4KhCrqr6GRJjdC/gNfTdxkIXvuGZZda2VM= -github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834 h1:pinX9lPOmYxkGGa6MQUolPqmdIySICvk03uxfMgYUmk= -github.com/Livepool-io/ai-worker v0.0.0-20240806021536-468d65dca834/go.mod h1:Xlnb0nFG2VsGeMG9hZmReVQXeFt0Dv28ODiUT2ooyLE= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= @@ -74,6 +72,8 @@ github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDO github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40= github.com/VictoriaMetrics/fastcache v1.12.1/go.mod h1:tX04vaqcNoQeGLD+ra5pU5sWkuxnzWhEzLwhP9w653o= +github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c h1:P1cDtj2uFXuYa1A68NXcocGxvcLt7J/XbjYKVH4LUJ4= +github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -625,8 +625,6 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo= github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc= -github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8= -github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE= github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw= github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=