diff --git a/.gitignore b/.gitignore index 25b38a0c18..901106d56f 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ openai-key.txt # Ignore run_experiments.sh results evals/elsuite/**/logs/ evals/elsuite/**/outputs/ + +.cache/ \ No newline at end of file diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index a8927dda4c..286b16d440 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -224,7 +224,7 @@ def to_number(x: str) -> Union[int, float, str]: **extra_eval_params, ) result = eval.run(recorder) - add_token_usage_to_result(result, recorder) + # add_token_usage_to_result(result, recorder) recorder.record_final_report(result) if not (args.dry_run or args.local_run): diff --git a/evals/elsuite/audio/eval.py b/evals/elsuite/audio/eval.py index b5139bc297..a3a551aba9 100644 --- a/evals/elsuite/audio/eval.py +++ b/evals/elsuite/audio/eval.py @@ -52,6 +52,7 @@ def __init__( self.max_tokens = max_tokens self._recorder = None self.max_audio_duration = max_audio_duration + self.non_error_samples = 0 # Count of samples that didn't error out @property def recorder(self): @@ -61,6 +62,10 @@ def eval_sample(self, sample: Sample, rng): prompt = self._build_prompt(sample, self.text_only) kwargs = self._get_completion_kwargs(sample) sampled = self._do_completion(prompt, **kwargs) + if not sampled.startswith("ERROR:"): # Only count non-error samples + self.non_error_samples += 1 + else: + print(f"Current error count: {self.non_error_samples}") return self._compute_metrics(sample, sampled) def _keep_sample(self, sample): @@ -86,6 +91,7 @@ def record_sampling_wrapper(*args, **kwargs): samples = [s for s in samples if self._keep_sample(s)] self._recorder = recorder self.eval_all_samples(recorder, samples) + logger.info(f"Number of successfully processed samples (non-errors): {self.non_error_samples}") return self._compute_corpus_metrics() def _load_dataset(self): diff --git a/evals/registry/eval_sets/audio.yaml b/evals/registry/eval_sets/audio.yaml index daba301bad..34c832b03d 100644 --- a/evals/registry/eval_sets/audio.yaml +++ b/evals/registry/eval_sets/audio.yaml @@ -22,17 +22,34 @@ audio-core: - audio-translate-covost-tr_en - audio-translate-covost-sv-SE_en +audio-translate-core: + evals: + - audio-transcribe + - audio-translate-covost-en_de + - audio-translate-covost-en_ca + - audio-translate-covost-en_ar + - audio-translate-covost-en_tr + - audio-translate-covost-en_sv-SE + - audio-translate-covost-ru_en + - audio-translate-covost-es_en + - audio-translate-covost-zh-CN_en + - audio-translate-covost-sv-SE_en + - audio-translate-covost-tr_en + + transcript-core: evals: - transcript-transcribe - - transcript-translate-covost-zh-CN_en - - transcript-translate-covost-ru_en - - transcript-translate-covost-es_en + - transcript-translate-covost-en_de - transcript-translate-covost-en_ca - transcript-translate-covost-en_ar - - transcript-translate-covost-en_de - - transcript-translate-covost-tr_en + - transcript-translate-covost-en_tr + - transcript-translate-covost-en_sv-SE + - transcript-translate-covost-ru_en + - transcript-translate-covost-es_en + - transcript-translate-covost-zh-CN_en - transcript-translate-covost-sv-SE_en + - transcript-translate-covost-tr_en audio-translate: evals: diff --git a/evals/registry/solvers/defaults.yaml b/evals/registry/solvers/defaults.yaml index 2d59aa68ab..3d0ad2d38d 100644 --- a/evals/registry/solvers/defaults.yaml +++ b/evals/registry/solvers/defaults.yaml @@ -151,6 +151,15 @@ classification/cot/gpt-4: # generation tasks +generation/direct/gpt-4o: + class: evals.solvers.providers.openai.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4o + extra_options: + temperature: 1 + max_tokens: 512 + generation/direct/gpt-4-turbo-preview: class: evals.solvers.providers.openai.openai_solver:OpenAISolver args: diff --git a/evals/registry/solvers/fireworks.yaml b/evals/registry/solvers/fireworks.yaml new file mode 100644 index 0000000000..8fa5907d7f --- /dev/null +++ b/evals/registry/solvers/fireworks.yaml @@ -0,0 +1,17 @@ +generation/direct/llama-31-70b-instruct: + class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver + args: + completion_fn_options: + model: accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/fixie/deployments/d2630f9a + extra_options: + temperature: 0 + max_tokens: 512 + +generation/direct/llama-31-8b-instruct: + class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver + args: + completion_fn_options: + model: accounts/fireworks/models/llama-v3p1-8b-instruct#accounts/fixie/deployments/682dbb7b + extra_options: + temperature: 0 + max_tokens: 512 diff --git a/evals/registry/solvers/realtime.yaml b/evals/registry/solvers/realtime.yaml index c59b0612bd..24721a0f42 100644 --- a/evals/registry/solvers/realtime.yaml +++ b/evals/registry/solvers/realtime.yaml @@ -6,5 +6,5 @@ generation/realtime/gpt-4o-realtime: completion_fn_options: model: gpt-4o-realtime-preview-2024-10-01 extra_options: - temperature: 1 + temperature: 0 max_tokens: 512 diff --git a/evals/registry/solvers/whisper.yaml b/evals/registry/solvers/whisper.yaml index 4ae3c8a625..e9c3b52f00 100644 --- a/evals/registry/solvers/whisper.yaml +++ b/evals/registry/solvers/whisper.yaml @@ -3,6 +3,56 @@ # text transcript to the downstream solver. The Whisper solver can be used with # any solver that accepts text input. + +generation/whisper/fireworks-llama-31-70b: + class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver + args: + solver: + class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver + args: + completion_fn_options: + # model: accounts/fireworks/models/llama-v3p1-70b-instruct + model: accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/fixie/deployments/d2630f9a + extra_options: + temperature: 0 + max_tokens: 512 + +generation/whisper/fireworks-llama-31-8b: + class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver + args: + solver: + class: evals.solvers.providers.fireworks.fireworks_solver:FireworksSolver + args: + completion_fn_options: + model: accounts/fireworks/models/llama-v3p1-8b-instruct#accounts/fixie/deployments/682dbb7b + extra_options: + temperature: 0 + max_tokens: 512 + +generation/whisper/groq-llama-31-8b: + class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver + args: + solver: + class: evals.solvers.providers.groq.groq_solver:GroqSolver + args: + completion_fn_options: + model: llama-3.1-8b-instant + extra_options: + temperature: 0 + max_tokens: 512 + +generation/whisper/groq-llama-31-70b: + class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver + args: + solver: + class: evals.solvers.providers.groq.groq_solver:GroqSolver + args: + completion_fn_options: + model: llama-3.1-70b-versatile + extra_options: + temperature: 0 + max_tokens: 512 + generation/whisper/gpt-4-turbo: class: evals.solvers.providers.openai.whisper_solver:WhisperCascadedSolver args: diff --git a/evals/solvers/providers/fireworks/fireworks_solver.py b/evals/solvers/providers/fireworks/fireworks_solver.py new file mode 100644 index 0000000000..755af92adc --- /dev/null +++ b/evals/solvers/providers/fireworks/fireworks_solver.py @@ -0,0 +1,10 @@ +from evals.solvers.providers.openai.third_party_solver import ThirdPartySolver + + +class FireworksSolver(ThirdPartySolver): + def __init__(self, **kwargs): + super().__init__( + api_base="https://api.fireworks.ai/inference/v1", + api_key_env_var="FIREWORKS_API_KEY", + **kwargs + ) diff --git a/evals/solvers/providers/openai/realtime_solver.py b/evals/solvers/providers/openai/realtime_solver.py index 0fb78a6348..a98526ad46 100644 --- a/evals/solvers/providers/openai/realtime_solver.py +++ b/evals/solvers/providers/openai/realtime_solver.py @@ -4,6 +4,8 @@ import json import os from typing import Any, Dict, Optional, Union +import time +import websockets.exceptions import librosa import numpy as np @@ -37,11 +39,15 @@ def __init__( completion_fn_options: Dict[str, Any], postprocessors: list[str] = [], registry: Any = None, + max_retries: int = 3, + initial_retry_delay: float = 1.0, ): super().__init__(postprocessors=postprocessors) if "model" not in completion_fn_options: raise ValueError("OpenAISolver requires a model to be specified.") self.completion_fn_options = completion_fn_options + self.max_retries = max_retries + self.initial_retry_delay = initial_retry_delay @property def model(self) -> str: @@ -70,15 +76,30 @@ def _api_key(self) -> Optional[str]: """The API key to use for the API""" return os.getenv("OPENAI_API_KEY") + async def _ws_completion_with_retry(self, messages): + retry_count = 0 + while True: + try: + return await self._ws_completion(messages) + except (websockets.exceptions.WebSocketException, ConnectionError) as e: + retry_count += 1 + if retry_count > self.max_retries: + raise RuntimeError(f"Failed after {self.max_retries} retries") from e + + # Exponential backoff with jitter + delay = self.initial_retry_delay * (2 ** (retry_count - 1)) + jitter = delay * 0.1 * (2 * np.random.random() - 1) + await asyncio.sleep(delay + jitter) + + print(f"Retry {retry_count}/{self.max_retries} after error: {str(e)}") + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: raw_msgs = [ {"role": "system", "content": task_state.task_description}, ] + [msg.to_dict() for msg in task_state.messages] - completion_result = asyncio.run(self._ws_completion(raw_msgs)) + completion_result = asyncio.run(self._ws_completion_with_retry(raw_msgs)) completion_content = completion_result["output"][0]["content"] completion_item = completion_content[0] - # When the model yields text output, the text portion is contained in "text"; - # when the model yields audio output, the text portion is contained in "transcript". completion_output = completion_item.get("text") or completion_item.get("transcript") solver_result = SolverResult(completion_output, raw_completion_result=completion_result) return solver_result diff --git a/scripts/translation_sample_mining.py b/scripts/translation_sample_mining.py new file mode 100644 index 0000000000..48194a01c9 --- /dev/null +++ b/scripts/translation_sample_mining.py @@ -0,0 +1,464 @@ +import json +import os +from typing import List, Dict +import requests +from datasets import load_dataset +from openai import OpenAI +import time +import io +import soundfile as sf +import subprocess +import base64 +import sacrebleu +import multiprocessing as mp +from pathlib import Path +import shutil +from functools import partial +import numpy as np +import hashlib +from tenacity import retry, stop_after_attempt, wait_exponential + +AUDIO_PLACEHOLDER = "<|reserved_special_token_0|>" +TASK_PROMPT = f"Please translate the text to {{language}}. Your response should only include the {{language}} translation, without any additional words. \n\n{AUDIO_PLACEHOLDER}" + +# Initialize API clients +openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +FIREWORKS_API_KEY = os.environ.get("FIREWORKS_API_KEY") +FIXIE_API_KEY = os.environ.get("ULTRAVOX_API_KEY") + +def load_covost_dataset(language_pair: str = "en_ca", split: str = "test") -> List[Dict]: + """Load CoVoST dataset for the specified language pair in streaming mode.""" + dataset = load_dataset( + "fixie-ai/covost2", + name=language_pair, + split=split, + streaming=True + ) + return dataset + +def translate_text_fireworks(text: str, target_language: str) -> str: + """Translate text using Fireworks API directly.""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {FIREWORKS_API_KEY}" + } + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"Please translate this text to {target_language}. Your response should only include the {target_language} translation, without any additional words:\n\n{text}"} + ] + + data = { + "model": "accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/fixie/deployments/d2630f9a", + "messages": messages, + "temperature": 0.1, + "max_tokens": 1024 + } + + response = requests.post( + "https://api.fireworks.ai/inference/v1/chat/completions", + headers=headers, + json=data + ) + + if response.status_code == 200: + return response.json()["choices"][0]["message"]["content"].strip() + else: + raise Exception(f"Fireworks API error: {response.text}") + +def get_auth_token() -> str: + """Get authentication token by impersonating service account.""" + cmd = [ + "gcloud", "auth", "print-identity-token", + "--impersonate-service-account=developer-service-account@fixie-frame.iam.gserviceaccount.com", + "--audiences=https://api.ultravox.ai", + "--include-email" + ] + + try: + token = subprocess.check_output(cmd).decode('utf-8').strip() + return token + except subprocess.CalledProcessError as e: + raise Exception(f"Failed to get auth token: {str(e)}") + +def translate_audio_fixie(audio_data: dict, target_language: str) -> str: + """Translate audio using Fixie/Ultravox API from raw audio data.""" + auth_token = get_auth_token() + + headers = { + "Authorization": f"Bearer {auth_token}", + "Content-Type": "application/json" + } + + # Get the original audio data + audio_array = audio_data['array'] + sampling_rate = audio_data['sampling_rate'] + + # Calculate number of samples for 1.5 seconds of silence + silence_samples = int(1.5 * sampling_rate) + + # Create silence padding and concatenate with original audio at both ends + silence_padding = np.zeros(silence_samples, dtype=audio_array.dtype) + padded_audio = np.concatenate([silence_padding, audio_array, silence_padding]) + + # Convert padded audio to base64 + audio_buffer = io.BytesIO() + sf.write(audio_buffer, padded_audio, sampling_rate, format='wav') + audio_buffer.seek(0) + audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') + data_url = f"data:audio/wav;base64,{audio_base64}" + + # Make the chat completion request + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": data_url + } + }, + { + "type": "text", + "text": TASK_PROMPT.format(language=target_language) + } + ] + } + ] + + data = { + "model": "fixie-ai/ultravox-70B-dev", + "messages": messages + } + + response = requests.post( + "https://api.ultravox.ai/api/chat/completions", + headers=headers, + json=data + ) + + if response.status_code == 200: + return response.json()["choices"][0]["message"]["content"].strip() + else: + raise Exception(f"Fixie API error: {response.text}") + +def evaluate_translation(reference: str, hypothesis: str) -> float: + """Evaluate translation using sentence BLEU score.""" + return sacrebleu.sentence_bleu(hypothesis, [reference]).score + +def calculate_bleu(reference: str, hypothesis: str) -> float: + """Calculate BLEU score between reference and hypothesis.""" + return sacrebleu.sentence_bleu(hypothesis, [reference]).score + +def transcribe_audio_whisper(audio_data: dict) -> str: + """Transcribe audio using OpenAI's Whisper API from raw audio data.""" + # Convert the audio array to a temporary file-like object + audio_array = audio_data['array'] + sampling_rate = audio_data['sampling_rate'] + + # Create an in-memory bytes buffer + audio_buffer = io.BytesIO() + # Write the audio data to the buffer in WAV format + sf.write(audio_buffer, audio_array, sampling_rate, format='wav') + # Seek to the beginning of the buffer + audio_buffer.seek(0) + + # Send to Whisper API + transcript = openai_client.audio.transcriptions.create( + model="whisper-1", + file=("audio.wav", audio_buffer, "audio/wav"), + ) + print(transcript) + return transcript.text + +def get_stable_hash(audio_array: np.ndarray) -> str: + """Generate a stable hash from audio array content.""" + # Convert array to bytes in a consistent way + array_bytes = audio_array.tobytes() + + # Use SHA-256 for consistent hashing + hasher = hashlib.sha256() + hasher.update(array_bytes) + + # Return first 16 characters of hex digest + return hasher.hexdigest()[:16] + +@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) +def get_gpt4_judgment(reference: str, whisper_fireworks: str, fixie: str, target_language: str) -> dict: + """Get GPT-4's judgment on which translation is better.""" + prompt = f"""You are an expert translator and linguistic evaluator. Please compare two machine translations against a reference translation and determine which one is better. Consider accuracy, fluency, and preservation of meaning. + +Reference ({target_language}): {reference} + +Translation A: {whisper_fireworks} +Translation B: {fixie} + +Please respond in the following JSON format only: +{{ + "winner": "A" or "B" or "tie", + "explanation": "brief explanation of the decision" +}}""" + + response = openai_client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a translation evaluation expert. Respond only in the specified JSON format."}, + {"role": "user", "content": prompt} + ], + temperature=0.1, + ) + + try: + return json.loads(response.choices[0].message.content) + except json.JSONDecodeError: + return {"winner": "error", "explanation": "Failed to parse GPT-4 response"} + +def process_single_sample(sample, cache_dir: str, target_language: str = "Catalan"): + """Process a single sample with caching.""" + try: + # Create cache directories + audio_cache_dir = Path(cache_dir) / "audio" + results_cache_dir = Path(cache_dir) / "results" + audio_cache_dir.mkdir(parents=True, exist_ok=True) + results_cache_dir.mkdir(parents=True, exist_ok=True) + + # Generate unique filenames using stable hash + audio_data = sample['audio'] + audio_hash = get_stable_hash(audio_data['array']) + cache_path = audio_cache_dir / f"{audio_hash}.wav" + results_path = results_cache_dir / f"{audio_hash}.json" + print(f"Processing sample {audio_hash}") + + # Check if results already exist + if results_path.exists(): + print("Loading results from cache") + with open(results_path, 'r') as f: + return json.load(f) + + # Cache the audio file if it doesn't exist + if not cache_path.exists(): + sf.write(str(cache_path), audio_data['array'], audio_data['sampling_rate']) + + # Get Whisper transcription + whisper_transcription = transcribe_audio_whisper(audio_data) + ground_truth_transcription = sample['sentence'] + reference_translation = sample['translation'] + + # Calculate BLEU score for transcription + transcription_bleu = calculate_bleu(ground_truth_transcription, whisper_transcription) + + # Get translations + whisper_fireworks_translation = translate_text_fireworks( + whisper_transcription, + target_language + ) + + fixie_translation = translate_audio_fixie( + audio_data, + target_language + ) + + # Evaluate translations using BLEU + whisper_fireworks_bleu = evaluate_translation(reference_translation, whisper_fireworks_translation) + fixie_bleu = evaluate_translation(reference_translation, fixie_translation) + + # Add GPT-4 judgment + gpt4_judgment = get_gpt4_judgment( + reference_translation, + whisper_fireworks_translation, + fixie_translation, + target_language + ) + + result = { + "cache_path": str(cache_path), + "ground_truth_transcription": ground_truth_transcription, + "whisper_transcription": whisper_transcription, + "transcription_bleu": transcription_bleu, + "reference_translation": reference_translation, + "whisper_fireworks_translation": whisper_fireworks_translation, + "fixie_translation": fixie_translation, + "whisper_fireworks_bleu": whisper_fireworks_bleu, + "fixie_bleu": fixie_bleu, + "bleu_delta": fixie_bleu - whisper_fireworks_bleu, + "gpt4_judgment": gpt4_judgment, + "gpt4_winner": "fixie" if gpt4_judgment["winner"] == "B" else + "whisper_fireworks" if gpt4_judgment["winner"] == "A" else + gpt4_judgment["winner"] + } + + # Cache the results + with open(results_path, 'w') as f: + json.dump(result, f) + + return result + + except Exception as e: + print(f"Error processing sample: {str(e)}") + return None + +def print_detailed_example(result: dict, idx: int, category: str): + """Print detailed information for a single example.""" + print(f"\nExample {idx + 1} ({category})") + print("=" * 80) + print(f"BLEU Delta: {result['bleu_delta']:.4f}") + print(f"Audio Path: {result['cache_path']}") + print("\nTranscription Comparison:") + print(f"Ground Truth: {result['ground_truth_transcription']}") + print(f"Whisper: {result['whisper_transcription']}") + print(f"Transcription BLEU: {result['transcription_bleu']:.4f}") + print("\nTranslation Comparison:") + print(f"Reference: {result['reference_translation']}") + print(f"Whisper+Fireworks: {result['whisper_fireworks_translation']}") + print(f"Whisper+Fireworks BLEU: {result['whisper_fireworks_bleu']:.4f}") + print(f"Fixie: {result['fixie_translation']}") + print(f"Fixie BLEU: {result['fixie_bleu']:.4f}") + print(f"GPT-4 Winner: {result['gpt4_winner']}") + print(f"GPT-4 Explanation: {result['gpt4_judgment']['explanation']}") + print("=" * 80) + +def main(): + dataset = load_covost_dataset() + max_samples = 500 + cache_dir = ".cache" + + # List of hashes to rerun + hashes_to_rerun = { + + } + + # Collect samples + samples = [] + for idx, sample in enumerate(dataset): + if idx >= max_samples: + break + + if hashes_to_rerun: + # Use the stable hash function + audio_hash = get_stable_hash(sample['audio']['array']) + print(f"Sample {idx} hash: {audio_hash}") # Debug print + if audio_hash in hashes_to_rerun: + samples.append(sample) + print(f"Found sample with matching hash: {audio_hash}") + else: + samples.append(sample) + + if hashes_to_rerun: + print(f"Found {len(samples)} samples to reprocess") + + # Delete existing cache files for these hashes + audio_cache_dir = Path(cache_dir) / "audio" + results_cache_dir = Path(cache_dir) / "results" + for hash_val in hashes_to_rerun: + # Remove cached results + result_file = results_cache_dir / f"{hash_val}.json" + if result_file.exists(): + result_file.unlink() + + # Remove cached audio + audio_file = audio_cache_dir / f"{hash_val}.wav" + if audio_file.exists(): + audio_file.unlink() + + # Process samples in parallel + with mp.Pool(processes=mp.cpu_count() - 1) as pool: + process_fn = partial(process_single_sample, cache_dir=cache_dir) + results = list(pool.imap(process_fn, samples)) + + # Filter out None results (failed processes) + results = [r for r in results if r is not None] + + # Sort results by BLEU score delta (Fixie - Whisper+Fireworks) + results.sort(key=lambda x: x['bleu_delta'], reverse=True) + + + # Print summary statistics + if results: + whisper_fireworks_avg = np.mean([r["whisper_fireworks_bleu"] for r in results]) + fixie_avg = np.mean([r["fixie_bleu"] for r in results]) + transcription_bleu_avg = np.mean([r["transcription_bleu"] for r in results]) + bleu_delta_avg = np.mean([r["bleu_delta"] for r in results]) + + print("\nResults Summary:") + print("=" * 80) + print(f"Total samples processed: {len(results)}") + print(f"Average Transcription BLEU Score: {transcription_bleu_avg:.4f}") + print(f"Whisper + Fireworks Average BLEU: {whisper_fireworks_avg:.4f}") + print(f"Fixie (Ultravox) Average BLEU: {fixie_avg:.4f}") + print(f"Average BLEU Delta (Fixie - Whisper+Fireworks): {bleu_delta_avg:.4f}") + + # Filter results + fixie_better_results = [r for r in results if r['bleu_delta'] > 0] + fireworks_better_results = [r for r in results if r['bleu_delta'] < 0] + tied_results = [r for r in results if r['bleu_delta'] == 0] + + # Print top 20 examples where Fixie performed better + print(f"\nTop 20 examples where Fixie performed better (out of {len(fixie_better_results)} total wins):") + print("=" * 80) + for idx, result in enumerate(fixie_better_results[:20]): + print_detailed_example(result, idx, "Fixie Better") + + # Print top 20 examples where Fireworks performed better + print(f"\nTop 20 examples where Whisper+Fireworks performed better (out of {len(fireworks_better_results)} total wins):") + print("=" * 80) + for idx, result in enumerate(sorted(fireworks_better_results, key=lambda x: x['bleu_delta'])[:20]): + print_detailed_example(result, idx, "Whisper+Fireworks Better") + + # Print up to 20 tied examples + print(f"\nUp to 20 examples where scores were tied (out of {len(tied_results)} total ties):") + print("=" * 80) + for idx, result in enumerate(tied_results[:20]): + print_detailed_example(result, idx, "Tied Score") + + # Print summary of wins + print("\nWin Statistics:") + print("=" * 80) + print(f"Fixie wins: {len(fixie_better_results)} samples") + print(f"Whisper+Fireworks wins: {len(fireworks_better_results)} samples") + print(f"Ties: {len(tied_results)} samples") + + # Print average BLEU scores for wins + if fixie_better_results: + fixie_win_avg = np.mean([r['bleu_delta'] for r in fixie_better_results]) + print(f"Average BLEU delta when Fixie wins: {fixie_win_avg:.4f}") + if fireworks_better_results: + fireworks_win_avg = np.mean([abs(r['bleu_delta']) for r in fireworks_better_results]) + print(f"Average BLEU delta when Whisper+Fireworks wins: {fireworks_win_avg:.4f}") + + # Add GPT-4 judgment statistics + gpt4_fixie_wins = [r for r in results if r['gpt4_winner'] == 'fixie'] + gpt4_whisper_wins = [r for r in results if r['gpt4_winner'] == 'whisper_fireworks'] + gpt4_ties = [r for r in results if r['gpt4_winner'] == 'tie'] + + print("\nGPT-4 Judgment Statistics:") + print("=" * 80) + print(f"GPT-4 Fixie wins: {len(gpt4_fixie_wins)} samples") + print(f"GPT-4 Whisper+Fireworks wins: {len(gpt4_whisper_wins)} samples") + print(f"GPT-4 Ties: {len(gpt4_ties)} samples") + + # Print top 20 examples where GPT-4 chose Fixie + print(f"\nTop 20 examples where GPT-4 chose Fixie as winner:") + print("=" * 80) + for idx, result in enumerate(gpt4_fixie_wins[:20]): + print_detailed_example(result, idx, "GPT-4 Fixie Win") + + # Print top 20 examples where GPT-4 chose Whisper+Fireworks + print(f"\nTop 20 examples where GPT-4 chose Whisper+Fireworks as winner:") + print("=" * 80) + for idx, result in enumerate(gpt4_whisper_wins[:20]): + print_detailed_example(result, idx, "GPT-4 Whisper+Fireworks Win") + + # Calculate agreement between BLEU and GPT-4 + bleu_gpt4_agreement = len([r for r in results if + (r['bleu_delta'] > 0 and r['gpt4_winner'] == 'fixie') or + (r['bleu_delta'] < 0 and r['gpt4_winner'] == 'whisper_fireworks') or + (r['bleu_delta'] == 0 and r['gpt4_winner'] == 'tie') + ]) + + agreement_percentage = (bleu_gpt4_agreement / len(results)) * 100 + print(f"\nBLEU and GPT-4 Agreement: {agreement_percentage:.2f}%") + +if __name__ == "__main__": + main() \ No newline at end of file