Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodfire API Provider Support #1161

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ee9a647
GSM8K Test Script
menhguin Jan 17, 2025
31f7d08
Merge branch 'main' of https://github.com/UKGovernmentBEIS/inspect_ai
menhguin Jan 17, 2025
8397a93
Add Goodfire API provider and configuration support
menhguin Jan 17, 2025
17bb100
to fix tmr
menhguin Jan 17, 2025
e2eb13a
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 18, 2025
63b1ad1
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 18, 2025
8c0a4d8
registry.py full revert
menhguin Jan 18, 2025
b466dea
i reverted model.py and it worked???
menhguin Jan 18, 2025
a7cc7ad
[why tf did that increase the score] Enhance Goodfire API provider wi…
menhguin Jan 18, 2025
7503d70
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 20, 2025
89a9ec0
generate config reset
menhguin Jan 20, 2025
1f9a9ca
[latest working script with standardisation improvements] Refactor Go…
menhguin Jan 20, 2025
65da3b1
[further standardisation changes]
menhguin Jan 20, 2025
254e0a8
Delete gsm8k_example.py
menhguin Jan 20, 2025
15eda0a
Update _model.py w space lol
menhguin Jan 20, 2025
65f507f
Update registry.py w space lol
menhguin Jan 20, 2025
2af6a9a
Update registry.py
menhguin Jan 20, 2025
89da81b
Update _model.py
menhguin Jan 20, 2025
b628884
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 20, 2025
dfa4d0e
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 22, 2025
45d5609
[changes not in goodfire.py] Update Goodfire integration: add goodfir…
menhguin Jan 22, 2025
06d8972
[model args implementation]
menhguin Jan 22, 2025
0f15aea
[#7 remove exception logging]
menhguin Jan 22, 2025
68ddb3b
last stable state
menhguin Jan 22, 2025
0796eac
Fix type hint mismatch in GoodfireAPI: update variant assignment to u…
menhguin Jan 22, 2025
67cd044
[implement #9] Enhance GoodfireAPI output structure: replace content-…
menhguin Jan 22, 2025
0cd262f
Refactor GoodfireAPI: streamline model argument handling, enhance mes…
menhguin Jan 22, 2025
3aa4d0a
Enhance GoodfireAPI: Add error handling for API calls, implement time…
menhguin Jan 22, 2025
6d6adf0
revert all changes to _generate_config
menhguin Jan 22, 2025
7597cc3
Remove version verification from goodfire.py as it was duplicated fro…
menhguin Jan 22, 2025
77b1f16
Refactor GoodfireAPI: Update default temperature and top_p values to …
menhguin Jan 22, 2025
831f8de
Merge branch 'main' into main
jjallaire Jan 23, 2025
004b982
Merge branch 'UKGovernmentBEIS:main' into main
menhguin Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/inspect_ai/model/_generate_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from contextvars import ContextVar
from copy import deepcopy
from typing import Literal, Union
from dataclasses import dataclass, field

from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from inspect_ai._util.content import Content


class GenerateConfigArgs(TypedDict, total=False):
Expand Down Expand Up @@ -79,6 +81,20 @@ class GenerateConfigArgs(TypedDict, total=False):
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""


@dataclass
class GoodfireConfig:
jjallaire marked this conversation as resolved.
Show resolved Hide resolved
"""Goodfire-specific configuration."""

variant_name: str | None = None
"""Name of the Goodfire variant to use."""

feature_analysis: bool = False
"""Whether to enable Goodfire feature analysis."""

feature_threshold: float = 0.5
"""Threshold for feature importance in analysis."""


class GenerateConfig(BaseModel):
"""Base class for model generation configs."""

Expand Down Expand Up @@ -151,6 +167,9 @@ class GenerateConfig(BaseModel):
reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""

goodfire: GoodfireConfig | None = Field(default=None)
"""Goodfire-specific configuration. Only used when using Goodfire models."""

def merge(
self, other: Union["GenerateConfig", GenerateConfigArgs]
) -> "GenerateConfig":
Expand Down
244 changes: 244 additions & 0 deletions src/inspect_ai/model/_providers/goodfire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import logging
from typing import Any, Dict, List, Optional, TypedDict, Union, cast, Literal, get_args
from typing_extensions import TypeAlias
import os

import goodfire
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
from goodfire.api.chat.client import ChatAPI, ChatCompletion
from goodfire.api.features.client import FeaturesAPI

from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version
from inspect_ai._util.content import Content, ContentText
from inspect_ai.tool import ToolChoice, ToolInfo

from .._model import ModelAPI
from .._model_output import ModelOutput, ModelUsage
from .._chat_message import (
ChatMessage,
ChatMessageAssistant,
ChatMessageSystem,
ChatMessageTool,
ChatMessageUser,
)
from .._generate_config import GenerateConfig
from .._call_tools import Tool
from .util import environment_prerequisite_error, model_base_url

logger = logging.getLogger(__name__)

# Constants
GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
MIN_VERSION = "0.2.5"
DEFAULT_BASE_URL = "https://api.goodfire.ai"
DEFAULT_MAX_TOKENS = 4096
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.95
DEFAULT_MAX_CONNECTIONS = 10

# Supported model mapping
MODEL_MAP = {
"meta-llama/Meta-Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
}

class GoodfireAPI(ModelAPI):
"""Goodfire API provider.

This provider implements the Goodfire API for LLM inference. It supports:
- Chat completions with standard message formats
- Basic parameter controls (temperature, top_p, etc.)
- Usage statistics tracking

Does not currently support:
- Tool calls
- Feature analysis
- Streaming responses
"""

def __init__(
self,
model_name: str,
base_url: str | None = None,
api_key: str | None = None,
api_key_vars: list[str] = [],
config: GenerateConfig = GenerateConfig(),
**kwargs: Any,
) -> None:
"""Initialize the Goodfire API provider.

Args:
model_name: Name of the model to use
base_url: Optional custom API base URL
api_key: Optional API key (will check env vars if not provided)
api_key_vars: Additional env vars to check for API key
config: Generation config options
"""
super().__init__(
model_name=model_name,
base_url=base_url,
api_key=api_key,
api_key_vars=[GOODFIRE_API_KEY],
config=config,
**kwargs,
)

verify_required_version("Goodfire API", "goodfire", MIN_VERSION)
jjallaire marked this conversation as resolved.
Show resolved Hide resolved

# Get API key from environment if not provided
if not self.api_key:
self.api_key = os.environ.get(GOODFIRE_API_KEY)
if not self.api_key:
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)

# Format and validate model name
if not model_name.startswith("meta-llama/"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to force the user to be explicit here from the get-go (as once you get more namespaces you'll want to clearly disambiguate). When we started we allowed just a plain gpt-4 or claude-3 but as more providers came on line there were conflicts so we went back to requiring the fully namespaced name. You know this stack better than I though so take this as a suggestion only.

Copy link
Author

@menhguin menhguin Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Response to Comment 2 (re: Model Name Namespacing)
In src/inspect_ai/model/_providers/goodfire.py:

  • Removed auto-prefixing code:
if not model_name.startswith("meta-llama/"):
    self.model_name = f"meta-llama/{model_name}"
  • Added model validation (lines 119-122):
supported_models = list(get_args(SUPPORTED_MODELS))
if self.model_name not in supported_models:
    raise ValueError(f"Model {self.model_name} not supported. Supported models: {supported_models}")

The change enforces fully namespaced model names by:

  1. Removing the auto-prefixing of "meta-llama/"
  2. Validating against the explicit list of supported models

model_name = f"meta-llama/{model_name}"

supported_models = list(get_args(SUPPORTED_MODELS))
if model_name not in supported_models:
raise ValueError(f"Model {model_name} not supported. Supported models: {supported_models}")

# Initialize client
base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
assert isinstance(base_url_val, str) or base_url_val is None
self.client = goodfire.Client(
jjallaire marked this conversation as resolved.
Show resolved Hide resolved
api_key=self.api_key,
base_url=base_url_val or DEFAULT_BASE_URL,
)
self.model_name = model_name

# Initialize variant
variant_model = MODEL_MAP.get(model_name, "meta-llama/Meta-Llama-3.1-8B-Instruct")
self.variant = Variant(variant_model)

# Feature analysis not yet supported
self.feature_analysis = False
self.feature_threshold = 0.5

async def generate(
self,
input: List[ChatMessage],
tools: List[ToolInfo],
tool_choice: ToolChoice,
config: GenerateConfig,
*,
cache: bool = True,
) -> ModelOutput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you return this instead:

tuple[ModelOutput | Exception, ModelCall]

Then you can include ModelCall information, which will be used in the viewer to show the underlying payload of the request to the Goodfire API (indispensable for debugging!)

"""Generate output from the model.

Args:
input: List of chat messages for the conversation
tools: Available tools (not currently supported)
tool_choice: Tool selection directive (not currently supported)
config: Generation parameters
cache: Whether to use response caching

Returns:
ModelOutput containing the generated response and usage statistics
"""
try:
# Convert messages and prepare request params
messages = [self._to_goodfire_message(msg) for msg in input]
params = {
"model": self.model_name,
"messages": messages,
"max_completion_tokens": int(config.max_tokens) if config.max_tokens is not None else DEFAULT_MAX_TOKENS,
"temperature": float(config.temperature) if config.temperature is not None else DEFAULT_TEMPERATURE,
"top_p": float(config.top_p) if config.top_p is not None else DEFAULT_TOP_P,
"stream": False,
}

# Make API request and convert response to dict
response = self.client.chat.completions.create(**params) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You noted this in your PR comments, but this absolutely has to be converted to async (as the sync version will hold up everything else in the process). If the goodfire client doesn't have an aysnc version then you should be able to just call asynio.to_thread and await that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goodfire has no async, and I have not been able to get this method to work at the moment. Will try again once more over the weekend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can live without async but it will slow things down immeasurably and also make the task UI unresponsive. Definitely worth some time to get this to work properly!

response_dict = response.model_dump()

# Create output with main content
output = ModelOutput.from_content(
jjallaire marked this conversation as resolved.
Show resolved Hide resolved
model=self.model_name,
content=response_dict["choices"][0]["message"]["content"],
stop_reason="stop", # Goodfire doesn't provide finish_reason
)

# Add usage statistics if available
if "usage" in response_dict:
output.usage = ModelUsage(
input_tokens=response_dict["usage"]["prompt_tokens"],
output_tokens=response_dict["usage"]["completion_tokens"],
total_tokens=response_dict["usage"]["total_tokens"],
)

return output

except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be logged elsewhere so we can remove

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

logger.error(f"Error in generate: {str(e)}", exc_info=True)
raise

def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
"""Convert an Inspect message to a Goodfire message format.

Special handling:
- Tool messages are converted to user messages (not yet supported)
- Tool call info is preserved in the message content for future compatibility
"""
role: Literal["system", "user", "assistant"] = "user"
if isinstance(message, ChatMessageSystem):
role = "system"
elif isinstance(message, ChatMessageUser):
role = "user"
elif isinstance(message, ChatMessageAssistant):
role = "assistant"
elif isinstance(message, ChatMessageTool):
role = "user" # Convert tool messages to user messages
else:
raise ValueError(f"Unknown message type: {type(message)}")

content = str(message.content)
if isinstance(message, ChatMessageTool):
content = f"Tool {message.function}: {content}"

return cast(GoodfireChatMessage, {
"role": role,
"content": content,
})

@property
def name(self) -> str:
"""Get provider name."""
return "goodfire"

def max_tokens(self) -> Optional[int]:
"""Return maximum tokens supported by model."""
return DEFAULT_MAX_TOKENS

def max_connections(self) -> int:
"""Return maximum concurrent connections."""
return DEFAULT_MAX_CONNECTIONS

def connection_key(self) -> str:
"""Return key for connection pooling."""
return f"goodfire:{self.api_key}"

def is_rate_limit(self, ex: BaseException) -> bool:
"""Check if exception is due to rate limiting."""
return "rate_limit" in str(ex).lower()

def collapse_user_messages(self) -> bool:
"""Whether to collapse consecutive user messages."""
return True

def collapse_assistant_messages(self) -> bool:
"""Whether to collapse consecutive assistant messages."""
return True

def tools_required(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove properties that just return the default (the tools ones + max_connections)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Response to Comment 10 (re: Default Properties)

  • ✅ Removed redundant default-returning properties:
    • Removed collapse_user_messages() method that just returned True
    • Removed collapse_assistant_messages() method that just returned True
  • ✅ Using model-specific token limits in max_tokens() method (lines 203-211):
@override
def max_tokens(self) -> int | None:
    """Return maximum tokens supported by model."""
    # Model-specific limits
    if "llama-3.3-70b" in self.model_name.lower():
        return 4096
    elif "llama-3.1-8b" in self.model_name.lower():
        return 4096
    return DEFAULT_MAX_TOKENS

"""Whether tools are required."""
return False

def tool_result_images(self) -> bool:
"""Whether tool results can contain images."""
return False

# Remove duplicate registration since it's handled in providers.py
20 changes: 19 additions & 1 deletion src/inspect_ai/model/_providers/providers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import os
import logging

from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version

from .._model import ModelAPI
from .._registry import modelapi
from .._registry import modelapi, modelapi_register
from .goodfire import GoodfireAPI

# Defer importing model api classes until they are actually used
# (this allows the package to load without the optional deps)
# Note that some api providers (e.g. Cloudflare, AzureAI) don't
# strictly require this treatment but we do it anyway for uniformity,

logger = logging.getLogger(__name__)


@modelapi(name="groq")
def groq() -> type[ModelAPI]:
Expand Down Expand Up @@ -239,6 +243,20 @@ def mockllm() -> type[ModelAPI]:
return MockLLM


@modelapi(name="goodfire")
def goodfire() -> type[ModelAPI]:
"""Get Goodfire API provider."""
logger.debug("[PROVIDER] Registering goodfire provider")
jjallaire marked this conversation as resolved.
Show resolved Hide resolved
try:
import goodfire
verify_required_version("Goodfire API", "goodfire", "0.1.0")
jjallaire marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("[PROVIDER] Successfully imported goodfire and verified version")
return GoodfireAPI
except ImportError as e:
logger.error("[PROVIDER] Failed to import goodfire", exc_info=True)
raise pip_dependency_error("Goodfire API", ["goodfire"]) from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with other providers, please use constants here e.g. raise pip_dependency_error(FEATURE, [PACKAGE])



def validate_openai_client(feature: str) -> None:
FEATURE = feature
PACKAGE = "openai"
Expand Down