Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Livepool llm rebase #3178

Closed
15 changes: 15 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,21 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice)
}

case "llm-generate":
_, ok := capabilityConstraints[core.Capability_LlmGenerate]
if !ok {
aiCaps = append(aiCaps, core.Capability_LlmGenerate)
capabilityConstraints[core.Capability_LlmGenerate] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_LlmGenerate].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_LlmGenerate, config.ModelID, autoPrice)
}
}

if len(aiCaps) > 0 {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type AI interface {
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
LlmGenerate(context.Context, worker.GenLlmFormdataRequestBody) (interface{}, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyriediculous, @ad-astra-video can we rename this to LLMGenerate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sounds good

Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
2 changes: 2 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
Capability_Upscale
Capability_AudioToText
Capability_SegmentAnything2
Capability_LlmGenerate
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -116,6 +117,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
Capability_LlmGenerate: "LLM Generate",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down
9 changes: 9 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe
return orch.node.SegmentAnything2(ctx, req)
}

// Return type is LlmResponse, but a stream is available as well as chan(string)
func (orch *orchestrator) LlmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
return orch.node.llmGenerate(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -1051,6 +1056,10 @@ func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVi
return &worker.ImageResponse{Images: videos}, nil
}

func (n *LivepeerNode) llmGenerate(ctx context.Context, req worker.GenLlmFormdataRequestBody) (interface{}, error) {
return n.AIWorker.LlmGenerate(ctx, req)
}

func (rtm *RemoteTranscoderManager) transcoderResults(tcID int64, res *RemoteTranscoderResult) {
remoteChan, err := rtm.getTaskChan(tcID)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,5 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDO
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/VictoriaMetrics/fastcache v1.12.1 h1:i0mICQuojGDL3KblA7wUNlY5lOK6a4bwt3uRKnkZU40=
github.com/VictoriaMetrics/fastcache v1.12.1/go.mod h1:tX04vaqcNoQeGLD+ra5pU5sWkuxnzWhEzLwhP9w653o=
github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c h1:P1cDtj2uFXuYa1A68NXcocGxvcLt7J/XbjYKVH4LUJ4=
github.com/ad-astra-video/ai-worker v0.0.0-20240921034803-5d83b83b7a1c/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
Expand Down Expand Up @@ -623,8 +625,6 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8=
github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
75 changes: 72 additions & 3 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/upscale", oapiReqValidator(lp.Upscale()))
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/llm-generate", oapiReqValidator(lp.LlmGenerate()))

return nil
}
Expand Down Expand Up @@ -181,6 +182,29 @@ func (h *lphttp) SegmentAnything2() http.Handler {
})
}

func (h *lphttp) LlmGenerate() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenLlmFormdataRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -324,6 +348,21 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.GenLlmFormdataRequestBody:
pipeline = "llm-generate"
cap = core.Capability_LlmGenerate
modelID = *v.ModelId
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.LlmGenerate(ctx, v)
}

if v.MaxTokens == nil {
respondWithError(w, "MaxTokens not specified", http.StatusBadRequest)
return
}

// TODO: Improve pricing
outPixels = int64(*v.MaxTokens)
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down Expand Up @@ -407,7 +446,37 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
monitor.AIJobProcessed(ctx, pipeline, modelID, monitor.AIJobInfo{LatencyScore: latencyScore, PricePerUnit: pricePerAIUnit})
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
// Check if the response is a streaming response
if streamChan, ok := resp.(<-chan worker.LlmStreamChunk); ok {
// Set headers for SSE
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

for chunk := range streamChan {
data, err := json.Marshal(chunk)
if err != nil {
clog.Errorf(ctx, "Error marshaling stream chunk: %v", err)
continue
}

fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()

if chunk.Done {
break
}
}
} else {
// Non-streaming response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
}
}
75 changes: 74 additions & 1 deletion server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -70,7 +71,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult())
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))

ls.HTTPMux.Handle("/llm-generate", oapiReqValidator(ls.LlmGenerate()))
return nil
}

Expand Down Expand Up @@ -428,6 +429,78 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler {
})
}

func (ls *LivepeerServer) LlmGenerate() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.GenLlmFormdataRequestBody

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

streamResponse := false
if req.Stream != nil {
streamResponse = *req.Stream
}

clog.V(common.VERBOSE).Infof(ctx, "Received LlmGenerate request prompt=%v model_id=%v stream=%v", req.Prompt, *req.ModelId, streamResponse)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLlmGenerate(ctx, params, req)
if err != nil {
var e *ServiceUnavailableError
if errors.As(err, &e) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed LlmGenerate request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took)

if streamChan, ok := resp.(chan worker.LlmStreamChunk); ok {
// Handle streaming response (SSE)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

for chunk := range streamChan {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
w.(http.Flusher).Flush()
if chunk.Done {
break
}
}
} else if llmResp, ok := resp.(*worker.LlmResponse); ok {
// Handle non-streaming response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(llmResp)
} else {
http.Error(w, "Unexpected response type", http.StatusInternalServerError)
}
})
}

func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
Loading
Loading