From 752f1e7d1e7fc7d567d4099d1f9946a87cf5f39d Mon Sep 17 00:00:00 2001 From: efugier Date: Wed, 8 Nov 2023 18:53:38 +0100 Subject: [PATCH] feat(config): configurable prompts --- Cargo.toml | 1 + src/config.rs | 59 ++++++++++++++++++++++++++++++--------- src/input_processing.rs | 20 +++++++------ src/main.rs | 52 +++++++++++++++++++++++++++++----- src/request.rs | 39 ++++++++++++++------------ tests/integration_test.rs | 3 +- 6 files changed, 126 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 68d95b5..343ec8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,6 @@ edition = "2021" [dependencies] toml = "*" +clap = { version = "*", features = ["derive"] } ureq = { version="*", features = ["json"] } serde = { version = "*", features = ["derive"] } diff --git a/src/config.rs b/src/config.rs index 07e8d2c..1467c3c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,23 +1,56 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fs; +use std::path::PathBuf; use toml::Value; -pub fn get_api_key() -> String { - let config_path = format!( - "{}/.config/pipelm/.api_configs.toml", - std::env::var("HOME").unwrap() - ); - let content = fs::read_to_string(config_path).expect("Failed to read the TOML file"); +#[derive(Debug, Deserialize, Serialize)] +pub struct Prompt { + #[serde(skip_serializing)] // internal use only + pub service: String, + pub model: String, + pub messages: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +pub const PLACEHOLDER_TOKEN: &str = "#[]"; + +const DEFAULT_CONFIG_PATH: &str = ".config/pipelm/"; +const CUSTOM_CONFIG_ENV_VAR: &str = "PIPLE_CONFIG_PATH"; +const API_KEYS_FILE: &str = ".api_keys.toml"; +const PROMPT_FILE: &str = "prompts.toml"; + +fn resolve_config_path() -> PathBuf { + match std::env::var(CUSTOM_CONFIG_ENV_VAR) { + Ok(p) => PathBuf::new().join(p), + Err(_) => PathBuf::new().join(env!("HOME")).join(DEFAULT_CONFIG_PATH), + } +} + +pub fn get_api_key(service: &str) -> String { + let api_keys_path = resolve_config_path().join(API_KEYS_FILE); + let content = fs::read_to_string(&api_keys_path) + .unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path, error)); let value: Value = content.parse().expect("Failed to parse TOML"); // Extract the API key from the TOML table. let api_key = value - .get("openai") - .and_then(|table| table.get("API_KEY")) - .and_then(|api_key| api_key.as_str()) - .unwrap_or_else(|| { - eprintln!("API_KEY not found in the TOML file."); - std::process::exit(1); - }); + .get("API_KEYS") + .expect("API_KEYS section not found") + .get(service) + .unwrap_or_else(|| panic!("No api key found for service {}.", &service)); api_key.to_string() } + +pub fn get_prompts() -> HashMap { + let prompts_path = resolve_config_path().join(PROMPT_FILE); + let content = fs::read_to_string(&prompts_path) + .unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path, error)); + toml::from_str(&content).unwrap() +} diff --git a/src/input_processing.rs b/src/input_processing.rs index eb04cb7..af0eff6 100644 --- a/src/input_processing.rs +++ b/src/input_processing.rs @@ -1,6 +1,8 @@ +use crate::config::{get_api_key, Prompt, PLACEHOLDER_TOKEN}; use crate::request::{make_authenticated_request, OpenAiResponse}; use std::io::{Read, Result, Write}; +// [tmp] mostly template to write tests pub fn chunk_process_input( input: &mut R, output: &mut W, @@ -9,7 +11,6 @@ pub fn chunk_process_input( ) -> Result<()> { let mut first_chunk = true; let mut buffer = [0; 1024]; - loop { match input.read(&mut buffer) { Ok(0) => break, // end of input @@ -33,10 +34,9 @@ pub fn chunk_process_input( } pub fn process_input_with_request( + prompt: &mut Prompt, input: &mut R, output: &mut W, - prefix: &str, - suffix: &str, ) -> Result<()> { let mut buffer = Vec::new(); input.read_to_end(&mut buffer)?; @@ -48,15 +48,17 @@ pub fn process_input_with_request( let input = String::from_utf8(buffer).unwrap(); - let mut result = String::from(prefix); - result.push_str(&input); - result.push_str(suffix); - - let response: OpenAiResponse = make_authenticated_request(&result).unwrap().into_json()?; + for message in prompt.messages.iter_mut() { + message.content = message.content.replace(PLACEHOLDER_TOKEN, &input) + } + let api_key = get_api_key(&prompt.service); + let response: OpenAiResponse = make_authenticated_request(&api_key, prompt) + .unwrap() + .into_json()?; println!("{}", response.choices.first().unwrap().message.content); - output.write_all(suffix.as_bytes())?; + output.write_all(input.as_bytes())?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index 04d6615..25683bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,56 @@ +use clap::Parser; use std::io; -mod config; mod input_processing; mod request; +#[allow(dead_code)] +mod config; + +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + #[arg(default_value_t = String::from("default"))] + prompt: String, + #[arg(short, long, default_value_t = String::from("openai"))] + service: String, +} + fn main() { + let args = Cli::parse(); + let mut output = io::stdout(); let mut input = io::stdin(); - if let Err(e) = input_processing::chunk_process_input( - &mut input, - &mut output, - "Hello, World!\n```\n", - "\n```\n", - ) { + let mut prompts = config::get_prompts(); + + // case for testing IO + if args.prompt == "test" { + if let Err(e) = input_processing::chunk_process_input( + &mut input, + &mut output, + "Hello, World!\n```\n", + "\n```\n", + ) { + eprintln!("Error: {}", e); + std::process::exit(1); + } else { + std::process::exit(0); + } + } + + let available_prompts: Vec<&String> = prompts.keys().collect(); + let prompt_not_found_error = format!( + "Prompt {} not found, availables ones are: {:?}", + &args.prompt, &available_prompts + ); + + let prompt = prompts + .get_mut(&args.prompt) + .expect(&prompt_not_found_error); + + println!("{:?}", prompt); + + if let Err(e) = input_processing::process_input_with_request(prompt, &mut input, &mut output) { eprintln!("Error: {}", e); std::process::exit(1); } diff --git a/src/request.rs b/src/request.rs index 0f05036..b53a1f3 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,5 +1,4 @@ -use crate::config::get_api_key; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize)] pub struct Message { @@ -32,24 +31,28 @@ pub struct OpenAiResponse { pub system_fingerprint: String, } -pub fn make_authenticated_request(text: &str) -> Result { - let api_key = get_api_key(); +pub fn make_authenticated_request( + api_key: &str, + data: impl Serialize, +) -> Result { println!("Trying to reach openai with {}", &api_key); ureq::post("https://api.openai.com/v1/chat/completions") .set("Content-Type", "application/json") .set("Authorization", &format!("Bearer {}", api_key)) - .send_json(ureq::json!({ - "model": "gpt-4-1106-preview", - "messages": [ - { - "role": "system", - "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair." - }, - { - "role": "user", - "content": text - } - ] - }) - ) + .send_json(data) + // .send_json(ureq::json!( + // { + // "model": "gpt-4-1106-preview", + // "messages": [ + // { + // "role": "system", + // "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair." + // }, + // { + // "role": "user", + // "content": data.messages.last().unwrap().content + // } + // ] + // }) + // ) } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index b35ed69..93be737 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -2,7 +2,7 @@ use std::io::{Read, Write}; use std::process::{Command, Stdio}; #[test] -fn test_program_integration() { +fn test_io() { let hardcoded_prefix = "Hello, World!\n```\n"; let hardcoded_suffix = "\n```\n"; let input_data = "Input data"; @@ -10,6 +10,7 @@ fn test_program_integration() { // launch the program and get the streams let mut child = Command::new("cargo") .arg("run") + .arg("test") .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn()