Skip to content

Commit

Permalink
Merge branch 'ai-video' into ai-video-fix-selection-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
ad-astra-video authored Sep 24, 2024
2 parents b965778 + d1b4dec commit 6dae336
Show file tree
Hide file tree
Showing 42 changed files with 978 additions and 412 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ jobs:

upload:
name: Upload artifacts to google bucket
if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository
permissions:
contents: "read"
id-token: "write"
Expand Down
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Changelog

## v0.7.9

- [#3165](https://github.com/livepeer/go-livepeer/pull/3165) Add node version and orch addr to transcoded metadata

### Features ⚒

#### Broadcaster

- [#3158](https://github.com/livepeer/go-livepeer/pull/3158) Add a metric tag for Orchestrator version

### Bug Fixes 🐞

#### Broadcaster

- [#3164](https://github.com/livepeer/go-livepeer/pull/3164) Fix media compatibility check
- [#3166](https://github.com/livepeer/go-livepeer/pull/3166) Clean up inactive sessions
- [#3086](https://github.com/livepeer/go-livepeer/pull/3086) Clear known sessions with inadequate latency scores

## v0.7.8

### Features ⚒
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.7.8-ai.1
0.7.9-ai.1
8 changes: 4 additions & 4 deletions ai/file_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewFileWorker(files map[string]string) *FileWorker {
return &FileWorker{files: files}
}

func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["text-to-image"]
if !ok {
return nil, errors.New("text-to-image response file not found")
Expand All @@ -36,7 +36,7 @@ func (w *FileWorker) TextToImage(ctx context.Context, req worker.TextToImageJSON
return &resp, nil
}

func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["image-to-image"]
if !ok {
return nil, errors.New("image-to-image response file not found")
Expand All @@ -55,7 +55,7 @@ func (w *FileWorker) ImageToImage(ctx context.Context, req worker.ImageToImageMu
return &resp, nil
}

func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error) {
fname, ok := w.files["image-to-video"]
if !ok {
return nil, errors.New("image-to-video response file not found")
Expand All @@ -74,7 +74,7 @@ func (w *FileWorker) ImageToVideo(ctx context.Context, req worker.ImageToVideoMu
return &resp, nil
}

func (w *FileWorker) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (w *FileWorker) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
fname, ok := w.files["upscale"]
if !ok {
return nil, errors.New("upscale response file not found")
Expand Down
1 change: 1 addition & 0 deletions cmd/livepeer/livepeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func parseLivepeerConfig() starter.LivepeerConfig {
cfg.MaxPricePerCapability = flag.String("maxPricePerCapability", *cfg.MaxPricePerCapability, `json list of prices per capability/model or path to json config file. Use "model_id": "default" to price all models in a pipeline the same. Example: {"capabilities_prices": [{"pipeline": "text-to-image", "model_id": "stabilityai/sd-turbo", "price_per_unit": 1000, "pixels_per_unit": 1}, {"pipeline": "upscale", "model_id": "default", price_per_unit": 1200, "pixels_per_unit": 1}]}`)
cfg.IgnoreMaxPriceIfNeeded = flag.Bool("ignoreMaxPriceIfNeeded", *cfg.IgnoreMaxPriceIfNeeded, "Set to true to allow exceeding max price condition if there is no O that meets this requirement")
cfg.MinPerfScore = flag.Float64("minPerfScore", *cfg.MinPerfScore, "The minimum orchestrator's performance score a broadcaster is willing to accept")
cfg.DiscoveryTimeout = flag.Duration("discoveryTimeout", *cfg.DiscoveryTimeout, "Time to wait for orchestrators to return info to be included in transcoding sessions for manifest (default = 500ms)")

// Transcoding:
cfg.Orchestrator = flag.Bool("orchestrator", *cfg.Orchestrator, "Set to true to be an orchestrator")
Expand Down
34 changes: 29 additions & 5 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type LivepeerConfig struct {
MaxPricePerCapability *string
IgnoreMaxPriceIfNeeded *bool
MinPerfScore *float64
DiscoveryTimeout *time.Duration
MaxSessions *string
CurrentManifest *bool
Nvidia *string
Expand Down Expand Up @@ -189,6 +190,7 @@ func DefaultLivepeerConfig() LivepeerConfig {
defaultOrchPerfStatsURL := ""
defaultRegion := ""
defaultMinPerfScore := 0.0
defaultDiscoveryTimeout := 500 * time.Millisecond
defaultCurrentManifest := false
defaultNvidia := ""
defaultNetint := ""
Expand Down Expand Up @@ -287,6 +289,7 @@ func DefaultLivepeerConfig() LivepeerConfig {
OrchPerfStatsURL: &defaultOrchPerfStatsURL,
Region: &defaultRegion,
MinPerfScore: &defaultMinPerfScore,
DiscoveryTimeout: &defaultDiscoveryTimeout,
CurrentManifest: &defaultCurrentManifest,
Nvidia: &defaultNvidia,
Netint: &defaultNetint,
Expand Down Expand Up @@ -877,6 +880,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
glog.Errorf("Error setting up orchestrator: %v", err)
return
}
n.RecipientAddr = recipientAddr.Hex()

sigVerifier := &pm.DefaultSigVerifier{}
validator := pm.NewValidator(sigVerifier, timeWatcher)
Expand Down Expand Up @@ -1227,10 +1231,11 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}
}

// If the config contains a URL we call Warm() anyway because AIWorker will just register
// the endpoint for an external container
if config.Warm || config.URL != "" {
// Register external container endpoint if URL is provided.
endpoint := worker.RunnerEndpoint{URL: config.URL, Token: config.Token}

// Warm the AI worker container or register the endpoint.
if err := n.AIWorker.Warm(ctx, config.Pipeline, config.ModelID, endpoint, config.OptimizationFlags); err != nil {
glog.Errorf("Error AI worker warming %v container: %v", config.Pipeline, err)
return
Expand Down Expand Up @@ -1313,6 +1318,20 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_AudioToText, config.ModelID, autoPrice)
}
case "segment-anything-2":
_, ok := capabilityConstraints[core.Capability_SegmentAnything2]
if !ok {
aiCaps = append(aiCaps, core.Capability_SegmentAnything2)
capabilityConstraints[core.Capability_SegmentAnything2] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_SegmentAnything2].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice)
}
}

if len(aiCaps) > 0 {
Expand Down Expand Up @@ -1404,7 +1423,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
dbOrchPoolCache, err := discovery.NewDBOrchestratorPoolCache(ctx, n, timeWatcher, orchBlacklist)
dbOrchPoolCache, err := discovery.NewDBOrchestratorPoolCache(ctx, n, timeWatcher, orchBlacklist, *cfg.DiscoveryTimeout)
if err != nil {
exit("Could not create orchestrator pool with DB cache: %v", err)
}
Expand All @@ -1419,9 +1438,9 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
glog.Exit("Error setting orch webhook URL ", err)
}
glog.Info("Using orchestrator webhook URL ", whurl)
n.OrchestratorPool = discovery.NewWebhookPool(bcast, whurl)
n.OrchestratorPool = discovery.NewWebhookPool(bcast, whurl, *cfg.DiscoveryTimeout)
} else if len(orchURLs) > 0 {
n.OrchestratorPool = discovery.NewOrchestratorPool(bcast, orchURLs, common.Score_Trusted, orchBlacklist)
n.OrchestratorPool = discovery.NewOrchestratorPool(bcast, orchURLs, common.Score_Trusted, orchBlacklist, *cfg.DiscoveryTimeout)
}

if n.OrchestratorPool == nil {
Expand Down Expand Up @@ -1481,6 +1500,11 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if err != nil {
glog.Exit("Error getting service URI: ", err)
}

if *cfg.Network != "offchain" && !common.ValidateServiceURI(suri) {
glog.Warning("**Warning -serviceAddr is a not a public address or hostname; this is not recommended for onchain networks**")
}

n.SetServiceURI(suri)
// if http addr is not provided, listen to all ifaces
// take the port to listen to from the service URI
Expand Down
6 changes: 6 additions & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"math/big"
"math/rand"
"mime"
"net/url"
"regexp"
"sort"
"strconv"
Expand Down Expand Up @@ -565,3 +566,8 @@ func CalculateAudioDuration(audio types.File) (int64, error) {

return duration, nil
}

// ValidateServiceURI checks if the serviceURI is valid.
func ValidateServiceURI(serviceURI *url.URL) bool {
return !strings.Contains(serviceURI.Host, "0.0.0.0")
}
36 changes: 36 additions & 0 deletions common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"math/big"
"net/url"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -483,3 +484,38 @@ func TestParseAccelDevices_CustomSelection(t *testing.T) {
assert.Equal(ids[1], "3")
assert.Equal(ids[2], "1")
}
func TestValidateServiceURI(t *testing.T) {
// Valid service URIs
validURIs := []string{
"https://8.8.8.8:8935",
"https://127.0.0.1:8935",
}

for _, uri := range validURIs {
serviceURI, err := url.Parse(uri)
if err != nil {
t.Errorf("Failed to parse valid service URI: %v", err)
}

if !ValidateServiceURI(serviceURI) {
t.Errorf("Expected service URI to be valid, but got invalid: %v", uri)
}
}

// Invalid service URIs
invalidURIs := []string{
"http://0.0.0.0",
"https://0.0.0.0",
}

for _, uri := range invalidURIs {
serviceURI, err := url.Parse(uri)
if err != nil {
t.Errorf("Failed to parse invalid service URI: %v", err)
}

if ValidateServiceURI(serviceURI) {
t.Errorf("Expected service URI to be invalid, but got valid: %v", uri)
}
}
}
11 changes: 6 additions & 5 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
var errPipelineNotAvailable = errors.New("pipeline not available")

type AI interface {
TextToImage(context.Context, worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.ImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error)
TextToImage(context.Context, worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error)
ImageToImage(context.Context, worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error)
ImageToVideo(context.Context, worker.GenImageToVideoMultipartRequestBody) (*worker.VideoResponse, error)
Upscale(context.Context, worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error)
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const (
Capability_ImageToVideo
Capability_Upscale
Capability_AudioToText
Capability_SegmentAnything2
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -114,6 +115,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_ImageToVideo: "Image to video",
Capability_Upscale: "Upscale",
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -204,6 +206,7 @@ func OptionalCapabilities() []Capability {
Capability_ImageToVideo,
Capability_Upscale,
Capability_AudioToText,
Capability_SegmentAnything2,
}
}

Expand Down
1 change: 1 addition & 0 deletions core/livepeernode.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ type LivepeerNode struct {
// Transcoder public fields
SegmentChans map[ManifestID]SegmentChan
Recipient pm.Recipient
RecipientAddr string
SelectionAlgorithm common.SelectionAlgorithm
OrchestratorPool common.OrchestratorPool
OrchPerfScore *common.PerfScore
Expand Down
40 changes: 30 additions & 10 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,30 @@ func (orch *orchestrator) TranscoderResults(tcID int64, res *RemoteTranscoderRes
orch.node.TranscoderManager.transcoderResults(tcID, res)
}

func (orch *orchestrator) TextToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) TextToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return orch.node.textToImage(ctx, req)
}

func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToImage(ctx, req)
}

func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) ImageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.imageToVideo(ctx, req)
}

func (orch *orchestrator) Upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (orch *orchestrator) Upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return orch.node.upscale(ctx, req)
}

func (orch *orchestrator) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (orch *orchestrator) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return orch.node.AudioToText(ctx, req)
}

func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return orch.node.SegmentAnything2(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -778,6 +782,17 @@ func (n *LivepeerNode) transcodeSeg(ctx context.Context, config transcodeConfig,
}
md.Fname = url

orchId := "offchain"
if n.RecipientAddr != "" {
orchId = n.RecipientAddr
}
if isRemote {
// huge hack to thread the orch id down to the transcoder
md.Metadata = map[string]string{"orchId": orchId}
} else {
md.Metadata = MakeMetadata(orchId)
}

//Do the transcoding
start := time.Now()
tData, err := transcoder.Transcode(ctx, md)
Expand Down Expand Up @@ -947,23 +962,27 @@ func (n *LivepeerNode) serveTranscoder(stream net.Transcoder_RegisterTranscoderS
}
}

func (n *LivepeerNode) textToImage(ctx context.Context, req worker.TextToImageJSONRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) textToImage(ctx context.Context, req worker.GenTextToImageJSONRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.TextToImage(ctx, req)
}

func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.ImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) imageToImage(ctx context.Context, req worker.GenImageToImageMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.ImageToImage(ctx, req)
}

func (n *LivepeerNode) upscale(ctx context.Context, req worker.UpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) upscale(ctx context.Context, req worker.GenUpscaleMultipartRequestBody) (*worker.ImageResponse, error) {
return n.AIWorker.Upscale(ctx, req)
}

func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.AudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error) {
return n.AIWorker.AudioToText(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.ImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
numVideos := 1

Expand Down Expand Up @@ -1108,6 +1127,7 @@ func (rt *RemoteTranscoder) Transcode(logCtx context.Context, md *SegTranscoding
msg := &net.NotifySegment{
Url: fname,
TaskId: taskID,
OrchId: md.Metadata["orchId"],
SegData: segData,
// Triggers failure on Os that don't know how to use SegData
Profiles: []byte("invalid"),
Expand Down
Loading

0 comments on commit 6dae336

Please sign in to comment.