Skip to content

Commit

Permalink
feat: implement multiple named conversations
Browse files Browse the repository at this point in the history
- Replace single conversation file with a conversations directory structure
- Add support for named conversations via -n/--name flag
- Implement conversation name validation
- Add test coverage for conversation management
  • Loading branch information
bytesoverflow committed Nov 6, 2024
1 parent 08c5592 commit bd57694
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 19 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0"
env_logger = "0"
reqwest = { version = "0", default-features = false, features = ["http2", "json", "blocking", "multipart", "rustls-tls"] }
reqwest = { version = "0", features = ["json", "blocking", "multipart"] }
regex = "1.11.1"

[dev-dependencies]
tempfile = "3"
Expand Down
12 changes: 10 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{path::PathBuf, process::Command};

use self::{
api::{api_keys_path, generate_api_keys_file, get_api_config},
prompt::{generate_prompts_file, get_prompts, prompts_path},
prompt::{generate_prompts_file, get_prompts, prompts_path, conversations_path},
};
use crate::utils::is_interactive;

Expand Down Expand Up @@ -58,6 +58,12 @@ pub fn ensure_config_files() -> std::io::Result<()> {
}
};

// Create the conversations directory if it doesn't exist
if !conversations_path().exists() {
std::fs::create_dir_all(conversations_path())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create conversations directory: {}", e)))?;
}

Ok(())
}

