Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bench): rewrite bench functions #748

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
68 changes: 68 additions & 0 deletions research_bench/crossbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
import re
from typing import Any, Dict, List, Set

from tqdm import tqdm
from utils import get_paper_by_arxiv_id, process_paper, save_benchmark


def get_arxiv_ids(input: str) -> List[str]:
with open(input, 'r', encoding='utf-8') as f:
urls = [line.strip() for line in f if line.strip()]

arxiv_ids = []
for url in urls:
match = re.search(r'arxiv\.org/abs/([^\s/]+)', url)
if match:
arxiv_ids.append(match.group(1))
return arxiv_ids


def process_arxiv_ids(
arxiv_ids: List[str],
output: str,
) -> Dict[str, Any]:
benchmark = {}
existing_arxiv_ids: Set[str] = set()

for arxiv_id in tqdm(arxiv_ids, desc='Processing arXiv IDs'):
if arxiv_id in existing_arxiv_ids:
continue

paper = get_paper_by_arxiv_id(arxiv_id)
if paper:
paper_data = process_paper(paper)
benchmark[paper_data['arxiv_id']] = paper_data
existing_arxiv_ids.add(arxiv_id)
save_benchmark(benchmark, output)
else:
print(f'Paper with arXiv ID {arxiv_id} not found.')

return benchmark


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Process arXiv URLs.')
parser.add_argument(
'--input',
type=str,
required=True,
help='Path to the input file containing arXiv URLs.',
)
parser.add_argument(
'--output',
type=str,
default='./benchmark/crossbench.json',
help='Output file path.',
)
return parser.parse_args()


def main() -> None:
args = parse_args()
arxiv_ids = get_arxiv_ids(args.input)
process_arxiv_ids(arxiv_ids, args.output)


if __name__ == '__main__':
main()
75 changes: 75 additions & 0 deletions research_bench/mlbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import argparse
from typing import Any, Dict, List, Set

from tqdm import tqdm
from utils import get_paper_by_keyword, process_paper, save_benchmark


def process_keywords(
keywords: List[str],
max_papers_per_keyword: int,
output: str,
) -> Dict[str, Any]:
benchmark = {}
existing_arxiv_ids: Set[str] = set()

for keyword in keywords:
print(f"Fetching papers for keyword: '{keyword}'")
papers = get_paper_by_keyword(
keyword, existing_arxiv_ids, max_papers_per_keyword
)

for paper in tqdm(papers, desc=f"Processing papers for '{keyword}'"):
paper_data = process_paper(paper)
benchmark[paper_data['arxiv_id']] = paper_data

save_benchmark(benchmark, output)
return benchmark


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Create an AI paper benchmark.')
parser.add_argument(
'--keywords',
type=str,
nargs='+',
default=[
'reinforcement learning',
'large language model',
'diffusion model',
'graph neural network',
'deep learning',
'representation learning',
'transformer',
'federated learning',
'generative model',
'self-supervised learning',
'vision language model',
'explainable AI',
'automated machine learning',
],
help='List of keywords to search for.',
)
parser.add_argument(
'--max_papers_per_keyword',
type=int,
default=10,
help='Maximum number of papers per keyword.',
)
parser.add_argument(
'--output',
type=str,
default='./benchmark/mlbench.json',
help='Output file path.',
)
return parser.parse_args()


def main() -> None:
args = parse_args()
keywords = args.keywords
process_keywords(keywords, args.max_papers_per_keyword, args.output)


if __name__ == '__main__':
main()
159 changes: 0 additions & 159 deletions research_bench/scripts/create_crossbench.py

This file was deleted.

1 change: 1 addition & 0 deletions research_bench/scripts/create_crossbench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python ../crossbench.py --input ./crossbench_paper_links.txt --output ../benchmark/crossbench.json
Loading
Loading