Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry to realtime solver #36

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ openai-key.txt
# Ignore run_experiments.sh results
evals/elsuite/**/logs/
evals/elsuite/**/outputs/

.cache/
2 changes: 1 addition & 1 deletion evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def to_number(x: str) -> Union[int, float, str]:
**extra_eval_params,
)
result = eval.run(recorder)
add_token_usage_to_result(result, recorder)
# add_token_usage_to_result(result, recorder)
recorder.record_final_report(result)

if not (args.dry_run or args.local_run):
Expand Down
6 changes: 6 additions & 0 deletions evals/elsuite/audio/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self.max_tokens = max_tokens
self._recorder = None
self.max_audio_duration = max_audio_duration
self.non_error_samples = 0 # Count of samples that didn't error out

@property
def recorder(self):
Expand All @@ -61,6 +62,10 @@ def eval_sample(self, sample: Sample, rng):
prompt = self._build_prompt(sample, self.text_only)
kwargs = self._get_completion_kwargs(sample)
sampled = self._do_completion(prompt, **kwargs)
if not sampled.startswith("ERROR:"): # Only count non-error samples
self.non_error_samples += 1
else:
print(f"Current error count: {self.non_error_samples}")
return self._compute_metrics(sample, sampled)

def _keep_sample(self, sample):
Expand All @@ -86,6 +91,7 @@ def record_sampling_wrapper(*args, **kwargs):
samples = [s for s in samples if self._keep_sample(s)]
self._recorder = recorder
self.eval_all_samples(recorder, samples)
logger.info(f"Number of successfully processed samples (non-errors): {self.non_error_samples}")
return self._compute_corpus_metrics()

def _load_dataset(self):
Expand Down
27 changes: 22 additions & 5 deletions evals/registry/eval_sets/audio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,34 @@ audio-core:
- audio-translate-covost-tr_en
- audio-translate-covost-sv-SE_en

audio-translate-core:
evals:
- audio-transcribe
- audio-translate-covost-en_de
- audio-translate-covost-en_ca
- audio-translate-covost-en_ar
- audio-translate-covost-en_tr
- audio-translate-covost-en_sv-SE
- audio-translate-covost-ru_en
- audio-translate-covost-es_en
- audio-translate-covost-zh-CN_en
- audio-translate-covost-sv-SE_en
- audio-translate-covost-tr_en


transcript-core:
evals:
- transcript-transcribe
- transcript-translate-covost-zh-CN_en
- transcript-translate-covost-ru_en
- transcript-translate-covost-es_en
- transcript-translate-covost-en_de
- transcript-translate-covost-en_ca
- transcript-translate-covost-en_ar
- transcript-translate-covost-en_de
- transcript-translate-covost-tr_en
- transcript-translate-covost-en_tr
- transcript-translate-covost-en_sv-SE
- transcript-translate-covost-ru_en
- transcript-translate-covost-es_en
- transcript-translate-covost-zh-CN_en
- transcript-translate-covost-sv-SE_en
- transcript-translate-covost-tr_en

audio-translate:
evals:
Expand Down
9 changes: 9 additions & 0 deletions evals/registry/solvers/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ classification/cot/gpt-4:

# generation tasks

generation/direct/gpt-4o:
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
completion_fn_options:
model: gpt-4o
extra_options:
temperature: 1
max_tokens: 512

generation/direct/gpt-4-turbo-preview:
class: evals.solvers.providers.openai.openai_solver:OpenAISolver
args:
Expand Down
17 changes: 17 additions & 0 deletions evals/registry/solvers/fireworks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
generation/direct/llama-31-70b-instruct:
class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver
args:
completion_fn_options:
model: accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/fixie/deployments/d2630f9a
extra_options:
temperature: 0
max_tokens: 512

generation/direct/llama-31-8b-instruct:
class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver
args:
completion_fn_options:
model: accounts/fireworks/models/llama-v3p1-8b-instruct#accounts/fixie/deployments/682dbb7b
extra_options:
temperature: 0
max_tokens: 512
2 changes: 1 addition & 1 deletion evals/registry/solvers/realtime.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ generation/realtime/gpt-4o-realtime:
completion_fn_options:
model: gpt-4o-realtime-preview-2024-10-01
extra_options:
temperature: 1
temperature: 0
max_tokens: 512
50 changes: 50 additions & 0 deletions evals/registry/solvers/whisper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,56 @@
# text transcript to the downstream solver. The Whisper solver can be used with
# any solver that accepts text input.


