diff --git a/research_bench/crossbench.py b/research_bench/crossbench.py index 4efba6c2..09f98bbc 100644 --- a/research_bench/crossbench.py +++ b/research_bench/crossbench.py @@ -5,12 +5,10 @@ from typing import Any, Dict, List, Set from tqdm import tqdm - from utils import get_paper_by_arxiv_id, process_paper, save_benchmark def get_arxiv_ids(input: str) -> List[str]: - with open(input, 'r', encoding='utf-8') as f: urls = [line.strip() for line in f if line.strip()] @@ -40,7 +38,7 @@ def process_arxiv_ids( existing_arxiv_ids.add(arxiv_id) save_benchmark(benchmark, output) else: - print(f"Paper with arXiv ID {arxiv_id} not found.") + print(f'Paper with arXiv ID {arxiv_id} not found.') return benchmark @@ -51,13 +49,13 @@ def parse_args(): '--input', type=str, required=True, - help='Path to the input file containing arXiv URLs.' + help='Path to the input file containing arXiv URLs.', ) parser.add_argument( '--output', type=str, default='./benchmark/crossbench.json', - help='Output file path.' + help='Output file path.', ) return parser.parse_args() diff --git a/research_bench/mlbench.py b/research_bench/mlbench.py index 75fe412d..8d9eac87 100644 --- a/research_bench/mlbench.py +++ b/research_bench/mlbench.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Set from tqdm import tqdm - from utils import get_paper_by_keyword, process_paper, save_benchmark @@ -18,7 +17,9 @@ def process_keywords( for keyword in keywords: print(f"Fetching papers for keyword: '{keyword}'") - papers = get_paper_by_keyword(keyword, existing_arxiv_ids, max_papers_per_keyword) + papers = get_paper_by_keyword( + keyword, existing_arxiv_ids, max_papers_per_keyword + ) for paper in tqdm(papers, desc=f"Processing papers for '{keyword}'"): paper_data = process_paper(paper) @@ -49,19 +50,19 @@ def parse_args(): 'explainable AI', 'automated machine learning', ], - help='List of keywords to search for.' + help='List of keywords to search for.', ) parser.add_argument( '--max_papers_per_keyword', type=int, default=10, - help='Maximum number of papers per keyword.' + help='Maximum number of papers per keyword.', ) parser.add_argument( '--output', type=str, default='./benchmark/mlbench.json', - help='Output file path.' + help='Output file path.', ) return parser.parse_args() diff --git a/research_bench/scripts/crossbench_paper_links.txt b/research_bench/scripts/crossbench_paper_links.txt index fd144581..83cdd012 100644 --- a/research_bench/scripts/crossbench_paper_links.txt +++ b/research_bench/scripts/crossbench_paper_links.txt @@ -17,4 +17,4 @@ https://arxiv.org/abs/2401.14818 https://arxiv.org/abs/2401.14656 https://arxiv.org/abs/2401.11052 https://arxiv.org/abs/2311.12410 -https://arxiv.org/abs/2311.10776 \ No newline at end of file +https://arxiv.org/abs/2311.10776 diff --git a/research_bench/utils.py b/research_bench/utils.py index b1b081be..09ff3523 100644 --- a/research_bench/utils.py +++ b/research_bench/utils.py @@ -20,25 +20,27 @@ def get_references(arxiv_id: str, max_retries: int = 5) -> List[Dict[str, Any]]: response = requests.get(url, params=params, headers=headers) if response.status_code == 200: data = response.json() - return [ref['citedPaper'] for ref in data.get('data', []) if 'citedPaper' in ref] + return [ + ref['citedPaper'] for ref in data.get('data', []) if 'citedPaper' in ref + ] else: - wait_time = 2 ** attempt - print(f"Error {response.status_code} fetching references for {arxiv_id}. Retrying in {wait_time}s...") + wait_time = 2**attempt + print( + f'Error {response.status_code} fetching references for {arxiv_id}. Retrying in {wait_time}s...' + ) time.sleep(wait_time) # Exponential backoff - print(f"Failed to fetch references for {arxiv_id} after {max_retries} attempts.") + print(f'Failed to fetch references for {arxiv_id} after {max_retries} attempts.') return [] def get_paper_by_keyword( - keyword: str, - existing_arxiv_ids: Set[str], - max_papers: int = 10 + keyword: str, existing_arxiv_ids: Set[str], max_papers: int = 10 ) -> List[arxiv.Result]: query = f'all:"{keyword}" AND (cat:cs.AI OR cat:cs.LG)' search = arxiv.Search( query=query, max_results=max_papers * 2, # Fetch extra to account for duplicates - sort_by=arxiv.SortCriterion.SubmittedDate + sort_by=arxiv.SortCriterion.SubmittedDate, ) papers = [] @@ -56,7 +58,7 @@ def save_benchmark(benchmark: Dict[str, Any], output_path: str) -> None: os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(benchmark, f, indent=4, ensure_ascii=False) - print(f"Benchmark saved to {output_path}") + print(f'Benchmark saved to {output_path}') def get_paper_by_arxiv_id(arxiv_id: str) -> Optional[arxiv.Result]: @@ -65,7 +67,7 @@ def get_paper_by_arxiv_id(arxiv_id: str) -> Optional[arxiv.Result]: results = list(search.results()) return results[0] if results else None except Exception as e: - print(f"Error fetching paper {arxiv_id}: {e}") + print(f'Error fetching paper {arxiv_id}: {e}') return None