-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Goodfire API Provider Support #1161
base: main
Are you sure you want to change the base?
Conversation
- Introduced `GoodfireConfig` dataclass for Goodfire-specific settings in `_generate_config.py`. - Implemented `GoodfireAPI` class in a new file `_providers/goodfire.py` to handle interactions with the Goodfire API. - Registered the Goodfire API provider in `_providers/providers.py`, including error handling for dependency imports. - Updated `GenerateConfig` to include Goodfire configuration options.
…th version verification and improved error handling - Added support for minimum version requirement for the Goodfire API. - Introduced supported model literals and updated model name handling. - Improved API key retrieval logic with environment variable checks. - Enhanced client initialization to include base URL handling. - Updated maximum completion tokens to 4096 for better performance. - Refined message conversion to handle tool messages appropriately. - Removed unsupported feature analysis configuration. This commit improves the robustness and usability of the Goodfire API integration.
…odfireAPI generate method for improved error handling and parameter management - Enhanced the generate method to use a try-except block for better error logging. - Consolidated API request parameters into a dictionary for cleaner code. - Added handling for usage statistics in the output if available. - Improved message conversion process for better clarity and maintainability. This update increases the robustness of the Goodfire API integration and enhances error reporting.
OH YES almost forgot: You need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! So happy to see this and excited to see it built out further. Left some feedback in the review. Some additional comments:
-
Saw there was a note on streaming support -- currently do don't use streaming in our model interfaces so I don't think this will be required (but perhaps there is a scenario I'm not thinking of?)
-
Saw your note on caching -- would the built in caching work (we cache
ModelOutput
instances based on a key that hopefully reflects the full range of possible inputs).
In terms of adding mech interp stuff, we've had initial discussions with a few others in the field on how to do this. At some point I think we'd like to define some common data structures that can go in ModelOutput
but we aren't there yet. In the meantime, you should add any mech interp data to the metadata
field of ModelOutput
(using whatever schema you want). Later we can try to bring some of this back into something that is shared by multiple mech interp back ends.
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY) | ||
|
||
# Format and validate model name | ||
if not model_name.startswith("meta-llama/"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to force the user to be explicit here from the get-go (as once you get more namespaces you'll want to clearly disambiguate). When we started we allowed just a plain gpt-4
or claude-3
but as more providers came on line there were conflicts so we went back to requiring the fully namespaced name. You know this stack better than I though so take this as a suggestion only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Response to Comment 2 (re: Model Name Namespacing)
In src/inspect_ai/model/_providers/goodfire.py
:
- Removed auto-prefixing code:
if not model_name.startswith("meta-llama/"):
self.model_name = f"meta-llama/{model_name}"
- Added model validation (lines 119-122):
supported_models = list(get_args(SUPPORTED_MODELS))
if self.model_name not in supported_models:
raise ValueError(f"Model {self.model_name} not supported. Supported models: {supported_models}")
The change enforces fully namespaced model names by:
- Removing the auto-prefixing of "meta-llama/"
- Validating against the explicit list of supported models
return GoodfireAPI | ||
except ImportError as e: | ||
logger.error("[PROVIDER] Failed to import goodfire", exc_info=True) | ||
raise pip_dependency_error("Goodfire API", ["goodfire"]) from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As with other providers, please use constants here e.g. raise pip_dependency_error(FEATURE, [PACKAGE])
|
||
return output | ||
|
||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be logged elsewhere so we can remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
} | ||
|
||
# Make API request and convert response to dict | ||
response = self.client.chat.completions.create(**params) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You noted this in your PR comments, but this absolutely has to be converted to async (as the sync version will hold up everything else in the process). If the goodfire client doesn't have an aysnc version then you should be able to just call asynio.to_thread
and await that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Goodfire has no async, and I have not been able to get this method to work at the moment. Will try again once more over the weekend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can live without async but it will slow things down immeasurably and also make the task UI unresponsive. Definitely worth some time to get this to work properly!
"""Whether to collapse consecutive assistant messages.""" | ||
return True | ||
|
||
def tools_required(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should remove properties that just return the default (the tools ones + max_connections)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Response to Comment 10 (re: Default Properties)
- ✅ Removed redundant default-returning properties:
- Removed
collapse_user_messages()
method that just returnedTrue
- Removed
collapse_assistant_messages()
method that just returnedTrue
- Removed
- ✅ Using model-specific token limits in
max_tokens()
method (lines 203-211):
@override
def max_tokens(self) -> int | None:
"""Return maximum tokens supported by model."""
# Model-specific limits
if "llama-3.3-70b" in self.model_name.lower():
return 4096
elif "llama-3.1-8b" in self.model_name.lower():
return 4096
return DEFAULT_MAX_TOKENS
You can add this to the |
The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time. Some of the more ... awkward design choices were me trying to patch dozens of little mismatches between Inspect and Goodfire API (different function names, output formats, Goodfire's special functions). I tried to clean it up + standardise with the rest of Inspect but evidently i missed a few things. I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference. |
Great, thanks! (and feel free to ping me w/ any questions in the meantime)
Yes, several model providers (OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM) support Logprobs: inspect_ai/src/inspect_ai/model/_model_output.py Lines 39 to 72 in 124d837
Eventually I'd like to have some standard fields like this for mech interp payloads (so that readers of logs can benefit form some uniformity). Absent working out these schemas I would put your own data structures in |
…e package dependency and remove GoodfireConfig class from GenerateConfig. Enhance goodfire provider with version verification.
…se runtime-safe string. Add note and TODO for potential issue in Goodfire's repo.
…replace content-based output with choices array including ChatCompletionChoice. This change improves the handling of chat responses by encapsulating message content within a structured choice format.
…sage conversion, and implement connection pooling methods. This update improves code organization by consolidating model argument storage and adds new methods for message processing and rate limit checks, while ensuring compatibility with existing functionality.
…out and retry logic, and improve rate limit detection. This update introduces a new method for handling API errors, ensuring graceful degradation during rate limits, and enhances the overall robustness of the API interaction. Additionally, new constants for maximum retries and timeout values have been defined to improve configuration management.
…baseline settings, streamline model name validation, and remove redundant message collapsing methods. This enhances the API's configuration and simplifies the code structure.
All changes should be done! I might add a more comprehensive overview of challenges/limitations later. The only specifically recommended change I couldn't implement was implementing async. |
error_msg = str(ex).lower() | ||
|
||
# Only handle rate limits and return other errors as-is | ||
if self.is_rate_limit(ex): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to handle rate limits here, rather we want these exception to propagate up so our higher level retry handler can catch them and retry. The handle_error()
used by other models is used almost exclusively for handling context window overflows and refusals (which are frustratingly reported via exception not in a structured way). This function exists essentially to take these conditions and return them into ModelOutput
with StopReason
of model_length
or content_filter
.
config: GenerateConfig, | ||
*, | ||
cache: bool = True, | ||
) -> ModelOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you return this instead:
tuple[ModelOutput | Exception, ModelCall]
Then you can include ModelCall
information, which will be used in the viewer to show the underlying payload of the request to the Goodfire API (indispensable for debugging!)
Add Goodfire API Provider Support
Overview
This PR introduces support for the Goodfire API, enabling the use of Meta's Llama models through Goodfire's inference service. The implementation provides basic chat completion functionality while maintaining compatibility with the existing evaluation framework.
Over the next few weeks, I expect to add complex mechanistic interpretability techniques (feature search, inspect, feature steering) as shown in the Goodfire AI documentation. For now, this PR seeks to cover basic chat completion and standardisation in line with other model providers (since I have to keep merging new commits every day)
Critical Implementation Details
Core Provider Implementation (
inspect_ai/src/inspect_ai/model/_providers/goodfire.py
):GoodfireAPI
class with synchronous API handlinggenerate()
,_to_goodfire_message()
, connection managementDEFAULT_MAX_TOKENS=4096
,DEFAULT_TEMPERATURE=0.7
MODEL_MAP
Known Limitations:
inspect_evals/src/inspect_evals/mmlu/mmlu_5_shot.py
generate()
methodAPI Differences (vs OpenAI/Anthropic):
max_completion_tokens
vsmax_tokens
finish_reason
fieldRequired Configuration
Environment Setup:
Model Support:
meta-llama/Meta-Llama-3.1-8B-Instruct
meta-llama/Llama-3.3-70B-Instruct
Pending Improvements (Prioritized)
Critical:
Important:
Nice to Have:
Testing Status
Verified:
Known Issues:
Breaking Changes
None. This should not affect the use of other model providers, and effort has been taken to ensure standardisation. Code changes have been isolated to:
/src/inspect_ai/model/_providers/goodfire.py
for core implementation scriptsrc/inspect_ai/model/_providers/providers.py
to register Goodfire as a model providersrc/inspect_ai/model/_generate_config.py
for certain Goodfire-specific generation functions (tho do feel free to test this to make sure it doesn't affect any other providers)So far, the model seems to generate and score similarly to VLLM-hosted Llama 8B-Instruct on GPQA, GSM8K and MMLU.
Conclusion
Once again, I will be improving and building on this initial chat generation implementation in the coming weeks with more advanced mech interp functions. If you come across issues in other evals, do let me know.