Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodfire API Provider Support #1161

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

menhguin
Copy link

@menhguin menhguin commented Jan 20, 2025

Add Goodfire API Provider Support

Overview

This PR introduces support for the Goodfire API, enabling the use of Meta's Llama models through Goodfire's inference service. The implementation provides basic chat completion functionality while maintaining compatibility with the existing evaluation framework.

Over the next few weeks, I expect to add complex mechanistic interpretability techniques (feature search, inspect, feature steering) as shown in the Goodfire AI documentation. For now, this PR seeks to cover basic chat completion and standardisation in line with other model providers (since I have to keep merging new commits every day)

Critical Implementation Details

  1. Core Provider Implementation (inspect_ai/src/inspect_ai/model/_providers/goodfire.py):

    • Implements GoodfireAPI class with synchronous API handling
    • Key methods: generate(), _to_goodfire_message(), connection management
    • Constants for defaults: DEFAULT_MAX_TOKENS=4096, DEFAULT_TEMPERATURE=0.7
    • Model mapping for supported variants in MODEL_MAP
  2. Known Limitations:

    • MMLU Few-shot Evaluation Issue ():
      • Zero-shot works correctly (~0.57 accuracy)
      • Few-shot fails (~0.1 accuracy) due to strict format following - we note that this is more due to using Llama 3 instruct models, rather than a Goodfire-specific issue
      • Model outputs bare letters instead of "Answer: A" format
      • Affects inspect_evals/src/inspect_evals/mmlu/mmlu_5_shot.py
    • Synchronous API in async framework:
      • Blocks event loop during generation
      • Affects progress bar updates
      • Located in generate() method
  3. API Differences (vs OpenAI/Anthropic):

    • Parameter naming:
      • Uses max_completion_tokens vs max_tokens
      • Different default values
    • Response handling:
      • Dictionary-based vs object-based responses
      • Manual extraction required for content/usage
      • No finish_reason field
    • Message handling:
      • Tool messages converted to user messages
      • Limited role support

Required Configuration

  1. Environment Setup:

    GOODFIRE_API_KEY=<key>
    GOODFIRE_BASE_URL=<optional>
    
    pip install goodfire
    
  2. Model Support:

    • Currently supports:
      • meta-llama/Meta-Llama-3.1-8B-Instruct
      • meta-llama/Llama-3.3-70B-Instruct

Pending Improvements (Prioritized)

  1. Critical:

    • Fix few-shot evaluation format handling
    • Implement proper async operation
    • Add progress tracking solution
  2. Important:

    • Add streaming support when available
    • Implement tool calls support
    • Enhance error handling
  3. Nice to Have:

    • Add feature analysis support
    • Expand model support
    • Add caching strategy

Testing Status

  1. Verified:

    • Basic chat completion
    • Zero-shot evaluations
    • Usage statistics collection
    • Parameter validation
  2. Known Issues:

    • Few-shot format compliance
    • Progress tracking during long runs
    • Type hints causing linter errors

Breaking Changes

