Skip to content

Commit

Permalink
add basic autocomplete
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Sep 14, 2024
1 parent 1e18d97 commit bbae631
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 7 deletions.
33 changes: 33 additions & 0 deletions assyst-core/src/command/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,39 @@ impl ParseArgument for Word {
}
}

/// A single word argument, with autocompletion.
#[derive(Debug)]
pub struct WordAutocomplete(pub String);

impl ParseArgument for WordAutocomplete {
async fn parse_raw_message(ctxt: &mut RawMessageParseCtxt<'_>, label: Label) -> Result<Self, TagParseError> {
Ok(Self(ctxt.next_word(label)?.to_owned()))
}

async fn parse_command_option(
ctxt: &mut InteractionCommandParseCtxt<'_>,
label: Label,
) -> Result<Self, TagParseError> {
let word = &ctxt.option_by_name(&label.unwrap().0)?.value;

if let CommandOptionValue::String(ref option) = word {
Ok(WordAutocomplete(option.clone()))
} else {
Err(TagParseError::MismatchedCommandOptionType((
"String (Word Autocomplete)".to_owned(),
word.clone(),
)))
}
}

fn as_command_option(name: &str) -> CommandOption {
StringBuilder::new(name, "word input")
.autocomplete(true)
.required(true)
.build()
}
}

/// A codeblock argument (may also be plaintext).
#[derive(Debug)]
pub struct Codeblock(pub String);
Expand Down
70 changes: 70 additions & 0 deletions assyst-core/src/command/autocomplete.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use assyst_common::err;
use twilight_model::application::command::{CommandOptionChoice, CommandOptionChoiceValue};
use twilight_model::http::interaction::{InteractionResponse, InteractionResponseType};
use twilight_model::id::marker::{GuildMarker, InteractionMarker};
use twilight_model::id::Id;
use twilight_util::builder::InteractionResponseDataBuilder;

use super::misc::tag::tag_names_autocomplete;
use super::services::cooltext::cooltext_options_autocomplete;
use crate::assyst::ThreadSafeAssyst;

const SUGG_LIMIT: usize = 25;

pub async fn handle_autocomplete(
assyst: ThreadSafeAssyst,
interaction_id: Id<InteractionMarker>,
interaction_token: String,
guild_id: Option<Id<GuildMarker>>,
command_full_name: &str,
option: &str,
text_to_autocomplete: &str,
) {
// FIXME: minimise hardcoding strings etc as much as possible
// future improvement is to use callbacks, but quite a lot of work
// considering this is only used in a small handful of places
let opts = match command_full_name {
"cooltext create" => cooltext_options_autocomplete(),
// FIXME: this unwrap needs handling properly when tags come to dms etc
"tag run" => tag_names_autocomplete(assyst.clone(), guild_id.unwrap().get()).await,
_ => {
err!("Trying to autocomplete for invalid command: {command_full_name} (arg {option})");
return;
},
};

let suggestions = get_autocomplete_suggestions(text_to_autocomplete, &opts);

let b = InteractionResponseDataBuilder::new();
let b = b.choices(suggestions);
let r = b.build();
let r = InteractionResponse {
kind: InteractionResponseType::ApplicationCommandAutocompleteResult,
data: Some(r),
};

if let Err(e) = assyst
.interaction_client()
.create_response(interaction_id, &interaction_token, &r)
.await
{
err!("Failed to send autocomplete options: {e:?}");
};
}

pub fn get_autocomplete_suggestions(text_to_autocomplete: &str, options: &[String]) -> Vec<CommandOptionChoice> {
options
.iter()
.filter(|x| {
x.to_ascii_lowercase()
.starts_with(&text_to_autocomplete.to_ascii_lowercase())
})
.take(SUGG_LIMIT)
.map(|x| CommandOptionChoice {
name: x.clone(),
name_localizations: None,
// FIXME: hardcoded string type
value: CommandOptionChoiceValue::String(x.clone()),
})
.collect::<Vec<_>>()
}
14 changes: 12 additions & 2 deletions assyst-core/src/command/misc/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use zip::ZipWriter;

