diff --git a/research_bench/benchmark/.gitkeep b/research_bench/benchmark/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/research_bench/crossbench.py b/research_bench/crossbench.py new file mode 100644 index 00000000..25e45edf --- /dev/null +++ b/research_bench/crossbench.py @@ -0,0 +1,68 @@ +import argparse +import re +from typing import Any, Dict, List, Set + +from tqdm import tqdm +from utils import get_paper_by_arxiv_id, process_paper, save_benchmark + + +def get_arxiv_ids(input: str) -> List[str]: + with open(input, 'r', encoding='utf-8') as f: + urls = [line.strip() for line in f if line.strip()] + + arxiv_ids = [] + for url in urls: + match = re.search(r'arxiv\.org/abs/([^\s/]+)', url) + if match: + arxiv_ids.append(match.group(1)) + return arxiv_ids + + +def process_arxiv_ids( + arxiv_ids: List[str], + output: str, +) -> Dict[str, Any]: + benchmark = {} + existing_arxiv_ids: Set[str] = set() + + for arxiv_id in tqdm(arxiv_ids, desc='Processing arXiv IDs'): + if arxiv_id in existing_arxiv_ids: + continue + + paper = get_paper_by_arxiv_id(arxiv_id) + if paper: + paper_data = process_paper(paper) + benchmark[paper_data['arxiv_id']] = paper_data + existing_arxiv_ids.add(arxiv_id) + save_benchmark(benchmark, output) + else: + print(f'Paper with arXiv ID {arxiv_id} not found.') + + return benchmark + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Process arXiv URLs.') + parser.add_argument( + '--input', + type=str, + required=True, + help='Path to the input file containing arXiv URLs.', + ) + parser.add_argument( + '--output', + type=str, + default='./benchmark/crossbench.json', + help='Output file path.', + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + arxiv_ids = get_arxiv_ids(args.input) + process_arxiv_ids(arxiv_ids, args.output) + + +if __name__ == '__main__': + main() diff --git a/research_bench/mlbench.py b/research_bench/mlbench.py new file mode 100644 index 00000000..627ff7e6 --- /dev/null +++ b/research_bench/mlbench.py @@ -0,0 +1,75 @@ +import argparse +from typing import Any, Dict, List, Set + +from tqdm import tqdm +from utils import get_paper_by_keyword, process_paper, save_benchmark + + +def process_keywords( + keywords: List[str], + max_papers_per_keyword: int, + output: str, +) -> Dict[str, Any]: + benchmark = {} + existing_arxiv_ids: Set[str] = set() + + for keyword in keywords: + print(f"Fetching papers for keyword: '{keyword}'") + papers = get_paper_by_keyword( + keyword, existing_arxiv_ids, max_papers_per_keyword + ) + + for paper in tqdm(papers, desc=f"Processing papers for '{keyword}'"): + paper_data = process_paper(paper) + benchmark[paper_data['arxiv_id']] = paper_data + + save_benchmark(benchmark, output) + return benchmark + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Create an AI paper benchmark.') + parser.add_argument( + '--keywords', + type=str, + nargs='+', + default=[ + 'reinforcement learning', + 'large language model', + 'diffusion model', + 'graph neural network', + 'deep learning', + 'representation learning', + 'transformer', + 'federated learning', + 'generative model', + 'self-supervised learning', + 'vision language model', + 'explainable AI', + 'automated machine learning', + ], + help='List of keywords to search for.', + ) + parser.add_argument( + '--max_papers_per_keyword', + type=int, + default=10, + help='Maximum number of papers per keyword.', + ) + parser.add_argument( + '--output', + type=str, + default='./benchmark/mlbench.json', + help='Output file path.', + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + keywords = args.keywords + process_keywords(keywords, args.max_papers_per_keyword, args.output) + + +if __name__ == '__main__': + main() diff --git a/research_bench/scripts/create_crossbench.py b/research_bench/scripts/create_crossbench.py deleted file mode 100644 index cddb0361..00000000 --- a/research_bench/scripts/create_crossbench.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -import os -import re -import time -from typing import Any, Dict, List, Optional, Union, cast - -import requests -from tqdm import tqdm - - -def extract_arxiv_id_from_url(url: str) -> Optional[str]: - match = re.search(r'arxiv\.org/abs/(\d{4}\.\d{5})', url) - if match: - return match.group(1) - else: - return None - - -def get_paper_info(arxiv_id: str, max_retry: int = 5) -> Optional[Dict[str, Any]]: - paper_id = f'ARXIV:{arxiv_id}' - url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}' - fields = 'title,authors,year,venue,abstract' - headers = {'x-api-key': 'FfOnoChxCS2vGorFNV4sQB7KdzzRalp9ygKzAGf8'} - params: Dict[str, Union[str, int]] = { - 'fields': fields, - } - - for attempt in range(max_retry): - response = requests.get(url, params=params, headers=headers) - if response.status_code == 200: - return cast(Dict[str, Any], response.json()) - else: - time.sleep(5) - return None - - -def get_references(arxiv_id: str, max_retry: int = 5) -> Optional[List[Dict[str, Any]]]: - paper_id = f'ARXIV:{arxiv_id}' - url = f'https://api.semanticscholar.org/graph/v1/paper/{paper_id}/references' - fields = 'title,abstract,year,venue,authors,externalIds,url,referenceCount,citationCount,influentialCitationCount,isOpenAccess,fieldsOfStudy' - params: Dict[str, Union[str, int]] = { - 'fields': fields, - 'limit': 100, - } - - references: List[Any] = [] - offset = 0 - - while True: - params['offset'] = offset - for attempt in range(max_retry): - response = requests.get(url, params=params) - if response.status_code == 200: - data = response.json() - if 'data' not in data or not data['data']: - return references - for ref in data['data']: - references.append(ref['citedPaper']) - if len(data['data']) < 100: - return references - offset += 100 - break - else: - time.sleep(5) - else: - return references if references else None - - -def process_paper(arxiv_id: str) -> Optional[Dict[str, Any]]: - paper_info = get_paper_info(arxiv_id) - if not paper_info: - return None - - references = get_references(arxiv_id) - if references is None: - references = [] - - processed_references = [] - for ref in references: - processed_ref = { - 'title': ref.get('title'), - 'abstract': ref.get('abstract'), - 'year': ref.get('year'), - 'venue': ref.get('venue'), - 'authors': [author.get('name') for author in ref.get('authors', [])], - 'externalIds': ref.get('externalIds'), - 'url': ref.get('url'), - 'referenceCount': ref.get('referenceCount'), - 'citationCount': ref.get('citationCount'), - 'influentialCitationCount': ref.get('influentialCitationCount'), - 'isOpenAccess': ref.get('isOpenAccess'), - 'fieldsOfStudy': ref.get('fieldsOfStudy'), - } - processed_references.append(processed_ref) - - paper_data = { - 'paper_title': paper_info.get('title'), - 'arxiv_id': arxiv_id, - 'authors': [author.get('name') for author in paper_info.get('authors', [])], - 'year': paper_info.get('year'), - 'venue': paper_info.get('venue'), - 'abstract': paper_info.get('abstract'), - 'references': processed_references, - } - - return paper_data - - -def process_arxiv_links(input_text: str, output_file: str) -> None: - urls = input_text.strip().split('\n') - - output_data = {} - for url in tqdm(urls, desc='Processing arXiv links'): - arxiv_id = extract_arxiv_id_from_url(url) - if not arxiv_id: - continue - - if arxiv_id in output_data: - continue - - paper_data = process_paper(arxiv_id) - if paper_data: - output_data[paper_data['paper_title']] = paper_data - time.sleep(5) - else: - print(f'can handle arXiv ID: {arxiv_id}') - - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(output_data, f, ensure_ascii=False, indent=4) - - -def main() -> None: - input_text = """https://arxiv.org/abs/2402.13448 -https://arxiv.org/abs/2409.19100 -https://arxiv.org/abs/2409.20252 -https://arxiv.org/abs/2409.20506 -https://arxiv.org/abs/2409.20044 -https://arxiv.org/abs/2409.19864 -https://arxiv.org/abs/2409.18815 -https://arxiv.org/abs/2407.10990 -https://arxiv.org/abs/2404.08001 -https://arxiv.org/abs/2402.09588 -https://arxiv.org/abs/2409.09825 -https://arxiv.org/abs/2403.14801 -https://arxiv.org/abs/2403.07144 -https://arxiv.org/abs/2401.04155 -https://arxiv.org/abs/2409.15675 -https://arxiv.org/abs/2401.14818 -https://arxiv.org/abs/2401.14656 -https://arxiv.org/abs/2401.11052 -https://arxiv.org/abs/2311.12410 -https://arxiv.org/abs/2311.10776 -""" - output_file = '../benchmark/cross_bench.json' - process_arxiv_links(input_text, output_file) - - -if __name__ == '__main__': - main() diff --git a/research_bench/scripts/create_crossbench.sh b/research_bench/scripts/create_crossbench.sh new file mode 100644 index 00000000..a4f19ecf --- /dev/null +++ b/research_bench/scripts/create_crossbench.sh @@ -0,0 +1 @@ +python ../crossbench.py --input ./crossbench_paper_links.txt --output ../benchmark/crossbench.json diff --git a/research_bench/scripts/create_mlbench.py b/research_bench/scripts/create_mlbench.py deleted file mode 100644 index 08e3b4b1..00000000 --- a/research_bench/scripts/create_mlbench.py +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import logging -import os -import time -from typing import Any, Dict, List, Optional, Set, Union - -import arxiv -import requests -from tqdm import tqdm - -# --------------------- -# Configuration Section -# --------------------- - -AI_KEYWORDS = [ - 'reinforcement learning', - 'large language model', - 'diffusion model', - 'graph neural network', - 'deep learning', - 'representation learning', - 'transformer', - 'federate learning', - 'generative model', - 'self-supervised learning', - 'vision language model', - 'explainable ai', - 'automated machine learning', -] - -# Semantic Scholar API configuration -SEMANTIC_SCHOLAR_API_URL = 'https://api.semanticscholar.org/graph/v1/paper/' - -# Logging configuration -LOG_FILE = 'create_bench.log' -logging.basicConfig( - filename=LOG_FILE, - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - filemode='w', -) -console = logging.StreamHandler() -console.setLevel(logging.INFO) -formatter = logging.Formatter('%(levelname)s - %(message)s') -console.setFormatter(formatter) -logging.getLogger().addHandler(console) - -# --------------------- -# Helper Functions -# --------------------- - - -def fetch_papers_for_keyword( - keyword: str, existing_arxiv_ids: Set[str], max_papers: int = 10 -) -> List[arxiv.Result]: - search_query = f'all:"{keyword}" AND (cat:cs.AI OR cat:cs.LG)' - logging.info(f"Searching arXiv for keyword: '{keyword}' with query: {search_query}") - - search = arxiv.Search( - query=search_query, - max_results=max_papers * 5, # Fetch extra to account for duplicates - sort_by=arxiv.SortCriterion.SubmittedDate, - sort_order=arxiv.SortOrder.Descending, - ) - - papers = [] - for paper in search.results(): - if paper.get_short_id() not in existing_arxiv_ids: - papers.append(paper) - existing_arxiv_ids.add(paper.get_short_id()) - if len(papers) >= max_papers * 5: - break - - logging.info(f"Fetched {len(papers)} papers for keyword: '{keyword}'") - return papers - - -def fetch_references( - arxiv_id: str, max_retry: int = 5 -) -> Optional[List[Dict[str, Any]]]: - if 'v' in arxiv_id: - arxiv_id = arxiv_id.split('v')[0] - url = f'{SEMANTIC_SCHOLAR_API_URL}ARXIV:{arxiv_id}/references' - params: Dict[str, Union[int, str]] = { - 'limit': 100, - 'fields': 'title,abstract,year,venue,authors,externalIds,url,referenceCount,citationCount,influentialCitationCount,isOpenAccess,fieldsOfStudy', - } - - for attempt in range(1, max_retry + 1): - try: - response = requests.get(url, params=params, timeout=10) - if response.status_code == 200: - data = response.json() - references = [ - ref['citedPaper'] - for ref in data.get('data', []) - if 'citedPaper' in ref - ] - return [process_reference(ref) for ref in references] - else: - logging.warning( - f'Attempt {attempt}: Failed to fetch references for ARXIV:{arxiv_id} - Status Code: {response.status_code}' - ) - except requests.RequestException as e: - logging.warning( - f'Attempt {attempt}: Error fetching references for ARXIV:{arxiv_id} - {e}' - ) - time.sleep(3) # Wait before retrying - - logging.error( - f'Failed to fetch references for ARXIV:{arxiv_id} after {max_retry} attempts.' - ) - return None - - -def process_reference(ref: Dict[str, Any]) -> Dict[str, Any]: - return { - 'title': ref.get('title', ''), - 'abstract': ref.get('abstract', ''), - 'year': ref.get('year', 0), - 'venue': ref.get('venue', ''), - 'authors': [author.get('name', '') for author in ref.get('authors', [])], - 'externalIds': ref.get('externalIds', {}), - 'url': ref.get('url', ''), - 'referenceCount': ref.get('referenceCount', 0), - 'citationCount': ref.get('citationCount', 0), - 'influentialCitationCount': ref.get('influentialCitationCount', 0), - 'isOpenAccess': ref.get('isOpenAccess', False), - 'fieldsOfStudy': ref.get('fieldsOfStudy', []), - } - - -def create_benchmark( - keywords: List[str], max_papers_per_keyword: int = 10 -) -> Dict[str, Any]: - benchmark = {} - existing_arxiv_ids: Set[str] = set() - - for keyword in keywords: - papers = fetch_papers_for_keyword( - keyword, existing_arxiv_ids, max_papers_per_keyword - ) - paper_count_per_keyword = 0 - for paper in tqdm( - papers, desc=f"Processing keyword: '{keyword}'", unit='paper' - ): - arxiv_id = paper.get_short_id() - if not arxiv_id: - logging.warning( - f"Paper '{paper.title}' does not have a valid arXiv ID. Skipping." - ) - continue - if paper.title in benchmark: - logging.warning( - f"Paper '{paper.title}' already exists in the benchmark. Skipping." - ) - continue - authors = [author.name for author in paper.authors] - references = fetch_references(arxiv_id) - if references is None: - logging.warning( - f"Failed to fetch references for paper: '{paper.title}'" - ) - continue - if len(references) == 0: - logging.warning(f"No references found for paper: '{paper.title}'") - continue - paper_count_per_keyword += 1 - if paper_count_per_keyword > max_papers_per_keyword: - break - benchmark[paper.title] = { - 'paper_title': paper.title, - 'arxiv_id': arxiv_id, - 'keyword': keyword, - 'authors': authors, - 'references': references, - } - - return benchmark - - -def save_benchmark(benchmark: Dict[str, Any], output_path: str) -> None: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(benchmark, f, ensure_ascii=False, indent=4) - logging.info(f'Benchmark saved to {output_path}') - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description='Create a benchmark dataset of AI-related arXiv papers.' - ) - - parser.add_argument( - '--max_papers_per_keyword', - type=int, - default=10, - help='Number of papers to fetch per keyword (default: 10).', - ) - - parser.add_argument( - '--output', - type=str, - default='./benchmark/benchmark.json', - help='Path to save the benchmark JSON file (default: ./benchmark/benchmark.json).', - ) - - return parser.parse_args() - - -def main() -> None: - args = parse_args() - - logging.info('Starting benchmark creation...') - logging.info(f'Max papers per keyword: {args.max_papers_per_keyword}') - logging.info(f'Output path: {args.output}') - - benchmark = create_benchmark( - keywords=AI_KEYWORDS, - max_papers_per_keyword=args.max_papers_per_keyword, - ) - - save_benchmark(benchmark, args.output) - logging.info('Benchmark creation completed successfully.') - - -if __name__ == '__main__': - main() diff --git a/research_bench/scripts/create_mlbench.sh b/research_bench/scripts/create_mlbench.sh new file mode 100644 index 00000000..b0244a6b --- /dev/null +++ b/research_bench/scripts/create_mlbench.sh @@ -0,0 +1 @@ +python ../mlbench.py --max_papers_per_keyword 10 --output ../benchmark/mlbench.json diff --git a/research_bench/scripts/crossbench_paper_links.txt b/research_bench/scripts/crossbench_paper_links.txt new file mode 100644 index 00000000..83cdd012 --- /dev/null +++ b/research_bench/scripts/crossbench_paper_links.txt @@ -0,0 +1,20 @@ +https://arxiv.org/abs/2402.13448 +https://arxiv.org/abs/2409.19100 +https://arxiv.org/abs/2409.20252 +https://arxiv.org/abs/2409.20506 +https://arxiv.org/abs/2409.20044 +https://arxiv.org/abs/2409.19864 +https://arxiv.org/abs/2409.18815 +https://arxiv.org/abs/2407.10990 +https://arxiv.org/abs/2404.08001 +https://arxiv.org/abs/2402.09588 +https://arxiv.org/abs/2409.09825 +https://arxiv.org/abs/2403.14801 +https://arxiv.org/abs/2403.07144 +https://arxiv.org/abs/2401.04155 +https://arxiv.org/abs/2409.15675 +https://arxiv.org/abs/2401.14818 +https://arxiv.org/abs/2401.14656 +https://arxiv.org/abs/2401.11052 +https://arxiv.org/abs/2311.12410 +https://arxiv.org/abs/2311.10776 diff --git a/research_bench/scripts/run_create_bench.sh b/research_bench/scripts/run_create_bench.sh deleted file mode 100644 index 0c4da110..00000000 --- a/research_bench/scripts/run_create_bench.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# ------------------------------------------------------------------- -# run_create_bench.sh -# ------------------------------------------------------------------- -# Description: -# This script runs the create_bench.py Python script to generate -# a benchmark dataset of AI-related arXiv papers based on specified -# keywords and saves the output to a JSON file. -# -# Usage: -# ./run_create_bench.sh -# -# ------------------------------------------------------------------- - -# Define the output path for the benchmark JSON file -OUTPUT_PATH="../benchmark/mlbench.json" - -echo "Output will be saved to: $OUTPUT_PATH" - -python3 create_bench.py --max_papers_per_keyword 10 --output "$OUTPUT_PATH" - -python3 create_crossbench.py - -echo "Crossbench created" diff --git a/research_bench/utils.py b/research_bench/utils.py new file mode 100644 index 00000000..09ff3523 --- /dev/null +++ b/research_bench/utils.py @@ -0,0 +1,85 @@ +# common.py + +import json +import os +import time +from typing import Any, Dict, List, Optional, Set + +import arxiv +import requests + +SEMANTIC_SCHOLAR_API_URL = 'https://api.semanticscholar.org/graph/v1/paper/' + + +def get_references(arxiv_id: str, max_retries: int = 5) -> List[Dict[str, Any]]: + url = f'{SEMANTIC_SCHOLAR_API_URL}ARXIV:{arxiv_id}/references' + params = {'limit': 100} + headers = {'User-Agent': 'PaperProcessor/1.0'} + + for attempt in range(max_retries): + response = requests.get(url, params=params, headers=headers) + if response.status_code == 200: + data = response.json() + return [ + ref['citedPaper'] for ref in data.get('data', []) if 'citedPaper' in ref + ] + else: + wait_time = 2**attempt + print( + f'Error {response.status_code} fetching references for {arxiv_id}. Retrying in {wait_time}s...' + ) + time.sleep(wait_time) # Exponential backoff + print(f'Failed to fetch references for {arxiv_id} after {max_retries} attempts.') + return [] + + +def get_paper_by_keyword( + keyword: str, existing_arxiv_ids: Set[str], max_papers: int = 10 +) -> List[arxiv.Result]: + query = f'all:"{keyword}" AND (cat:cs.AI OR cat:cs.LG)' + search = arxiv.Search( + query=query, + max_results=max_papers * 2, # Fetch extra to account for duplicates + sort_by=arxiv.SortCriterion.SubmittedDate, + ) + + papers = [] + for paper in search.results(): + short_id = paper.get_short_id() + if short_id not in existing_arxiv_ids: + papers.append(paper) + existing_arxiv_ids.add(short_id) + if len(papers) >= max_papers: + break + return papers + + +def save_benchmark(benchmark: Dict[str, Any], output_path: str) -> None: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(benchmark, f, indent=4, ensure_ascii=False) + print(f'Benchmark saved to {output_path}') + + +def get_paper_by_arxiv_id(arxiv_id: str) -> Optional[arxiv.Result]: + try: + search = arxiv.Search(id_list=[arxiv_id]) + results = list(search.results()) + return results[0] if results else None + except Exception as e: + print(f'Error fetching paper {arxiv_id}: {e}') + return None + + +def process_paper(paper: arxiv.Result) -> Dict[str, Any]: + arxiv_id = paper.get_short_id() + references = get_references(arxiv_id) + return { + 'title': paper.title, + 'arxiv_id': arxiv_id, + 'authors': [author.name for author in paper.authors], + 'abstract': paper.summary, + 'published': paper.published.isoformat(), + 'updated': paper.updated.isoformat(), + 'references': references, + }