None. This should not affect the use of other model providers, and effort has been taken to ensure standardisation. Code changes have been isolated to:

  • /src/inspect_ai/model/_providers/goodfire.py for core implementation script
  • src/inspect_ai/model/_providers/providers.py to register Goodfire as a model provider
  • src/inspect_ai/model/_generate_config.py for certain Goodfire-specific generation functions (tho do feel free to test this to make sure it doesn't affect any other providers)

So far, the model seems to generate and score similarly to VLLM-hosted Llama 8B-Instruct on GPQA, GSM8K and MMLU.

Conclusion

Once again, I will be improving and building on this initial chat generation implementation in the coming weeks with more advanced mech interp functions. If you come across issues in other evals, do let me know.

- Introduced `GoodfireConfig` dataclass for Goodfire-specific settings in `_generate_config.py`.
- Implemented `GoodfireAPI` class in a new file `_providers/goodfire.py` to handle interactions with the Goodfire API.
- Registered the Goodfire API provider in `_providers/providers.py`, including error handling for dependency imports.
- Updated `GenerateConfig` to include Goodfire configuration options.
…th version verification and improved error handling

- Added support for minimum version requirement for the Goodfire API.
- Introduced supported model literals and updated model name handling.
- Improved API key retrieval logic with environment variable checks.
- Enhanced client initialization to include base URL handling.
- Updated maximum completion tokens to 4096 for better performance.
- Refined message conversion to handle tool messages appropriately.
- Removed unsupported feature analysis configuration.

This commit improves the robustness and usability of the Goodfire API integration.
…odfireAPI generate method for improved error handling and parameter management

- Enhanced the generate method to use a try-except block for better error logging.
- Consolidated API request parameters into a dictionary for cleaner code.
- Added handling for usage statistics in the output if available.
- Improved message conversion process for better clarity and maintainability.

This update increases the robustness of the Goodfire API integration and enhances error reporting.
@menhguin
Copy link
Author

menhguin commented Jan 20, 2025

OH YES almost forgot: You need to pip install goodfire as well, but i'm unsure where to add this.

Copy link
Collaborator

@jjallaire jjallaire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! So happy to see this and excited to see it built out further. Left some feedback in the review. Some additional comments:

  1. Saw there was a note on streaming support -- currently do don't use streaming in our model interfaces so I don't think this will be required (but perhaps there is a scenario I'm not thinking of?)

  2. Saw your note on caching -- would the built in caching work (we cache ModelOutput instances based on a key that hopefully reflects the full range of possible inputs).

In terms of adding mech interp stuff, we've had initial discussions with a few others in the field on how to do this. At some point I think we'd like to define some common data structures that can go in ModelOutput but we aren't there yet. In the meantime, you should add any mech interp data to the metadata field of ModelOutput (using whatever schema you want). Later we can try to bring some of this back into something that is shared by multiple mech interp back ends.

src/inspect_ai/model/_generate_config.py Outdated Show resolved Hide resolved
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)

# Format and validate model name
if not model_name.startswith("meta-llama/"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to force the user to be explicit here from the get-go (as once you get more namespaces you'll want to clearly disambiguate). When we started we allowed just a plain gpt-4 or claude-3 but as more providers came on line there were conflicts so we went back to requiring the fully namespaced name. You know this stack better than I though so take this as a suggestion only.

Copy link
Author

@menhguin menhguin Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Response to Comment 2 (re: Model Name Namespacing)
In src/inspect_ai/model/_providers/goodfire.py:

  • Removed auto-prefixing code:
if not model_name.startswith("meta-llama/"):
    self.model_name = f"meta-llama/{model_name}"
  • Added model validation (lines 119-122):
supported_models = list(get_args(SUPPORTED_MODELS))
if self.model_name not in supported_models:
    raise ValueError(f"Model {self.model_name} not supported. Supported models: {supported_models}")

The change enforces fully namespaced model names by:

  1. Removing the auto-prefixing of "meta-llama/"
  2. Validating against the explicit list of supported models

src/inspect_ai/model/_providers/providers.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/providers.py Outdated Show resolved Hide resolved
return GoodfireAPI
except ImportError as e:
logger.error("[PROVIDER] Failed to import goodfire", exc_info=True)
raise pip_dependency_error("Goodfire API", ["goodfire"]) from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with other providers, please use constants here e.g. raise pip_dependency_error(FEATURE, [PACKAGE])

src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved

return output

except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be logged elsewhere so we can remove

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

}

# Make API request and convert response to dict
response = self.client.chat.completions.create(**params) # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You noted this in your PR comments, but this absolutely has to be converted to async (as the sync version will hold up everything else in the process). If the goodfire client doesn't have an aysnc version then you should be able to just call asynio.to_thread and await that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goodfire has no async, and I have not been able to get this method to work at the moment. Will try again once more over the weekend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can live without async but it will slow things down immeasurably and also make the task UI unresponsive. Definitely worth some time to get this to work properly!

src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
"""Whether to collapse consecutive assistant messages."""
return True

def tools_required(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove properties that just return the default (the tools ones + max_connections)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Response to Comment 10 (re: Default Properties)

  • ✅ Removed redundant default-returning properties:
    • Removed collapse_user_messages() method that just returned True
    • Removed collapse_assistant_messages() method that just returned True
  • ✅ Using model-specific token limits in max_tokens() method (lines 203-211):
@override
def max_tokens(self) -> int | None:
    """Return maximum tokens supported by model."""
    # Model-specific limits
    if "llama-3.3-70b" in self.model_name.lower():
        return 4096
    elif "llama-3.1-8b" in self.model_name.lower():
        return 4096
    return DEFAULT_MAX_TOKENS

@jjallaire
Copy link
Collaborator

OH YES almost forgot: You need to pip install goodfire as well, but i'm unsure where to add this.

You can add this to the dev config of [project.optional-dependencies] in pyproject.toml

@menhguin
Copy link
Author

The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time.

Some of the more ... awkward design choices were me trying to patch dozens of little mismatches between Inspect and Goodfire API (different function names, output formats, Goodfire's special functions). I tried to clean it up + standardise with the rest of Inspect but evidently i missed a few things.

I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference.

@jjallaire
Copy link
Collaborator

The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time.

Great, thanks! (and feel free to ping me w/ any questions in the meantime)

I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference.

Yes, several model providers (OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM) support Logprobs:

class TopLogprob(BaseModel):
"""List of the most likely tokens and their log probability, at this token position."""
token: str
"""The top-kth token represented as a string."""
logprob: float
"""The log probability value of the model for the top-kth token."""
bytes: list[int] | None = Field(default=None)
"""The top-kth token represented as a byte array (a list of integers)."""
class Logprob(BaseModel):
"""Log probability for a token."""
token: str
"""The predicted token represented as a string."""
logprob: float
"""The log probability value of the model for the predicted token."""
bytes: list[int] | None = Field(default=None)
"""The predicted token represented as a byte array (a list of integers)."""
top_logprobs: list[TopLogprob] | None = Field(default=None)
"""If the `top_logprobs` argument is greater than 0, this will contain an ordered list of the top K most likely tokens and their log probabilities."""
class Logprobs(BaseModel):
"""Log probability information for a completion choice."""
content: list[Logprob]
"""a (num_generated_tokens,) length list containing the individual log probabilities for each generated token."""

Eventually I'd like to have some standard fields like this for mech interp payloads (so that readers of logs can benefit form some uniformity). Absent working out these schemas I would put your own data structures in ModelOutput.metadata then we can ideally learn from them and work towards standardization over time.

…e package dependency and remove GoodfireConfig class from GenerateConfig. Enhance goodfire provider with version verification.
…se runtime-safe string. Add note and TODO for potential issue in Goodfire's repo.
…replace content-based output with choices array including ChatCompletionChoice. This change improves the handling of chat responses by encapsulating message content within a structured choice format.
…sage conversion, and implement connection pooling methods. This update improves code organization by consolidating model argument storage and adds new methods for message processing and rate limit checks, while ensuring compatibility with existing functionality.
…out and retry logic, and improve rate limit detection. This update introduces a new method for handling API errors, ensuring graceful degradation during rate limits, and enhances the overall robustness of the API interaction. Additionally, new constants for maximum retries and timeout values have been defined to improve configuration management.
…baseline settings, streamline model name validation, and remove redundant message collapsing methods. This enhances the API's configuration and simplifies the code structure.
@menhguin
Copy link
Author

All changes should be done! I might add a more comprehensive overview of challenges/limitations later. The only specifically recommended change I couldn't implement was implementing async.

error_msg = str(ex).lower()

# Only handle rate limits and return other errors as-is
if self.is_rate_limit(ex):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to handle rate limits here, rather we want these exception to propagate up so our higher level retry handler can catch them and retry. The handle_error() used by other models is used almost exclusively for handling context window overflows and refusals (which are frustratingly reported via exception not in a structured way). This function exists essentially to take these conditions and return them into ModelOutput with StopReason of model_length or content_filter.

config: GenerateConfig,
*,
cache: bool = True,
) -> ModelOutput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you return this instead:

tuple[ModelOutput | Exception, ModelCall]

Then you can include ModelCall information, which will be used in the viewer to show the underlying payload of the request to the Goodfire API (indispensable for debugging!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants