Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add alternative model support including Gemini Pro, Grok-1, and LLaMa variants #160

Closed
Closed
77 changes: 71 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,53 @@ pip install -r requirements.txt

### Supported Models and API Keys

We support a wide variety of models, including open-weight and API-only models. In general, we recommend using only frontier models above the capability of the original GPT-4. To see a full list of supported models, see [here](https://github.com/SakanaAI/AI-Scientist/blob/main/ai_scientist/llm.py).
We support a wide variety of models, including open-weight and API-only models. In general, we recommend using only frontier models above the capability of the original GPT-4. Below is a comprehensive list of supported models and their variants.

#### OpenAI API (GPT-4o, GPT-4o-mini, o1 models)
## Available Models

By default, this uses the `OPENAI_API_KEY` environment variable.
AI-Scientist supports multiple model providers and variants:

#### Anthropic API (Claude Sonnet 3.5)
### Claude Models
- Claude 3.5 Sonnet (via Anthropic API)
- Claude 3.5 Sonnet (via Amazon Bedrock)
- Claude 3.5 Sonnet (via Vertex AI)

### GPT Models
- GPT-4o and variants (via OpenAI API)
- GPT-4o-mini and variants (via OpenAI API)
- o1 models and variants (via OpenAI API)

### LLaMa Models
- LLaMa 3.3 70B (via OpenRouter API)
- LLaMa 3.3 70B Local (via Ollama)
- LLaMa 3.2 1B Local (via Ollama, for resource-constrained environments)
- LLaMa 3.1 8B Local (via Ollama, optimized for segmented templates)

### Additional Models
- Gemini Pro (via Google Cloud)
- Grok-1 (via xAI)
- DeepSeek Coder V2 (via DeepSeek API)

## Model Performance and Template Compatibility

### Performance Tiers
- Tier 1 (Full Capability): LLaMa 3.3, GPT-4o, Claude 3.5
- Tier 2 (Standard): LLaMa 3.1, GPT-3.5
- Tier 3 (Resource-Constrained): LLaMa 3.2 1B

### Template Formats
AI-Scientist supports two template editing modes:
- **Diff Mode**: Default for high-capability models (Tier 1)
- **Whole Mode**: Optimized for resource-constrained models (Tier 2 & 3)

### Template Segmentation
For improved compatibility with resource-constrained models:
- Segmented templates split papers into manageable sections
- Recommended for LLaMa 3.1 8B and LLaMa 3.2 1B
- Helps prevent edit mode termination issues
- Improves reliability for paper generation tasks

For detailed configuration of each model type, see the sections below.

By default, this uses the `ANTHROPIC_API_KEY` environment variable.

Expand Down Expand Up @@ -122,9 +162,34 @@ export VERTEXAI_PROJECT="PROJECT_ID" # for Aider/LiteLLM call

By default, this uses the `DEEPSEEK_API_KEY` environment variable.

#### OpenRouter API (Llama3.1)
#### OpenRouter API (LLaMa Models)

By default, this uses the `OPENROUTER_API_KEY` environment variable. Supported models:
- LLaMa 3.3 70B: High-performance model suitable for complex research tasks
- LLaMa 3.1: Mid-tier model for general research tasks

#### Local Models via Ollama

For local model execution without API keys, AI-Scientist supports running models through Ollama:

1. Install Ollama:
```bash
curl https://ollama.ai/install.sh | sh
```

2. Pull the LLaMa model:
```bash
ollama pull llama2
```

3. Start the Ollama server:
```bash
ollama serve
```

4. Use the local model by specifying "llama3.3-70b-local" as the model identifier in your experiments.

By default, this uses the `OPENROUTER_API_KEY` environment variable.
Note: Local model performance may vary based on your system's resources. The Ollama server provides an OpenAI-compatible endpoint at `http://localhost:11434/v1`.

#### Semantic Scholar API (Literature Search)

Expand Down
121 changes: 83 additions & 38 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,18 @@ def generate_ideas(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

# Iteratively improve task.
if num_reflections > 1:
Expand All @@ -148,11 +157,18 @@ def generate_ideas(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
Expand Down Expand Up @@ -229,9 +245,18 @@ def generate_next_idea(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

# Iteratively improve task.
if num_reflections > 1:
Expand All @@ -247,11 +272,18 @@ def generate_next_idea(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

if "I am done" in text:
print(
Expand Down Expand Up @@ -409,29 +441,42 @@ def check_idea_novelty(
break

## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"

## SEARCH FOR PAPERS
query = json_output["Query"]
papers = search_for_papers(query, result_limit=10)
if papers is None:
papers_str = "No papers found."

paper_strings = []
for i, paper in enumerate(papers):
paper_strings.append(
"""{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
i=i,
title=paper["title"],
authors=paper["authors"],
venue=paper["venue"],
year=paper["year"],
cites=paper["citationCount"],
abstract=paper["abstract"],
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue

## SEARCH FOR PAPERS
query = json_output["Query"]
papers = search_for_papers(query, result_limit=10)
if papers is None:
papers_str = "No papers found."

paper_strings = []
for i, paper in enumerate(papers):
paper_strings.append(
"""{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
i=i,
title=paper["title"],
authors=paper["authors"],
venue=paper["venue"],
year=paper["year"],
cites=paper["citationCount"],
abstract=paper["abstract"],
)
)
)
papers_str = "\n\n".join(paper_strings)
papers_str = "\n\n".join(paper_strings)

except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except KeyError as e:
print(f"Missing required field in JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

except Exception as e:
print(f"Error: {e}")
Expand Down
Loading