Skip to content

Commit

Permalink
feat(config): more configurable services
Browse files Browse the repository at this point in the history
  • Loading branch information
efugier committed Nov 9, 2023
1 parent 958d698 commit 62b15d1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
26 changes: 16 additions & 10 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use toml::Value;

#[derive(Debug, Deserialize)]
pub struct ServiceConfig {
#[serde(skip_serializing)] // internal use only
pub api_key: String,
pub url: String,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Prompt {
Expand Down Expand Up @@ -32,20 +38,20 @@ fn resolve_config_path() -> PathBuf {
}
}

pub fn get_api_key(service: &str) -> String {
pub fn get_service_config(service: &str) -> ServiceConfig {
let api_keys_path = resolve_config_path().join(API_KEYS_FILE);
let content = fs::read_to_string(&api_keys_path)
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path, error));
let value: Value = content.parse().expect("Failed to parse TOML");

// Extract the API key from the TOML table.
let api_key = value
.get("API_KEYS")
.expect("API_KEYS section not found")
.get(service)
.unwrap_or_else(|| panic!("No api key found for service {}.", &service));
let mut service_configs: HashMap<String, ServiceConfig> = toml::from_str(&content).unwrap();

api_key.to_string()
service_configs.remove(service).unwrap_or_else(|| {
panic!(
"Prompt {} not found, availables ones are: {:?}",
service,
service_configs.keys().collect::<Vec<_>>()
)
})
}

pub fn get_prompts() -> HashMap<String, Prompt> {
Expand Down
6 changes: 3 additions & 3 deletions src/input_processing.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::{get_api_key, Prompt, PLACEHOLDER_TOKEN};
use crate::config::{get_service_config, Prompt, PLACEHOLDER_TOKEN};
use crate::request::{make_authenticated_request, OpenAiResponse};
use std::io::{Read, Result, Write};

Expand Down Expand Up @@ -51,8 +51,8 @@ pub fn process_input_with_request<R: Read, W: Write>(
for message in prompt.messages.iter_mut() {
message.content = message.content.replace(PLACEHOLDER_TOKEN, &input)
}
let api_key = get_api_key(&prompt.service);
let response: OpenAiResponse = make_authenticated_request(&api_key, prompt)
let service_config = get_service_config(&prompt.service);
let response: OpenAiResponse = make_authenticated_request(service_config, prompt)
.unwrap()
.into_json()?;

Expand Down
13 changes: 9 additions & 4 deletions src/request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use crate::config::ServiceConfig;

#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
Expand Down Expand Up @@ -32,13 +34,16 @@ pub struct OpenAiResponse {
}

pub fn make_authenticated_request(
api_key: &str,
service_config: ServiceConfig,
data: impl Serialize,
) -> Result<ureq::Response, ureq::Error> {
println!("Trying to reach openai with {}", &api_key);
ureq::post("https://api.openai.com/v1/chat/completions")
println!("Trying to reach openai with {}", service_config.api_key);
ureq::post(&service_config.url)
.set("Content-Type", "application/json")
.set("Authorization", &format!("Bearer {}", api_key))
.set(
"Authorization",
&format!("Bearer {}", service_config.api_key),
)
.send_json(data)
// .send_json(ureq::json!(
// {
Expand Down

0 comments on commit 62b15d1

Please sign in to comment.