Skip to content

Commit

Permalink
refactor: update worker classes (#3171)
Browse files Browse the repository at this point in the history
This commit ensures that the go-livepeer code uses the new worker classes that were defined in livepeer/ai-worker#191.
  • Loading branch information
rickstaa authored Sep 13, 2024
1 parent 14f7783 commit ffb1922
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 95 deletions.
8 changes: 4 additions & 4 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewFileWorker(files map[string]string) *FileWorker {
return &FileWorker{files: files}
}

func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["text-to-image"]
if !ok {
return nil, errors.New("text-to-image response file not found")
Expand All @@ -36,7 +36,7 @@ func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSON
return &resp, nil
}

func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["image-to-image"]
if !ok {
return nil, errors.New("image-to-image response file not found")
Expand All @@ -55,7 +55,7 @@ func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMu
return &resp, nil
}

func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["image-to-video"]
if !ok {
return nil, errors.New("image-to-video response file not found")
Expand All @@ -74,7 +74,7 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand Down
12 changes: 6 additions & 6 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ import (
var errPipelineNotAvailable = errors.New("pipeline not available")

type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
TextToImage(context.Context, worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
24 changes: 12 additions & 12 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,27 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes
orch.node.TranscoderManager.transcoderResults(tcID, res)
}

func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return orch.node.textToImage(ctx, req)
}

func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return orch.node.AudioToText(ctx, req)
}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}

Expand Down Expand Up @@ -951,27 +951,27 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS
}
}

func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.TextToImage(ctx, req)
}

func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.2.0
github.com/livepeer/ai-worker v0.5.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240819180416-f87352959b85
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.2.0 h1:u6m3nVQnisqWn2nMhgTgKLQby7VDAEp5xmeHt7Res/Y=
github.com/livepeer/ai-worker v0.2.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8=
github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
36 changes: 18 additions & 18 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (h *lphttp) TextToImage() http.Handler {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

var req worker.TextToImageJSONRequestBody
var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
Expand All @@ -79,7 +79,7 @@ func (h *lphttp) ImageToImage() http.Handler {
return
}

var req worker.ImageToImageMultipartRequestBody
var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -102,7 +102,7 @@ func (h *lphttp) ImageToVideo() http.Handler {
return
}

var req worker.ImageToVideoMultipartRequestBody
var req worker.GenImageToVideoMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -125,7 +125,7 @@ func (h *lphttp) Upscale() http.Handler {
return
}

var req worker.UpscaleMultipartRequestBody
var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -148,7 +148,7 @@ func (h *lphttp) AudioToText() http.Handler {
return
}

var req worker.AudioToTextMultipartRequestBody
var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -171,7 +171,7 @@ func (h *lphttp) SegmentAnything2() http.Handler {
return
}

var req worker.SegmentAnything2MultipartRequestBody
var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -202,7 +202,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
var outPixels int64

switch v := req.(type) {
case worker.TextToImageJSONRequestBody:
case worker.GenTextToImageJSONRequestBody:
pipeline = "text-to-image"
cap = core.Capability_TextToImage
modelID = *v.ModelId
Expand All @@ -226,7 +226,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

outPixels = height * width * numImages
case worker.ImageToImageMultipartRequestBody:
case worker.GenImageToImageMultipartRequestBody:
pipeline = "image-to-image"
cap = core.Capability_ImageToImage
modelID = *v.ModelId
Expand All @@ -251,7 +251,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

outPixels = int64(config.Height) * int64(config.Width) * numImages
case worker.UpscaleMultipartRequestBody:
case worker.GenUpscaleMultipartRequestBody:
pipeline = "upscale"
cap = core.Capability_Upscale
modelID = *v.ModelId
Expand All @@ -270,7 +270,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.ImageToVideoMultipartRequestBody:
case worker.GenImageToVideoMultipartRequestBody:
pipeline = "image-to-video"
cap = core.Capability_ImageToVideo
modelID = *v.ModelId
Expand All @@ -291,7 +291,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
frames := int64(25)

outPixels = height * width * int64(frames)
case worker.AudioToTextMultipartRequestBody:
case worker.GenAudioToTextMultipartRequestBody:
pipeline = "audio-to-text"
cap = core.Capability_AudioToText
modelID = *v.ModelId
Expand All @@ -305,7 +305,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.SegmentAnything2MultipartRequestBody:
case worker.GenSegmentAnything2MultipartRequestBody:
pipeline = "segment-anything-2"
cap = core.Capability_SegmentAnything2
modelID = *v.ModelId
Expand Down Expand Up @@ -382,20 +382,20 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
if monitor.Enabled {
var latencyScore float64
switch v := req.(type) {
case worker.TextToImageJSONRequestBody:
case worker.GenTextToImageJSONRequestBody:
latencyScore = CalculateTextToImageLatencyScore(took, v, outPixels)
case worker.ImageToImageMultipartRequestBody:
case worker.GenImageToImageMultipartRequestBody:
latencyScore = CalculateImageToImageLatencyScore(took, v, outPixels)
case worker.ImageToVideoMultipartRequestBody:
case worker.GenImageToVideoMultipartRequestBody:
latencyScore = CalculateImageToVideoLatencyScore(took, v, outPixels)
case worker.UpscaleMultipartRequestBody:
case worker.GenUpscaleMultipartRequestBody:
latencyScore = CalculateUpscaleLatencyScore(took, v, outPixels)
case worker.AudioToTextMultipartRequestBody:
case worker.GenAudioToTextMultipartRequestBody:
durationSeconds, err := common.CalculateAudioDuration(v.Audio)
if err == nil {
latencyScore = CalculateAudioToTextLatencyScore(took, durationSeconds)
}
case worker.SegmentAnything2MultipartRequestBody:
case worker.GenSegmentAnything2MultipartRequestBody:
latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels)
}

Expand Down
12 changes: 6 additions & 6 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (ls *LivepeerServer) TextToImage() http.Handler {
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.TextToImageJSONRequestBody
var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -129,7 +129,7 @@ func (ls *LivepeerServer) ImageToImage() http.Handler {
return
}

var req worker.ImageToImageMultipartRequestBody
var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -177,7 +177,7 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler {
return
}

var req worker.ImageToVideoMultipartRequestBody
var req worker.GenImageToVideoMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -287,7 +287,7 @@ func (ls *LivepeerServer) Upscale() http.Handler {
return
}

var req worker.UpscaleMultipartRequestBody
var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -335,7 +335,7 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
return
}

var req worker.AudioToTextMultipartRequestBody
var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -388,7 +388,7 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler {
return
}

var req worker.SegmentAnything2MultipartRequestBody
var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down
Loading

0 comments on commit ffb1922

Please sign in to comment.