generation/whisper/fireworks-llama-31-70b:
class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver
args:
solver:
class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver
args:
completion_fn_options:
# model: accounts/fireworks/models/llama-v3p1-70b-instruct
model: accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/fixie/deployments/d2630f9a
extra_options:
temperature: 0
max_tokens: 512

generation/whisper/fireworks-llama-31-8b:
class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver
args:
solver:
class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver
args:
completion_fn_options:
model: accounts/fireworks/models/llama-v3p1-8b-instruct#accounts/fixie/deployments/682dbb7b
extra_options:
temperature: 0
max_tokens: 512

generation/whisper/groq-llama-31-8b:
class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver
args:
solver:
class: evals.solvers.providers.groq.groq_solver:GroqSolver
args:
completion_fn_options:
model: llama-3.1-8b-instant
extra_options:
temperature: 0
max_tokens: 512

generation/whisper/groq-llama-31-70b:
class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver
args:
solver:
class: evals.solvers.providers.groq.groq_solver:GroqSolver
args:
completion_fn_options:
model: llama-3.1-70b-versatile
extra_options:
temperature: 0
max_tokens: 512

generation/whisper/gpt-4-turbo:
class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver
args:
Expand Down
10 changes: 10 additions & 0 deletions evals/solvers/providers/fireworks/fireworks_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from evals.solvers.providers.openai.third_party_solver import ThirdPartySolver


class FireworksSolver(ThirdPartySolver):
def __init__(self, **kwargs):
super().__init__(
api_base="https://api.fireworks.ai/inference/v1",
api_key_env_var="FIREWORKS_API_KEY",
**kwargs
)
27 changes: 24 additions & 3 deletions evals/solvers/providers/openai/realtime_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import os
from typing import Any, Dict, Optional, Union
import time
import websockets.exceptions

import librosa
import numpy as np
Expand Down Expand Up @@ -37,11 +39,15 @@ def __init__(
completion_fn_options: Dict[str, Any],
postprocessors: list[str] = [],
registry: Any = None,
max_retries: int = 3,
initial_retry_delay: float = 1.0,
):
super().__init__(postprocessors=postprocessors)
if "model" not in completion_fn_options:
raise ValueError("OpenAISolver requires a model to be specified.")
self.completion_fn_options = completion_fn_options
self.max_retries = max_retries
self.initial_retry_delay = initial_retry_delay

@property
def model(self) -> str:
Expand Down Expand Up @@ -70,15 +76,30 @@ def _api_key(self) -> Optional[str]:
"""The API key to use for the API"""
return os.getenv("OPENAI_API_KEY")

async def _ws_completion_with_retry(self, messages):
retry_count = 0
while True:
try:
return await self._ws_completion(messages)
except (websockets.exceptions.WebSocketException, ConnectionError) as e:
retry_count += 1
if retry_count > self.max_retries:
raise RuntimeError(f"Failed after {self.max_retries} retries") from e

# Exponential backoff with jitter
delay = self.initial_retry_delay * (2 ** (retry_count - 1))
jitter = delay * 0.1 * (2 * np.random.random() - 1)
await asyncio.sleep(delay + jitter)

print(f"Retry {retry_count}/{self.max_retries} after error: {str(e)}")

def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
raw_msgs = [
{"role": "system", "content": task_state.task_description},
] + [msg.to_dict() for msg in task_state.messages]
completion_result = asyncio.run(self._ws_completion(raw_msgs))
completion_result = asyncio.run(self._ws_completion_with_retry(raw_msgs))
completion_content = completion_result["output"][0]["content"]
completion_item = completion_content[0]
# When the model yields text output, the text portion is contained in "text";
# when the model yields audio output, the text portion is contained in "transcript".
completion_output = completion_item.get("text") or completion_item.get("transcript")
solver_result = SolverResult(completion_output, raw_completion_result=completion_result)
return solver_result
Expand Down
Loading