use super::CommandCtxt;
use crate::assyst::ThreadSafeAssyst;
use crate::command::arguments::{Image, ImageUrl, RestNoFlags, User, Word};
use crate::command::arguments::{Image, ImageUrl, RestNoFlags, User, Word, WordAutocomplete};
use crate::command::componentctxt::{
button_emoji_new, button_new, respond_modal, respond_update_text, ComponentCtxt, ComponentInteractionData,
ComponentMetadata,
Expand Down Expand Up @@ -877,6 +877,12 @@ pub async fn paste(ctxt: CommandCtxt<'_>, name: Word) -> anyhow::Result<()> {
Ok(())
}

pub async fn tag_names_autocomplete(assyst: ThreadSafeAssyst, guild_id: u64) -> Vec<String> {
Tag::get_names_in_guild(&assyst.database_handler, guild_id as i64)
.await
.unwrap_or(vec![])
}

#[command(
description = "run a tag in the current server",
cooldown = Duration::from_secs(2),
Expand All @@ -887,7 +893,11 @@ pub async fn paste(ctxt: CommandCtxt<'_>, name: Word) -> anyhow::Result<()> {
send_processing = true,
guild_only = true
)]
pub async fn default(ctxt: CommandCtxt<'_>, tag_name: Word, arguments: Option<Vec<Word>>) -> anyhow::Result<()> {
pub async fn default(
ctxt: CommandCtxt<'_>,
tag_name: WordAutocomplete,
arguments: Option<Vec<Word>>,
) -> anyhow::Result<()> {
let Some(guild_id) = ctxt.data.guild_id else {
bail!("Tags can only be used in guilds.")
};
Expand Down
1 change: 1 addition & 0 deletions assyst-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use super::gateway_handler::reply as gateway_reply;
use crate::assyst::ThreadSafeAssyst;

pub mod arguments;
pub mod autocomplete;
pub mod componentctxt;
pub mod errors;
pub mod flags;
Expand Down
11 changes: 8 additions & 3 deletions assyst-core/src/command/services/cooltext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ use assyst_proc_macro::command;
use assyst_string_fmt::Markdown;
use rand::{thread_rng, Rng};

use crate::command::arguments::{Rest, Word};
use crate::command::arguments::{Rest, WordAutocomplete};
use crate::command::{Availability, Category, CommandCtxt};
use crate::define_commandgroup;
use crate::rest::cooltext::STYLES;

pub fn cooltext_options_autocomplete() -> Vec<String> {
let options = STYLES.iter().map(|x| x.0.to_owned()).collect::<Vec<_>>();
options
}

#[command(
description = "make some cool text",
access = Availability::Public,
Expand All @@ -17,7 +22,7 @@ use crate::rest::cooltext::STYLES;
examples = ["burning hello", "saint fancy", "random im random"],
send_processing = true
)]
pub async fn default(ctxt: CommandCtxt<'_>, style: Word, text: Rest) -> anyhow::Result<()> {
pub async fn default(ctxt: CommandCtxt<'_>, style: WordAutocomplete, text: Rest) -> anyhow::Result<()> {
let style = if &style.0 == "random" {
let rand = thread_rng().gen_range(0..STYLES.len());
STYLES[rand].0
Expand Down Expand Up @@ -63,6 +68,6 @@ define_commandgroup! {
commands: [
"list" => list
],
default_interaction_subcommand: "run",
default_interaction_subcommand: "create",
default: default
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ use tracing::{debug, warn};
use twilight_model::application::interaction::application_command::{
CommandData as DiscordCommandData, CommandDataOption, CommandOptionValue,
};
use twilight_model::application::interaction::{InteractionContextType, InteractionData};
use twilight_model::application::interaction::{InteractionContextType, InteractionData, InteractionType};
use twilight_model::gateway::payload::incoming::InteractionCreate;
use twilight_model::http::interaction::{InteractionResponse, InteractionResponseType};
use twilight_model::util::Timestamp;

use super::after_command_execution_success;
use crate::assyst::ThreadSafeAssyst;
use crate::command::autocomplete::handle_autocomplete;
use crate::command::componentctxt::ComponentInteractionData;
use crate::command::registry::find_command_by_name;
use crate::command::source::Source;
Expand Down Expand Up @@ -74,7 +75,9 @@ pub async fn handle(assyst: ThreadSafeAssyst, InteractionCreate(interaction): In
}
}

if let Some(InteractionData::ApplicationCommand(command_data)) = interaction.data {
if interaction.kind == InteractionType::ApplicationCommand
&& let Some(InteractionData::ApplicationCommand(command_data)) = interaction.data
{
let command = find_command_by_name(&command_data.name);
let subcommand_data = parse_subcommand_data(&command_data);

Expand Down Expand Up @@ -262,5 +265,62 @@ pub async fn handle(assyst: ThreadSafeAssyst, InteractionCreate(interaction): In
};
},
}
} else if interaction.kind == InteractionType::ApplicationCommandAutocomplete
&& let Some(InteractionData::ApplicationCommand(command_data)) = interaction.data
{
let command = find_command_by_name(&command_data.name);
let subcommand_data = parse_subcommand_data(&command_data);

if let Some(command) = command {
let incoming_options = if let Some(d) = subcommand_data.clone() {
match d.1 {
CommandOptionValue::SubCommand(s) => s,
_ => unreachable!(),
}
} else {
command_data.options
};

let interaction_subcommand = if let Some(d) = subcommand_data {
match d.1 {
CommandOptionValue::SubCommand(_) => Some(d),
_ => unreachable!(),
}
} else {
None
};

let focused_option = incoming_options
.iter()
.find(|x| matches!(x.value, CommandOptionValue::Focused(_, _)))
.expect("no focused option?");

// we will probably only ever use autocomplete on `Word` arguments
// FIXME: add support for more arg types here?
let inner_option = if let CommandOptionValue::Focused(x, y) = focused_option.value.clone() {
(x, y)
} else {
unreachable!()
};

// full_name will use the interaction command replacement for any "default" subcommand (e.g., "tag
// run")
let full_name = if let Some(i) = interaction_subcommand {
format!("{} {}", command.metadata().name, i.0)
} else {
command.metadata().name.to_owned()
};

handle_autocomplete(
assyst.clone(),
interaction.id,
interaction.token,
interaction.guild_id,
&full_name,
&focused_option.name,
&inner_option.0,
)
.await;
}
}
}
10 changes: 10 additions & 0 deletions assyst-database/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct DatabaseCache {
global_blacklist: Cache<u64, bool>,
disabled_commands: Cache<u64, Arc<Mutex<HashSet<String>>>>,
copied_tags: Cache<u64 /* user id */, String /* content */>,
guild_tag_names: Cache<u64, Vec<String>>,
}
impl DatabaseCache {
pub fn new() -> Self {
Expand All @@ -39,6 +40,7 @@ impl DatabaseCache {
global_blacklist: default_cache(),
disabled_commands: default_cache(),
copied_tags: default_cache_sized(u64::MAX),
guild_tag_names: default_cache(),
}
}

Expand Down Expand Up @@ -140,6 +142,14 @@ impl DatabaseCache {
pub fn get_copied_tag(&self, user_id: u64) -> Option<String> {
self.copied_tags.get(&user_id)
}

pub fn insert_guild_tag_names(&self, guild_id: u64, names: Vec<String>) {
self.guild_tag_names.insert(guild_id, names);
}

pub fn get_guild_tag_names(&self, guild_id: u64) -> Option<Vec<String>> {
self.guild_tag_names.get(&guild_id)
}
}

impl Default for DatabaseCache {
Expand Down
14 changes: 14 additions & 0 deletions assyst-database/src/model/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ impl Tag {
.await
}

pub async fn get_names_in_guild(handler: &DatabaseHandler, guild_id: i64) -> Result<Vec<String>, sqlx::Error> {
if let Some(g) = handler.cache.get_guild_tag_names(guild_id as u64) {
return Ok(g);
}

let query = r#"SELECT * FROM tags WHERE guild_id = $1"#;

let result: Vec<Tag> = sqlx::query_as(query).bind(guild_id).fetch_all(&handler.pool).await?;
let names = result.iter().map(|x| &x.name).cloned().collect::<Vec<_>>();
handler.cache.insert_guild_tag_names(guild_id as u64, names.clone());

Ok(names)
}

pub async fn get_paged_for_user(
handler: &DatabaseHandler,
guild_id: i64,
Expand Down

0 comments on commit bbae631

Please sign in to comment.