Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Google Gemini 1.5 pro/flash models #1190

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
# OPENAI_API_KEY=Your personal OpenAI API key from https://platform.openai.com/account/api-keys
OPENAI_API_KEY=...
ANTHROPIC_API_KEY=...
# GOOGLE_API_KEY=Your personal GOOGLE API key from https://aistudio.google.com/app/apikey
GOOGLE_API_KEY=...

# If not set Model Name defaults to gpt-4o
MODEL_NAME=gemini-1.5-pro-latest
19 changes: 18 additions & 1 deletion gpt_engineer/applications/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def load_env_if_needed():
if os.getenv("ANTHROPIC_API_KEY") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))

if os.getenv("GOOGLE_API_KEY") is None:
load_dotenv()
if os.getenv("GOOGLE_API_KEY") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))


def model_env():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way so that we "merge" the model_env() and load_env_if_needed() functions? So to simplify the logic here a bit?

Or what was the idea behind adding this function.

Copy link
Author

@rtmcrc rtmcrc Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, do as you consider it is efficient. I've made a separate function so that it's not called twice.

load_env_if_needed() is called on line 460 and model_env() on 298

Have no clue why it was placed so far, maybe there was a reason, didn't get much into it. 🤷‍♂️

if os.getenv("MODEL_NAME") is None:
load_dotenv()
if os.getenv("MODEL_NAME") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))

return os.getenv("MODEL_NAME", default="gpt-4o")


def concatenate_paths(base_path, sub_path):
# Compute the relative path from base_path to sub_path
Expand Down Expand Up @@ -281,7 +295,10 @@ def format_installed_packages(packages):
def main(
project_path: str = typer.Argument(".", help="path"),
model: str = typer.Option(
os.environ.get("MODEL_NAME", "gpt-4o"), "--model", "-m", help="model id string"
os.environ.get("MODEL_NAME", model_env()),
"--model",
"-m",
help="model id string",
),
temperature: float = typer.Option(
0.1,
Expand Down
10 changes: 10 additions & 0 deletions gpt_engineer/core/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
messages_to_dict,
)
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from gpt_engineer.core.token_usage import TokenUsageLog
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
("vision-preview" in model_name)
or ("gpt-4-turbo" in model_name and "preview" not in model_name)
or ("claude" in model_name)
or ("gemini" in model_name)
)
self.llm = self._create_chat_model()
self.token_usage_log = TokenUsageLog(model_name)
Expand Down Expand Up @@ -362,6 +364,14 @@ def _create_chat_model(self) -> BaseChatModel:
streaming=self.streaming,
max_tokens_to_sample=4096,
)
elif "gemini" in self.model_name:
return ChatGoogleGenerativeAI(
model=self.model_name,
temperature=self.temperature,
streaming=self.streaming,
google_api_key=os.getenv("GOOGLE_API_KEY"),
callbacks=[StreamingStdOutCallbackHandler()],
)
elif self.vision:
return ChatOpenAI(
model=self.model_name,
Expand Down
Loading
Loading