diff --git a/pkg/providers/clients/errors.go b/pkg/providers/clients/errors.go index bf4635c7..deaf00bc 100644 --- a/pkg/providers/clients/errors.go +++ b/pkg/providers/clients/errors.go @@ -8,6 +8,7 @@ import ( var ( ErrProviderUnavailable = errors.New("provider is not available") + ErrUnauthorized = errors.New("API key is wrong or not set") ErrChatStreamNotImplemented = errors.New("streaming chat API is not implemented for provider") ) diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go index 9fa78281..93a78698 100644 --- a/pkg/providers/lang.go +++ b/pkg/providers/lang.go @@ -110,26 +110,25 @@ func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatRequest return nil, err } - streamResultC := make(chan *clients.ChatStreamResult) - - go func() { - defer close(streamResultC) + startedAt := time.Now() + err = stream.Open() + chunkLatency := time.Since(startedAt) - startedAt := time.Now() - err = stream.Open() - chunkLatency := time.Since(startedAt) + // the first chunk latency + m.chatStreamLatency.Add(float64(chunkLatency)) - // the first chunk latency - m.chatStreamLatency.Add(float64(chunkLatency)) + if err != nil { + // if connection was not even open, we should not send our clients any messages about this failure - if err != nil { - m.healthTracker.TrackErr(err) + m.healthTracker.TrackErr(err) - // if connection was not even open, we should not send our clients any messages about this failure + return nil, err + } - return - } + streamResultC := make(chan *clients.ChatStreamResult) + go func() { + defer close(streamResultC) defer stream.Close() for { diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 49fdc412..94ef5418 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -55,6 +55,10 @@ func (m *ErrorMapper) Map(resp *http.Response) error { return clients.NewRateLimitError(&cooldownDelay) } + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + // Server & client errors result in the same error to keep gateway resilient return clients.ErrProviderUnavailable } diff --git a/pkg/routers/health/tracker.go b/pkg/routers/health/tracker.go index e6135d54..f49e310b 100644 --- a/pkg/routers/health/tracker.go +++ b/pkg/routers/health/tracker.go @@ -8,24 +8,32 @@ import ( // Tracker tracks errors and general health of model provider type Tracker struct { - errBudget *TokenBucket - rateLimit *RateLimitTracker + unauthorized bool + errBudget *TokenBucket + rateLimit *RateLimitTracker } func NewTracker(budget *ErrorBudget) *Tracker { return &Tracker{ - rateLimit: NewRateLimitTracker(), - errBudget: NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()), + unauthorized: false, + rateLimit: NewRateLimitTracker(), + errBudget: NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()), } } func (t *Tracker) Healthy() bool { - return !t.rateLimit.Limited() && t.errBudget.HasTokens() + return !t.unauthorized && !t.rateLimit.Limited() && t.errBudget.HasTokens() } func (t *Tracker) TrackErr(err error) { var rateLimitErr *clients.RateLimitError + if errors.Is(err, clients.ErrUnauthorized) { + t.unauthorized = true + + return + } + if errors.As(err, &rateLimitErr) { t.rateLimit.SetLimited(rateLimitErr.UntilReset()) diff --git a/pkg/routers/router.go b/pkg/routers/router.go index 78eaa672..81ddb47d 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -152,6 +152,7 @@ func (r *LangRouter) ChatStream( langModel := model.(providers.LangModel) modelRespC, err := langModel.ChatStream(ctx, req) + if err != nil { r.tel.L().Error( "Lang model failed to create streaming chat request",