Expand Down Expand Up @@ -107,7 +113,7 @@ mod tests {
config::{
api::{api_keys_path, default_timeout_seconds, Api, ApiConfig},
ensure_config_files,
prompt::{prompts_path, Prompt},
prompt::{prompts_path, conversations_path, Prompt},
resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH,
},
utils::IS_NONINTERACTIVE_ENV_VAR,
Expand Down Expand Up @@ -175,6 +181,7 @@ mod tests {

assert!(!api_keys_path.exists());
assert!(!prompts_path.exists());
assert!(!conversations_path().exists());

let result = ensure_config_files();

Expand All @@ -187,6 +194,7 @@ mod tests {

assert!(api_keys_path.exists());
assert!(prompts_path.exists());
assert!(conversations_path().exists());

Ok(())
}
Expand Down
68 changes: 61 additions & 7 deletions src/config/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::path::PathBuf;
use crate::config::{api::Api, resolve_config_path};

const PROMPT_FILE: &str = "prompts.toml";
const CONVERSATION_FILE: &str = "conversation.toml";
const CONVERSATIONS_PATH: &str = "conversations/";

#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct Prompt {
Expand Down Expand Up @@ -96,19 +96,31 @@ pub(super) fn prompts_path() -> PathBuf {
resolve_config_path().join(PROMPT_FILE)
}

pub fn conversation_file_path() -> PathBuf {
resolve_config_path().join(CONVERSATION_FILE)
// Get the path to the conversations directory
pub fn conversations_path() -> PathBuf {
resolve_config_path().join(CONVERSATIONS_PATH)
}

pub fn get_last_conversation_as_prompt() -> Prompt {
let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| {
// Get the path to a specific conversation file
pub fn conversation_file_path(name: &str) -> PathBuf {
conversations_path().join(format!("{}.toml", name))
}

// Get the last conversation as a prompt, if it exists
pub fn get_last_conversation_as_prompt(name: &str) -> Option<Prompt> {
let file_path = conversation_file_path(name);
if !file_path.exists() {
return None;
}

let content = fs::read_to_string(file_path).unwrap_or_else(|error| {
panic!(
"Could not read file {:?}, {:?}",
conversation_file_path(),
conversation_file_path(name),
error
)
});
toml::from_str(&content).expect("failed to load the conversation file")
Some(toml::from_str(&content).expect("failed to load the conversation file"))
}

pub(super) fn generate_prompts_file() -> std::io::Result<()> {
Expand Down Expand Up @@ -136,3 +148,45 @@ pub fn get_prompts() -> HashMap<String, Prompt> {
.unwrap_or_else(|error| panic!("Could not read file {:?}, {:?}", prompts_path(), error));
toml::from_str(&content).expect("could not parse prompt file content")
}

#[cfg(test)]
mod tests {
use super::*;
use std::fs;

#[test]
fn test_conversation_file_path() {
let name = "test_conversation";
let file_path = conversation_file_path(name);
assert_eq!(
file_path.file_name().unwrap().to_str().unwrap(),
format!("{}.toml", name)
);
assert_eq!(file_path.parent().unwrap(), conversations_path());
}

#[test]
fn test_get_last_conversation_as_prompt() {
let name = "test_conversation";
let file_path = conversation_file_path(name);
let prompt = Prompt::default();

// Create a test conversation file
let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt");
fs::write(&file_path, toml_string).expect("Failed to write test conversation file");

let loaded_prompt = get_last_conversation_as_prompt(name);
assert_eq!(loaded_prompt, Some(prompt));

// Clean up the test conversation file
fs::remove_file(&file_path).expect("Failed to remove test conversation file");
}

#[test]
fn test_get_last_conversation_as_prompt_missing_file() {
let name = "nonexistent_conversation";
let loaded_prompt = get_last_conversation_as_prompt(name);
assert_eq!(loaded_prompt, None);
}

}
47 changes: 44 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::config::{
prompt::{conversation_file_path, get_last_conversation_as_prompt, get_prompts, Prompt},
};
use prompt_customization::customize_prompt;
use crate::utils::valid_conversation_name;

use clap::{Args, Parser};
use log::debug;
Expand All @@ -18,6 +19,7 @@ use std::io::{self, IsTerminal, Read, Write};
use text::process_input_with_request;

const DEFAULT_PROMPT_NAME: &str = "default";
const DEFAULT_CONVERSATION_NAME: &str = "default";

#[derive(Debug, Parser)]
#[command(
Expand Down Expand Up @@ -56,6 +58,9 @@ struct Cli {
/// whether to repeat the input before the output, useful to extend instead of replacing
#[arg(short, long)]
repeat_input: bool,
/// conversation name
#[arg(short, long, value_parser = valid_conversation_name)]
name: Option<String>,
#[command(flatten)]
prompt_params: PromptParams,
}
Expand Down Expand Up @@ -104,6 +109,7 @@ fn main() {
}

let args = Cli::parse();
let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME);

debug!("args: {:?}", args);

Expand All @@ -118,14 +124,18 @@ fn main() {
// if it doesn't use default prompt and treat that first arg as customization text
get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text)
} else {
prompt_customizaton_text = args.input_or_template_ref;
prompt_customizaton_text = args.input_or_template_ref.clone();
if args.input_if_template_ref.is_some() {
panic!(
"Invalid parameters, cannot provide a config ref when extending a conversation.\n\
Use `sc -e \"<your_prompt>.\"`"
);
}
get_last_conversation_as_prompt()

match get_last_conversation_as_prompt(name) {
Some(prompt) => prompt,
None => get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text),
}
};

// if no text was piped, use the custom prompt as input
Expand All @@ -149,7 +159,7 @@ fn main() {
Ok(prompt) => {
let toml_string =
toml::to_string(&prompt).expect("Failed to serialize prompt after response.");
let mut file = fs::File::create(conversation_file_path())
let mut file = fs::File::create(conversation_file_path(name))
.expect("Failed to create the conversation save file.");
file.write_all(toml_string.as_bytes())
.expect("Failed to write to the conversation file.");
Expand Down Expand Up @@ -204,3 +214,34 @@ fn get_default_and_or_custom_prompt(
.expect(&prompt_not_found_error)
}
}

#[cfg(test)]
mod tests {

use super::*;
use crate::config::prompt::Prompt;

#[test]
fn test_get_last_conversation_as_prompt_missing_file() {
let args = Cli {
input_or_template_ref: Some("test_input".to_string()),
input_if_template_ref: None,
extend_conversation: true,
repeat_input: false,
name: Some("nonexistent_conversation".to_string()),
prompt_params: PromptParams::default(),
};
let mut prompt_customizaton_text = None;
let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME);

let prompt = get_last_conversation_as_prompt(name);

assert_eq!(prompt, None);

let default_prompt = get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text);
assert_eq!(default_prompt, Prompt::default());
assert_eq!(prompt_customizaton_text, Some("test_input".to_string()));

}

}
12 changes: 12 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use regex::Regex;

pub const IS_NONINTERACTIVE_ENV_VAR: &str = "SMARTCAT_NONINTERACTIVE";

/// clean error logging
Expand All @@ -24,3 +26,13 @@ pub fn read_user_input() -> String {
.expect("Failed to read line");
user_input.trim().to_string()
}

// Validate the conversation name
pub fn valid_conversation_name(s: &str) -> Result<String, String> {
let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
if re.is_match(s) {
Ok(s.to_string())
} else {
Err(format!("Invalid name: {}", s))
}
}

0 comments on commit bd57694

Please sign in to comment.