Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update worker classes #3171

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewFileWorker(files map[string]string) *FileWorker {
return &FileWorker{files: files}
}

func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["text-to-image"]
if !ok {
return nil, errors.New("text-to-image response file not found")
Expand All @@ -36,7 +36,7 @@ func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSON
return &resp, nil
}

func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["image-to-image"]
if !ok {
return nil, errors.New("image-to-image response file not found")
Expand All @@ -55,7 +55,7 @@ func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMu
return &resp, nil
}

func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["image-to-video"]
if !ok {
return nil, errors.New("image-to-video response file not found")
Expand All @@ -74,7 +74,7 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand Down
12 changes: 6 additions & 6 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ import (
var errPipelineNotAvailable = errors.New("pipeline not available")

type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
TextToImage(context.Context, worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
24 changes: 12 additions & 12 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,27 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes
orch.node.TranscoderManager.transcoderResults(tcID, res)
}

func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return orch.node.textToImage(ctx, req)
}

func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return orch.node.AudioToText(ctx, req)
}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}

Expand Down Expand Up @@ -951,27 +951,27 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS
}
}

func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.TextToImage(ctx, req)
}

func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.SegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/golang/protobuf v1.5.4
github.com/jaypipes/ghw v0.10.0
github.com/jaypipes/pcidb v1.0.0
github.com/livepeer/ai-worker v0.2.0
github.com/livepeer/ai-worker v0.5.0
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b
github.com/livepeer/livepeer-data v0.7.5-0.20231004073737-06f1f383fb18
github.com/livepeer/lpms v0.0.0-20240819180416-f87352959b85
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.2.0 h1:u6m3nVQnisqWn2nMhgTgKLQby7VDAEp5xmeHt7Res/Y=
github.com/livepeer/ai-worker v0.2.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/ai-worker v0.5.0 h1:dgO6j9QVFPOq9omIcgB1YmgVSlhV94BMb6QO4WUocX8=
github.com/livepeer/ai-worker v0.5.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
36 changes: 18 additions & 18 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (h *lphttp) TextToImage() http.Handler {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

var req worker.TextToImageJSONRequestBody
var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondWithError(w, err.Error(), http.StatusBadRequest)
return
Expand All @@ -79,7 +79,7 @@ func (h *lphttp) ImageToImage() http.Handler {
return
}

var req worker.ImageToImageMultipartRequestBody
var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -102,7 +102,7 @@ func (h *lphttp) ImageToVideo() http.Handler {
return
}

var req worker.ImageToVideoMultipartRequestBody
var req worker.GenImageToVideoMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -125,7 +125,7 @@ func (h *lphttp) Upscale() http.Handler {
return
}

var req worker.UpscaleMultipartRequestBody
var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -148,7 +148,7 @@ func (h *lphttp) AudioToText() http.Handler {
return
}

var req worker.AudioToTextMultipartRequestBody
var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -171,7 +171,7 @@ func (h *lphttp) SegmentAnything2() http.Handler {
return
}

var req worker.SegmentAnything2MultipartRequestBody
var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -202,7 +202,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
var outPixels int64

switch v := req.(type) {
case worker.TextToImageJSONRequestBody:
case worker.GenTextToImageJSONRequestBody:
pipeline = "text-to-image"
cap = core.Capability_TextToImage
modelID = *v.ModelId
Expand All @@ -226,7 +226,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

outPixels = height * width * numImages
case worker.ImageToImageMultipartRequestBody:
case worker.GenImageToImageMultipartRequestBody:
pipeline = "image-to-image"
cap = core.Capability_ImageToImage
modelID = *v.ModelId
Expand All @@ -251,7 +251,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
}

outPixels = int64(config.Height) * int64(config.Width) * numImages
case worker.UpscaleMultipartRequestBody:
case worker.GenUpscaleMultipartRequestBody:
pipeline = "upscale"
cap = core.Capability_Upscale
modelID = *v.ModelId
Expand All @@ -270,7 +270,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.ImageToVideoMultipartRequestBody:
case worker.GenImageToVideoMultipartRequestBody:
pipeline = "image-to-video"
cap = core.Capability_ImageToVideo
modelID = *v.ModelId
Expand All @@ -291,7 +291,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
frames := int64(25)

outPixels = height * width * int64(frames)
case worker.AudioToTextMultipartRequestBody:
case worker.GenAudioToTextMultipartRequestBody:
pipeline = "audio-to-text"
cap = core.Capability_AudioToText
modelID = *v.ModelId
Expand All @@ -305,7 +305,7 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels *= 1000 // Convert to milliseconds
case worker.SegmentAnything2MultipartRequestBody:
case worker.GenSegmentAnything2MultipartRequestBody:
pipeline = "segment-anything-2"
cap = core.Capability_SegmentAnything2
modelID = *v.ModelId
Expand Down Expand Up @@ -382,20 +382,20 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
if monitor.Enabled {
var latencyScore float64
switch v := req.(type) {
case worker.TextToImageJSONRequestBody:
case worker.GenTextToImageJSONRequestBody:
latencyScore = CalculateTextToImageLatencyScore(took, v, outPixels)
case worker.ImageToImageMultipartRequestBody:
case worker.GenImageToImageMultipartRequestBody:
latencyScore = CalculateImageToImageLatencyScore(took, v, outPixels)
case worker.ImageToVideoMultipartRequestBody:
case worker.GenImageToVideoMultipartRequestBody:
latencyScore = CalculateImageToVideoLatencyScore(took, v, outPixels)
case worker.UpscaleMultipartRequestBody:
case worker.GenUpscaleMultipartRequestBody:
latencyScore = CalculateUpscaleLatencyScore(took, v, outPixels)
case worker.AudioToTextMultipartRequestBody:
case worker.GenAudioToTextMultipartRequestBody:
durationSeconds, err := common.CalculateAudioDuration(v.Audio)
if err == nil {
latencyScore = CalculateAudioToTextLatencyScore(took, durationSeconds)
}
case worker.SegmentAnything2MultipartRequestBody:
case worker.GenSegmentAnything2MultipartRequestBody:
latencyScore = CalculateSegmentAnything2LatencyScore(took, outPixels)
}

Expand Down
12 changes: 6 additions & 6 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (ls *LivepeerServer) TextToImage() http.Handler {
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

var req worker.TextToImageJSONRequestBody
var req worker.GenTextToImageJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -129,7 +129,7 @@ func (ls *LivepeerServer) ImageToImage() http.Handler {
return
}

var req worker.ImageToImageMultipartRequestBody
var req worker.GenImageToImageMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -177,7 +177,7 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler {
return
}

var req worker.ImageToVideoMultipartRequestBody
var req worker.GenImageToVideoMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -287,7 +287,7 @@ func (ls *LivepeerServer) Upscale() http.Handler {
return
}

var req worker.UpscaleMultipartRequestBody
var req worker.GenUpscaleMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -335,7 +335,7 @@ func (ls *LivepeerServer) AudioToText() http.Handler {
return
}

var req worker.AudioToTextMultipartRequestBody
var req worker.GenAudioToTextMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down Expand Up @@ -388,7 +388,7 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler {
return
}

var req worker.SegmentAnything2MultipartRequestBody
var req worker.GenSegmentAnything2MultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
Expand Down
Loading
Loading