Skip to content

Commit

Permalink
fix pre-commit error
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk committed Oct 8, 2024
1 parent 8ff76b4 commit 5cd1895
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
8 changes: 3 additions & 5 deletions research_bench/crossbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from typing import Any, Dict, List, Set

from tqdm import tqdm

from utils import get_paper_by_arxiv_id, process_paper, save_benchmark


def get_arxiv_ids(input: str) -> List[str]:

with open(input, 'r', encoding='utf-8') as f:
urls = [line.strip() for line in f if line.strip()]

Expand Down Expand Up @@ -40,7 +38,7 @@ def process_arxiv_ids(
existing_arxiv_ids.add(arxiv_id)
save_benchmark(benchmark, output)
else:
print(f"Paper with arXiv ID {arxiv_id} not found.")
print(f'Paper with arXiv ID {arxiv_id} not found.')

return benchmark

Expand All @@ -51,13 +49,13 @@ def parse_args():
'--input',
type=str,
required=True,
help='Path to the input file containing arXiv URLs.'
help='Path to the input file containing arXiv URLs.',
)
parser.add_argument(
'--output',
type=str,
default='./benchmark/crossbench.json',
help='Output file path.'
help='Output file path.',
)
return parser.parse_args()

Expand Down
11 changes: 6 additions & 5 deletions research_bench/mlbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Set

from tqdm import tqdm

from utils import get_paper_by_keyword, process_paper, save_benchmark


Expand All @@ -18,7 +17,9 @@ def process_keywords(

for keyword in keywords:
print(f"Fetching papers for keyword: '{keyword}'")
papers = get_paper_by_keyword(keyword, existing_arxiv_ids, max_papers_per_keyword)
papers = get_paper_by_keyword(
keyword, existing_arxiv_ids, max_papers_per_keyword
)

for paper in tqdm(papers, desc=f"Processing papers for '{keyword}'"):
paper_data = process_paper(paper)
Expand Down Expand Up @@ -49,19 +50,19 @@ def parse_args():
'explainable AI',
'automated machine learning',
],
help='List of keywords to search for.'
help='List of keywords to search for.',
)
parser.add_argument(
'--max_papers_per_keyword',
type=int,
default=10,
help='Maximum number of papers per keyword.'
help='Maximum number of papers per keyword.',
)
parser.add_argument(
'--output',
type=str,
default='./benchmark/mlbench.json',
help='Output file path.'
help='Output file path.',
)
return parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion research_bench/scripts/crossbench_paper_links.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ https://arxiv.org/abs/2401.14818
https://arxiv.org/abs/2401.14656
https://arxiv.org/abs/2401.11052
https://arxiv.org/abs/2311.12410
https://arxiv.org/abs/2311.10776
https://arxiv.org/abs/2311.10776
22 changes: 12 additions & 10 deletions research_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,27 @@ def get_references(arxiv_id: str, max_retries: int = 5) -> List[Dict[str, Any]]:
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
data = response.json()
return [ref['citedPaper'] for ref in data.get('data', []) if 'citedPaper' in ref]
return [
ref['citedPaper'] for ref in data.get('data', []) if 'citedPaper' in ref
]
else:
wait_time = 2 ** attempt
print(f"Error {response.status_code} fetching references for {arxiv_id}. Retrying in {wait_time}s...")
wait_time = 2**attempt
print(
f'Error {response.status_code} fetching references for {arxiv_id}. Retrying in {wait_time}s...'
)
time.sleep(wait_time) # Exponential backoff
print(f"Failed to fetch references for {arxiv_id} after {max_retries} attempts.")
print(f'Failed to fetch references for {arxiv_id} after {max_retries} attempts.')
return []


def get_paper_by_keyword(
keyword: str,
existing_arxiv_ids: Set[str],
max_papers: int = 10
keyword: str, existing_arxiv_ids: Set[str], max_papers: int = 10
) -> List[arxiv.Result]:
query = f'all:"{keyword}" AND (cat:cs.AI OR cat:cs.LG)'
search = arxiv.Search(
query=query,
max_results=max_papers * 2, # Fetch extra to account for duplicates
sort_by=arxiv.SortCriterion.SubmittedDate
sort_by=arxiv.SortCriterion.SubmittedDate,
)

papers = []
Expand All @@ -56,7 +58,7 @@ def save_benchmark(benchmark: Dict[str, Any], output_path: str) -> None:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(benchmark, f, indent=4, ensure_ascii=False)
print(f"Benchmark saved to {output_path}")
print(f'Benchmark saved to {output_path}')


def get_paper_by_arxiv_id(arxiv_id: str) -> Optional[arxiv.Result]:
Expand All @@ -65,7 +67,7 @@ def get_paper_by_arxiv_id(arxiv_id: str) -> Optional[arxiv.Result]:
results = list(search.results())
return results[0] if results else None
except Exception as e:
print(f"Error fetching paper {arxiv_id}: {e}")
print(f'Error fetching paper {arxiv_id}: {e}')
return None


Expand Down

0 comments on commit 5cd1895

Please sign in to comment.