From 1e276881d995071658c18cc1f440d5f8d6aaf533 Mon Sep 17 00:00:00 2001 From: Toby Date: Wed, 6 Nov 2024 15:51:25 +0000 Subject: [PATCH] feat: maintain backwards compatibility with conversastion.toml - Use conversation.toml for storing latest conversation state - Remove unnecessary directory creation in config initialization - Update test coverage --- src/config/mod.rs | 12 +-- src/config/prompt.rs | 184 ++++++++++++++++++++++++++++++++++--------- src/main.rs | 142 ++++++++++++++++++++++++++------- src/utils.rs | 2 +- 4 files changed, 264 insertions(+), 76 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index a9f1b4d..358b904 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,7 +5,7 @@ use std::{path::PathBuf, process::Command}; use self::{ api::{api_keys_path, generate_api_keys_file, get_api_config}, - prompt::{generate_prompts_file, get_prompts, prompts_path, conversations_path}, + prompt::{generate_prompts_file, get_prompts, prompts_path}, }; use crate::utils::is_interactive; @@ -58,12 +58,6 @@ pub fn ensure_config_files() -> std::io::Result<()> { } }; - // Create the conversations directory if it doesn't exist - if !conversations_path().exists() { - std::fs::create_dir_all(conversations_path()) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to create conversations directory: {}", e)))?; - } - Ok(()) } @@ -113,7 +107,7 @@ mod tests { config::{ api::{api_keys_path, default_timeout_seconds, Api, ApiConfig}, ensure_config_files, - prompt::{prompts_path, conversations_path, Prompt}, + prompt::{prompts_path, Prompt}, resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH, }, utils::IS_NONINTERACTIVE_ENV_VAR, @@ -181,7 +175,6 @@ mod tests { assert!(!api_keys_path.exists()); assert!(!prompts_path.exists()); - assert!(!conversations_path().exists()); let result = ensure_config_files(); @@ -194,7 +187,6 @@ mod tests { assert!(api_keys_path.exists()); assert!(prompts_path.exists()); - assert!(conversations_path().exists()); Ok(()) } diff --git a/src/config/prompt.rs b/src/config/prompt.rs index e38e3e1..1ae56a8 100644 --- a/src/config/prompt.rs +++ b/src/config/prompt.rs @@ -9,7 +9,8 @@ use std::path::PathBuf; use crate::config::{api::Api, resolve_config_path}; const PROMPT_FILE: &str = "prompts.toml"; -const CONVERSATIONS_PATH: &str = "conversations/"; +const CONVERSATION_FILE: &str = "conversation.toml"; +const CONVERSATIONS_PATH: &str = "saved_conversations"; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct Prompt { @@ -96,31 +97,66 @@ pub(super) fn prompts_path() -> PathBuf { resolve_config_path().join(PROMPT_FILE) } +pub fn conversation_file_path() -> PathBuf { + resolve_config_path().join(CONVERSATION_FILE) +} + // Get the path to the conversations directory pub fn conversations_path() -> PathBuf { resolve_config_path().join(CONVERSATIONS_PATH) } // Get the path to a specific conversation file -pub fn conversation_file_path(name: &str) -> PathBuf { +pub fn named_conversation_path(name: &str) -> PathBuf { conversations_path().join(format!("{}.toml", name)) } // Get the last conversation as a prompt, if it exists -pub fn get_last_conversation_as_prompt(name: &str) -> Option { - let file_path = conversation_file_path(name); - if !file_path.exists() { - return None; +pub fn get_last_conversation_as_prompt(name: Option<&str>) -> Option { + if let Some(name) = name { + let named_path = named_conversation_path(name); + if !named_path.exists() { + return None; + } + let content = fs::read_to_string(named_path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + named_conversation_path(name), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } else { + let path = conversation_file_path(); + if !path.exists() { + return None; + } + let content = fs::read_to_string(path) + .unwrap_or_else(|error| { + panic!( + "Could not read file {:?}, {:?}", + conversation_file_path(), + error + ) + }); + Some(toml::from_str(&content).expect("failed to load the conversation file")) + } +} + +pub fn save_conversation(prompt: &Prompt, name: Option<&str>) -> std::io::Result<()> { + let toml_string = toml::to_string(prompt).expect("Failed to serialize prompt"); + + // Always save to conversation.toml + fs::write(conversation_file_path(), &toml_string)?; + + // If name is provided, also save to named conversation file + if let Some(name) = name { + fs::create_dir_all(conversations_path())?; + fs::write(named_conversation_path(name), &toml_string)?; } - let content = fs::read_to_string(file_path).unwrap_or_else(|error| { - panic!( - "Could not read file {:?}, {:?}", - conversation_file_path(name), - error - ) - }); - Some(toml::from_str(&content).expect("failed to load the conversation file")) + Ok(()) } pub(super) fn generate_prompts_file() -> std::io::Result<()> { @@ -153,40 +189,114 @@ pub fn get_prompts() -> HashMap { mod tests { use super::*; use std::fs; + use tempfile::tempdir; + use crate::config::prompt::Prompt; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } #[test] - fn test_conversation_file_path() { - let name = "test_conversation"; - let file_path = conversation_file_path(name); - assert_eq!( - file_path.file_name().unwrap().to_str().unwrap(), - format!("{}.toml", name) - ); - assert_eq!(file_path.parent().unwrap(), conversations_path()); + #[serial] + fn test_get_and_save_default_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving conversation + save_conversation(&test_prompt, None).unwrap(); + assert!(conversation_file_path().exists()); + + // Test retrieving conversation + let loaded_prompt = get_last_conversation_as_prompt(None).unwrap(); + assert_eq!(loaded_prompt, test_prompt); } #[test] - fn test_get_last_conversation_as_prompt() { - let name = "test_conversation"; - let file_path = conversation_file_path(name); - let prompt = Prompt::default(); + #[serial] + fn test_get_and_save_named_conversation() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; - // Create a test conversation file - let toml_string = toml::to_string(&prompt).expect("Failed to serialize prompt"); - fs::write(&file_path, toml_string).expect("Failed to write test conversation file"); + // Test saving named conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + assert!(named_conversation_path(conv_name).exists()); + assert!(conversation_file_path().exists()); // Should also save to default location - let loaded_prompt = get_last_conversation_as_prompt(name); - assert_eq!(loaded_prompt, Some(prompt)); + // Test retrieving named conversation + let loaded_prompt = get_last_conversation_as_prompt(Some(conv_name)).unwrap(); + assert_eq!(loaded_prompt, test_prompt); + } - // Clean up the test conversation file - fs::remove_file(&file_path).expect("Failed to remove test conversation file"); + #[test] + #[serial] + fn test_nonexistent_conversation() { + let _temp_dir = setup(); + + // Test getting nonexistent default conversation + assert!(get_last_conversation_as_prompt(None).is_none()); + + // Test getting nonexistent named conversation + assert!(get_last_conversation_as_prompt(Some("nonexistent")).is_none()); + } + + #[test] + #[serial] + fn test_conversation_file_contents() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + let conv_name = "test_conversation"; + + // Save conversation + save_conversation(&test_prompt, Some(conv_name)).unwrap(); + + // Verify default and named files have identical content + let default_content = fs::read_to_string(conversation_file_path()).unwrap(); + let named_content = fs::read_to_string(named_conversation_path(conv_name)).unwrap(); + assert_eq!(default_content, named_content); + + // Verify content can be parsed back to original prompt + let parsed_prompt: Prompt = toml::from_str(&default_content).unwrap(); + assert_eq!(parsed_prompt, test_prompt); } #[test] - fn test_get_last_conversation_as_prompt_missing_file() { - let name = "nonexistent_conversation"; - let loaded_prompt = get_last_conversation_as_prompt(name); - assert_eq!(loaded_prompt, None); + #[serial] + fn test_generate_prompts_file() { + let _temp_dir = setup(); + + // Test file generation + generate_prompts_file().unwrap(); + assert!(prompts_path().exists()); + + // Verify file is valid TOML and contains expected content + let content = fs::read_to_string(prompts_path()).unwrap(); + let prompts: HashMap = toml::from_str(&content).unwrap(); + assert!(!prompts.is_empty()); } + #[test] + #[serial] + fn test_get_prompts() { + let _temp_dir = setup(); + + // Generate prompts file + generate_prompts_file().unwrap(); + + // Test loading prompts + let prompts = get_prompts(); + assert!(!prompts.is_empty()); + + // Verify at least one default prompt exists + assert!(prompts.contains_key("default")); + } } diff --git a/src/main.rs b/src/main.rs index 2a52a3a..6cd35c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,20 +6,18 @@ mod utils; use crate::config::{ api::Api, ensure_config_usable, - prompt::{conversation_file_path, get_last_conversation_as_prompt, get_prompts, Prompt}, + prompt::{get_last_conversation_as_prompt, save_conversation, get_prompts, Prompt}, }; use prompt_customization::customize_prompt; use crate::utils::valid_conversation_name; use clap::{Args, Parser}; use log::debug; -use std::fs; use std::io::{self, IsTerminal, Read, Write}; use text::process_input_with_request; const DEFAULT_PROMPT_NAME: &str = "default"; -const DEFAULT_CONVERSATION_NAME: &str = "default"; #[derive(Debug, Parser)] #[command( @@ -109,7 +107,6 @@ fn main() { } let args = Cli::parse(); - let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); debug!("args: {:?}", args); @@ -119,12 +116,9 @@ fn main() { let is_piped = !stdin.is_terminal(); let mut prompt_customizaton_text: Option = None; - let prompt: Prompt = if !args.extend_conversation { - // try to get prompt matching the first arg and use second arg as customization text - // if it doesn't use default prompt and treat that first arg as customization text - get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) - } else { + let prompt = if args.extend_conversation { prompt_customizaton_text = args.input_or_template_ref.clone(); + if args.input_if_template_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ @@ -132,10 +126,19 @@ fn main() { ); } - match get_last_conversation_as_prompt(name) { + match get_last_conversation_as_prompt(args.name.as_deref()) { Some(prompt) => prompt, - None => get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text), + None => { + if args.name.is_some() { + panic!("Named conversation does not exist: {}", args.name.unwrap()); + } + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) + } } + } else { + // try to get prompt matching the first arg and use second arg as customization text + // if it doesn't use default prompt and treat that first arg as customization text + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) }; // if no text was piped, use the custom prompt as input @@ -156,13 +159,9 @@ fn main() { debug!("{:?}", prompt); match process_input_with_request(prompt, input, &mut output, args.repeat_input) { - Ok(prompt) => { - let toml_string = - toml::to_string(&prompt).expect("Failed to serialize prompt after response."); - let mut file = fs::File::create(conversation_file_path(name)) - .expect("Failed to create the conversation save file."); - file.write_all(toml_string.as_bytes()) - .expect("Failed to write to the conversation file."); + Ok(new_prompt) => { + save_conversation(&new_prompt, args.name.as_deref()) + .expect("Failed to save conversation"); } Err(e) => { eprintln!("Error: {}", e); @@ -217,12 +216,28 @@ fn get_default_and_or_custom_prompt( #[cfg(test)] mod tests { - use super::*; - use crate::config::prompt::Prompt; + use crate::config::prompt::{Prompt, Message}; + use tempfile::tempdir; + use serial_test::serial; + + fn setup() -> tempfile::TempDir { + let temp_dir = tempdir().unwrap(); + std::env::set_var("SMARTCAT_CONFIG_PATH", temp_dir.path()); + temp_dir + } + + fn create_test_prompt() -> Prompt { + let mut prompt = Prompt::default(); + prompt.messages = vec![(Message::user("test"))]; + prompt + } #[test] - fn test_get_last_conversation_as_prompt_missing_file() { + #[serial] + fn test_cli_with_nonexistent_conversation() { + let _temp_dir = setup(); + let args = Cli { input_or_template_ref: Some("test_input".to_string()), input_if_template_ref: None, @@ -231,17 +246,88 @@ mod tests { name: Some("nonexistent_conversation".to_string()), prompt_params: PromptParams::default(), }; - let mut prompt_customizaton_text = None; - let name = args.name.as_deref().unwrap_or(DEFAULT_CONVERSATION_NAME); - let prompt = get_last_conversation_as_prompt(name); + // Test that getting a nonexistent conversation returns None + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); + } - assert_eq!(prompt, None); + #[test] + #[serial] + fn test_cli_with_existing_conversation() { + let _temp_dir = setup(); + + // Create a test conversation + let test_prompt = create_test_prompt(); + save_conversation(&test_prompt, Some("test_conversation")).unwrap(); + + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: Some("test_conversation".to_string()), + prompt_params: PromptParams::default(), + }; + + // Test retrieving the saved conversation + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_some()); + assert_eq!(prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_valid_conversation_name() { + assert!(valid_conversation_name("valid_name").is_ok()); + assert!(valid_conversation_name("valid-name").is_ok()); + assert!(valid_conversation_name("valid123").is_ok()); + assert!(valid_conversation_name("VALID_NAME").is_ok()); + + assert!(valid_conversation_name("invalid name").is_err()); + assert!(valid_conversation_name("invalid/name").is_err()); + assert!(valid_conversation_name("invalid.name").is_err()); + assert!(valid_conversation_name("").is_err()); + } + + #[test] + #[serial] + fn test_conversation_persistence() { + let _temp_dir = setup(); + let test_prompt = create_test_prompt(); + + // Test saving and loading default conversation + save_conversation(&test_prompt, None).unwrap(); + let loaded_prompt = get_last_conversation_as_prompt(None); + assert!(loaded_prompt.is_some()); + assert_eq!(loaded_prompt.unwrap(), test_prompt); + + // Test saving and loading named conversation + save_conversation(&test_prompt, Some("test_conv")).unwrap(); + let loaded_named_prompt = get_last_conversation_as_prompt(Some("test_conv")); + assert!(loaded_named_prompt.is_some()); + assert_eq!(loaded_named_prompt.unwrap(), test_prompt); + } + + #[test] + #[serial] + fn test_default_prompt_fallback() { + let _temp_dir = setup(); + let args = Cli { + input_or_template_ref: Some("test_input".to_string()), + input_if_template_ref: None, + extend_conversation: true, + repeat_input: false, + name: None, + prompt_params: PromptParams::default(), + }; - let default_prompt = get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text); - assert_eq!(default_prompt, Prompt::default()); - assert_eq!(prompt_customizaton_text, Some("test_input".to_string())); + let prompt = get_last_conversation_as_prompt(args.name.as_deref()); + assert!(prompt.is_none()); // Should be None when no conversation exists + // Verify the prompt customization text is set correctly + let prompt_customization_text = args.input_or_template_ref; + assert_eq!(prompt_customization_text, Some("test_input".to_string())); } } diff --git a/src/utils.rs b/src/utils.rs index 5643f45..fb676a4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -33,6 +33,6 @@ pub fn valid_conversation_name(s: &str) -> Result { if re.is_match(s) { Ok(s.to_string()) } else { - Err(format!("Invalid name: {}", s)) + Err(format!("Invalid conversation name: {}. Use only letters, numbers, underscores, and hyphens.", s)) } }