Skip to content

Commit

Permalink
Merge pull request #40 from mlfoundations/etashg/lint
Browse files Browse the repository at this point in the history
Etashg/lint
  • Loading branch information
EtashGuha authored Dec 20, 2024
2 parents 32fb732 + d1b4fc1 commit eb65d76
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 138 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
66 changes: 27 additions & 39 deletions create_csv_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

engine, SessionMaker = create_db_engine()


@contextmanager
def session_scope():
"""
Expand All @@ -25,54 +26,42 @@ def session_scope():
finally:
session.close()


def get_model_score(name: str, model_id: uuid.UUID, annotator_model: str) -> float:
with session_scope() as session:
rows = session.query(EvalResult).filter_by(model_id=model_id).all()
if not rows:
return None
for row in rows:
eval_setting = session.query(EvalSetting).filter_by(
id=row.eval_setting_id
).first()
if eval_setting and name == eval_setting.name and eval_setting.parameters['annotator_model'] == annotator_model:
eval_setting = session.query(EvalSetting).filter_by(id=row.eval_setting_id).first()
if (
eval_setting
and name == eval_setting.name
and eval_setting.parameters["annotator_model"] == annotator_model
):
return float(row.score)
return None


def get_model_name(model_id: uuid.UUID) -> str:
with session_scope() as session:
model = session.query(Model).filter_by(id=model_id).first()
return model.name if model else None


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Generate CSV of model evaluation scores')
parser.add_argument(
'--model-ids',
required=True,
nargs='+',
help='List of model UUIDs to evaluate'
)
parser.add_argument(
'--eval-tasks',
required=True,
nargs='+',
help='List of evaluation task names'
)
parser.add_argument(
'--annotator-model',
required=True,
help='Annotator model to filter results'
)
parser.add_argument(
'--output',
default='model_scores.csv',
help='Output CSV filename (default: model_scores.csv)'
)
parser = argparse.ArgumentParser(description="Generate CSV of model evaluation scores")
parser.add_argument("--model-ids", required=True, nargs="+", help="List of model UUIDs to evaluate")
parser.add_argument("--eval-tasks", required=True, nargs="+", help="List of evaluation task names")
parser.add_argument("--annotator-model", required=True, help="Annotator model to filter results")
parser.add_argument("--output", default="model_scores.csv", help="Output CSV filename (default: model_scores.csv)")
return parser.parse_args()


def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_model: str, output_file: str) -> None:
"""
Generate CSV file with model evaluation scores.
Args:
model_ids: List of model UUID strings
eval_tasks: List of evaluation task names
Expand All @@ -87,7 +76,7 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
sys.exit(1)

# Prepare CSV headers
headers = ['model_id', 'model_name'] + eval_tasks
headers = ["model_id", "model_name"] + eval_tasks

# Collect data for each model
rows = []
Expand All @@ -97,21 +86,18 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
print(f"Warning: Model not found for ID {model_id}", file=sys.stderr)
continue

row = {
'model_id': str(model_id),
'model_name': model_name
}

row = {"model_id": str(model_id), "model_name": model_name}

# Get scores for each eval task
for task in eval_tasks:
score = get_model_score(task, model_id, annotator_model)
row[task] = score if score is not None else 'N/A'
row[task] = score if score is not None else "N/A"

rows.append(row)

# Write to CSV
try:
with open(output_file, 'w', newline='') as csvfile:
with open(output_file, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=headers)
writer.writeheader()
writer.writerows(rows)
Expand All @@ -120,9 +106,11 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
print(f"Error writing to CSV file: {e}", file=sys.stderr)
sys.exit(1)


def main():
args = parse_args()
generate_eval_csv(args.model_ids, args.eval_tasks, args.annotator_model, args.output)

if __name__ == '__main__':
main()

if __name__ == "__main__":
main()
9 changes: 3 additions & 6 deletions eval/chat_benchmarks/MTBench/fastchat/data/clean_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Usage:
python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json
"""

import argparse
from concurrent.futures import ProcessPoolExecutor
import json
Expand All @@ -19,9 +20,7 @@

div_pattern = re.compile("<div.*?>")
span_pattern = re.compile("<span.*?>")
code_lang_pattern = re.compile(
"```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL
)
code_lang_pattern = re.compile("```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL)
code_lang_format = "```\g<1>\n\g<2>\n```"
regenerate_pattern = re.compile("\d+ / \d+")
copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words")
Expand Down Expand Up @@ -155,9 +154,7 @@ def clean_html_all(content, begin, end):
content = content[begin:end]
processed = []
with ProcessPoolExecutor() as executor:
for result in tqdm(
executor.map(clean_html_one_sample, content), total=len(content)
):
for result in tqdm(executor.map(clean_html_one_sample, content), total=len(content)):
processed.append(result)

visited = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json
"""

import argparse
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json
"""

import argparse
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json
"""

import argparse
import json
import re
Expand Down
8 changes: 2 additions & 6 deletions eval/chat_benchmarks/MTBench/fastchat/data/get_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def tokenize_one_sample(c):
def tokenize_dataset(content):
processed = []
with ProcessPoolExecutor() as executor:
for result in tqdm(
executor.map(tokenize_one_sample, content), total=len(content)
):
for result in tqdm(executor.map(tokenize_one_sample, content), total=len(content)):
processed.append(result)

return processed
Expand Down Expand Up @@ -59,9 +57,7 @@ def compute_stats(content):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str)
parser.add_argument(
"--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
args = parser.parse_args()

content = json.load(open(args.in_file, "r"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Hardcoded question and answers.
"""

import json


Expand Down
1 change: 1 addition & 0 deletions eval/chat_benchmarks/MTBench/fastchat/data/inspect_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Usage:
python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json
"""

import argparse
import json
import random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Requirement:
pip3 install polyglot pyicu pycld2
"""

import argparse
import json
import re
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Requirement:
pip3 install transformers tqdm
"""

import argparse
import json
import traceback
Expand All @@ -15,9 +16,7 @@
from tqdm import tqdm


def replace_special_tokens(
tokenizer: transformers.PreTrainedTokenizer, text: str
) -> str:
def replace_special_tokens(tokenizer: transformers.PreTrainedTokenizer, text: str) -> str:
if not text:
return text

Expand Down
10 changes: 2 additions & 8 deletions eval/chat_benchmarks/MTBench/fastchat/data/prepare_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,14 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521")
parser.add_argument(
"--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--seq-len", type=int, default=4096)
args = parser.parse_args()

in_prefix = args.prefix
model_path = args.model_name_or_path
seq_len = args.seq_len
prefix = (
f"{in_prefix}_{seq_len}".replace("4096", "4k")
.replace("8192", "8k")
.replace("16384", "16k")
)
prefix = f"{in_prefix}_{seq_len}".replace("4096", "4k").replace("8192", "8k").replace("16384", "16k")

cmd_list = [
f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json",
Expand Down
1 change: 1 addition & 0 deletions eval/chat_benchmarks/MTBench/fastchat/data/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json
"""

import argparse
import json

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
--out sharegpt_split.json \
--model-name-or-path $<model-name>
"""

import argparse
from concurrent.futures import ProcessPoolExecutor
import json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json
"""

import argparse
import json

Expand Down
Loading

0 comments on commit eb65d76

Please sign in to comment.