Skip to content

Commit

Permalink
Merge pull request #1 from efugier/feat/generate-default-config
Browse files Browse the repository at this point in the history
feat(config): generate default if doesn't exist
  • Loading branch information
efugier authored Nov 10, 2023
2 parents 360e233 + ee52ef8 commit 45ebaff
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
smartcat

# Generated by Cargo
# will have compiled files and executables
debug/
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ Puts a brain behind cat!
WIP cli interface to language models to bring them in the Unix ecosystem

```
Putting a brain behind `cat`. WIP cli interface to language model to bring them in the Unix echosystem 🐈‍⬛
Usage: smartcat [OPTIONS] [PROMPT]
Arguments:
[PROMPT] which prompt in the config to fetch.
The config must have at least one named "default" containing which
model and api to hit by default. [default: default]
[PROMPT] which prompt in the config to fetch.
The config must have at least one named "default" containing which model and api to hit by default [default: default]
Options:
-c, --command <COMMAND>
Expand All @@ -20,7 +21,7 @@ Options:
-s, --system-message <SYSTEM_MESSAGE>
a system "config" message to send before the first user message
--api <API>
which api to hit
which api to hit [possible values: openai]
-m, --model <MODEL>
which model (of the api) to use
-f, --file <FILE>
Expand Down
149 changes: 132 additions & 17 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,110 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::default::Default;
use std::fmt::Debug;
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;

#[derive(Debug, Deserialize)]
pub const PLACEHOLDER_TOKEN: &str = "#[<input>]";

const DEFAULT_CONFIG_PATH: &str = ".config/smartcat/";
const CUSTOM_CONFIG_ENV_VAR: &str = "PIPELM_CONFIG_PATH";
const API_KEYS_FILE: &str = ".api_configs.toml";
const PROMPT_FILE: &str = "prompts.toml";

#[derive(clap::ValueEnum, Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Api {
Openai,
}

impl FromStr for Api {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(Api::Openai),
_ => Err(()),
}
}
}

impl ToString for Api {
fn to_string(&self) -> String {
match self {
Api::Openai => "openai".to_string(),
}
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ApiConfig {
#[serde(skip_serializing)] // internal use only
pub api_key: String,
pub url: String,
}

impl Default for ApiConfig {
// default to openai
fn default() -> Self {
ApiConfig {
api_key: String::from("<insert_api_key_here>"),
url: String::from("https://api.openai.com/v1/chat/completions"),
}
}
}

impl ApiConfig {
fn default_with_api_key(api_key: String) -> Self {
ApiConfig {
api_key,
url: String::from("https://api.openai.com/v1/chat/completions"),
}
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Prompt {
#[serde(skip_serializing)] // internal use only
pub api: String,
pub api: Api,
pub model: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub messages: Vec<Message>,
}

impl Default for Prompt {
// default to openai and gpt 4 with no preset messages
fn default() -> Self {
Prompt {
api: Api::Openai,
model: String::from("gpt-4"),
messages: Vec::new(),
}
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Message {
pub role: String,
pub content: String,
}

pub const PLACEHOLDER_TOKEN: &str = "#[<input>]";

const DEFAULT_CONFIG_PATH: &str = ".config/smartcat/";
const CUSTOM_CONFIG_ENV_VAR: &str = "PIPELM_CONFIG_PATH";
const API_KEYS_FILE: &str = ".api_configs.toml";
const PROMPT_FILE: &str = "prompts.toml";

fn resolve_config_path() -> PathBuf {
match std::env::var(CUSTOM_CONFIG_ENV_VAR) {
Ok(p) => PathBuf::new().join(p),
Err(_) => PathBuf::new().join(env!("HOME")).join(DEFAULT_CONFIG_PATH),
}
}
fn prompts_path() -> PathBuf {
resolve_config_path().join(PROMPT_FILE)
}
fn api_keys_path() -> PathBuf {
resolve_config_path().join(API_KEYS_FILE)
}

pub fn get_api_config(api: &str) -> ApiConfig {
let api_keys_path = resolve_config_path().join(API_KEYS_FILE);
let content = fs::read_to_string(&api_keys_path)
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path, error));
let content = fs::read_to_string(api_keys_path())
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", api_keys_path(), error));

let mut api_configs: HashMap<String, ApiConfig> = toml::from_str(&content).unwrap();

Expand All @@ -55,8 +118,60 @@ pub fn get_api_config(api: &str) -> ApiConfig {
}

pub fn get_prompts() -> HashMap<String, Prompt> {
let prompts_path = resolve_config_path().join(PROMPT_FILE);
let content = fs::read_to_string(&prompts_path)
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path, error));
let content = fs::read_to_string(prompts_path())
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path(), error));
toml::from_str(&content).unwrap()
}

fn read_user_input() -> String {
let mut user_input = String::new();
std::io::stdin()
.read_line(&mut user_input)
.expect("Failed to read line");
user_input.trim().to_string()
}

fn prompt_user_for_config_file_creation(file_path: impl Debug) {
println!(
"Api config file not found at {:?}, do you wish to generate one? [y/n]",
file_path
);
if read_user_input().to_lowercase() != "y" {
println!("smartcat needs this file tu function, create it and come back 👋");
std::process::exit(1);
}
}

pub fn ensure_config_files() -> std::io::Result<()> {
if !api_keys_path().exists() {
prompt_user_for_config_file_creation(api_keys_path());
println!(
"Please paste your openai API key, it can be found at\n\
https://platform.openai.com/api-keys\n\
Press enter to skip (then edit the file at {:?}).",
api_keys_path()
);
let mut api_config = HashMap::new();
api_config.insert(
Prompt::default().api.to_string(),
ApiConfig::default_with_api_key(read_user_input()),
);

let api_config_str = toml::to_string_pretty(&api_config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let mut config_file = fs::File::create(api_keys_path())?;
config_file.write_all(api_config_str.as_bytes())?;
}

if !prompts_path().exists() {
prompt_user_for_config_file_creation(prompts_path());
let mut prompt_config = HashMap::new();
prompt_config.insert("default", Prompt::default());
let prompt_str = toml::to_string_pretty(&prompt_config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let mut prompts_file = fs::File::create(prompts_path())?;
prompts_file.write_all(prompt_str.as_bytes())?;
}

Ok(())
}
18 changes: 10 additions & 8 deletions src/cutsom_prompt.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use log::debug;

use crate::config::{Message, Prompt, PLACEHOLDER_TOKEN};
use crate::config::{Api, Message, Prompt, PLACEHOLDER_TOKEN};

pub fn customize_prompt(
mut prompt: Prompt,
api: &Option<String>,
api: &Option<Api>,
model: &Option<String>,
command: &Option<String>,
after_input: &Option<String>,
system_message: &Option<String>,
) -> Prompt {
debug!("pre-customization promot {:?}", prompt);
debug!("pre-customization prompt {:?}", prompt);
// Override parameters
if let Some(api) = api {
prompt.api = api.to_owned();
Expand All @@ -35,14 +35,15 @@ pub fn customize_prompt(
}

// if prompt customization was provided, add it in a new message
let mut prompt_message = String::new();
if let Some(command_text) = command {
prompt_message.push_str(command_text);
let mut prompt_message = String::from(command_text);
if !prompt_message.contains(PLACEHOLDER_TOKEN) {
prompt_message.push_str(PLACEHOLDER_TOKEN);
}
}
if !prompt_message.is_empty() {
// remove existing input placeholder in order to get just one
for message in prompt.messages.iter_mut() {
message.content = message.content.replace(PLACEHOLDER_TOKEN, "");
}
prompt.messages.push(Message {
role: "user".to_string(),
content: prompt_message,
Expand Down Expand Up @@ -73,6 +74,7 @@ pub fn customize_prompt(

prompt.messages.push(last_message);

debug!("pre-customization promot {:?}", prompt);
debug!("post-customization prompt {:?}", prompt);

prompt
}
6 changes: 5 additions & 1 deletion src/input_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ pub fn process_input_with_request<R: Read, W: Write>(

let input = String::from_utf8(buffer).unwrap();

// insert the input in the messages with placeholders
for message in prompt.messages.iter_mut() {
message.content = message.content.replace(PLACEHOLDER_TOKEN, &input)
}
let api_config = get_api_config(&prompt.api);
// fetch the api config tied to the prompt
let api_config = get_api_config(&prompt.api.to_string());

// make the request
let response: OpenAiResponse = make_authenticated_request(api_config, prompt)
.unwrap()
.into_json()?;
Expand Down
7 changes: 5 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod config;
#[command(
author = "Emilien Fugier",
version = "0.1",
about = "WIP cli interface to language model to bring them in the Unix echosystem",
about = "Putting a brain behind `cat`. WIP cli interface to language model to bring them in the Unix echosystem 🐈‍⬛",
long_about = None
)]
struct Cli {
Expand All @@ -30,7 +30,7 @@ struct Cli {
system_message: Option<String>,
/// which api to hit
#[arg(long)]
api: Option<String>,
api: Option<config::Api>,
#[arg(short, long)]
/// which model (of the api) to use
model: Option<String>,
Expand Down Expand Up @@ -78,6 +78,9 @@ fn main() {
}
}

config::ensure_config_files()
.expect("Unable to verify that the config files exist or to generate new ones.");

let mut prompts = config::get_prompts();

let available_prompts: Vec<&String> = prompts.keys().collect();
Expand Down
41 changes: 30 additions & 11 deletions src/request.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
use crate::config::{Api, Message, Prompt};
use log::debug;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

use crate::config::ApiConfig;

#[derive(Debug, Deserialize)]
pub struct Message {
pub struct OpenAiMessage {
pub role: String,
pub content: String,
}

#[derive(Debug, Deserialize)]
pub struct Choice {
pub struct OpenAiChoice {
pub index: u32,
pub message: Message,
pub message: OpenAiMessage,
pub finish_reason: String,
}

#[derive(Debug, Deserialize)]
pub struct Usage {
pub struct OpenAiUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
Expand All @@ -30,19 +31,37 @@ pub struct OpenAiResponse {
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
pub choices: Vec<OpenAiChoice>,
pub usage: OpenAiUsage,
pub system_fingerprint: Option<String>,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct OpenAiPrompt {
pub model: String,
pub messages: Vec<Message>,
}

impl From<Prompt> for OpenAiPrompt {
fn from(prompt: Prompt) -> OpenAiPrompt {
OpenAiPrompt {
model: prompt.model,
messages: prompt.messages,
}
}
}

pub fn make_authenticated_request(
api_config: ApiConfig,
data: impl Serialize + Debug,
prompt: Prompt,
) -> Result<ureq::Response, ureq::Error> {
debug!("Trying to reach openai with {}", api_config.api_key);
debug!("request content: {:?}", data);
ureq::post(&api_config.url)
debug!("request content: {:?}", prompt);

let request = ureq::post(&api_config.url)
.set("Content-Type", "application/json")
.set("Authorization", &format!("Bearer {}", api_config.api_key))
.send_json(data)
.set("Authorization", &format!("Bearer {}", api_config.api_key));
match prompt.api {
Api::Openai => request.send_json(OpenAiPrompt::from(prompt)),
}
}

0 comments on commit 45ebaff

Please sign in to comment.