Skip to content

Commit

Permalink
add additional autocompletions for edit and delete
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Sep 14, 2024
1 parent 73ae311 commit 1ab1bb8
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 16 deletions.
10 changes: 8 additions & 2 deletions assyst-core/src/command/autocomplete.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
use assyst_common::err;
use twilight_model::application::command::{CommandOptionChoice, CommandOptionChoiceValue};
use twilight_model::http::interaction::{InteractionResponse, InteractionResponseType};
use twilight_model::id::marker::{GuildMarker, InteractionMarker};
use twilight_model::id::marker::{GuildMarker, InteractionMarker, UserMarker};
use twilight_model::id::Id;
use twilight_util::builder::InteractionResponseDataBuilder;

use super::misc::tag::tag_names_autocomplete;
use super::misc::tag::{tag_names_autocomplete, tag_names_autocomplete_for_user};
use super::services::cooltext::cooltext_options_autocomplete;
use crate::assyst::ThreadSafeAssyst;

const SUGG_LIMIT: usize = 25;

// FIXME: pass a struct with data instead of having so many arguments
#[allow(clippy::too_many_arguments)]
pub async fn handle_autocomplete(
assyst: ThreadSafeAssyst,
interaction_id: Id<InteractionMarker>,
interaction_token: String,
guild_id: Option<Id<GuildMarker>>,
user_id: Id<UserMarker>,
command_full_name: &str,
option: &str,
text_to_autocomplete: &str,
Expand All @@ -29,6 +32,9 @@ pub async fn handle_autocomplete(
("tag run", "name") | ("tag raw", "name") | ("tag copy", "name") | ("tag info", "name") => {
tag_names_autocomplete(assyst.clone(), guild_id.unwrap().get()).await
},
("tag edit", "name") | ("tag delete", "name") => {
tag_names_autocomplete_for_user(assyst.clone(), guild_id.unwrap().get(), user_id.get()).await
},
_ => {
err!("Trying to autocomplete for invalid command: {command_full_name} (arg {option})");
return;
Expand Down
16 changes: 14 additions & 2 deletions assyst-core/src/command/misc/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub async fn create(ctxt: CommandCtxt<'_>, name: Word, contents: RestNoFlags) ->
examples = ["test hello there", "script 2+2 is: {js:2+2}"],
guild_only = true
)]
pub async fn edit(ctxt: CommandCtxt<'_>, name: Word, contents: RestNoFlags) -> anyhow::Result<()> {
pub async fn edit(ctxt: CommandCtxt<'_>, name: WordAutocomplete, contents: RestNoFlags) -> anyhow::Result<()> {
let author = ctxt.data.author.id.get();
let Some(guild_id) = ctxt.data.guild_id else {
bail!("Tags can only be edited in guilds.")
Expand Down Expand Up @@ -130,7 +130,7 @@ pub async fn edit(ctxt: CommandCtxt<'_>, name: Word, contents: RestNoFlags) -> a
examples = ["test", "script"],
guild_only = true
)]
pub async fn delete(ctxt: CommandCtxt<'_>, name: Word) -> anyhow::Result<()> {
pub async fn delete(ctxt: CommandCtxt<'_>, name: WordAutocomplete) -> anyhow::Result<()> {
let author = ctxt.data.author.id.get();
let Some(guild_id) = ctxt.data.guild_id else {
bail!("Tags can only be deleted in guilds.")
Expand Down Expand Up @@ -881,6 +881,18 @@ pub async fn tag_names_autocomplete(assyst: ThreadSafeAssyst, guild_id: u64) ->
Tag::get_names_in_guild(&assyst.database_handler, guild_id as i64)
.await
.unwrap_or(vec![])
.iter()
.map(|x| x.1.clone())
.collect::<Vec<_>>()
}

pub async fn tag_names_autocomplete_for_user(assyst: ThreadSafeAssyst, guild_id: u64, user_id: u64) -> Vec<String> {
Tag::get_names_in_guild(&assyst.database_handler, guild_id as i64)
.await
.unwrap_or(vec![])
.iter()
.filter_map(|x| if x.0 == user_id { Some(x.1.clone()) } else { None })
.collect::<Vec<_>>()
}

#[command(
Expand Down
2 changes: 1 addition & 1 deletion assyst-core/src/downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn download_with_proxy(
limit: usize,
) -> Result<impl Stream<Item = Result<Bytes, reqwest::Error>>, DownloadError> {
let resp = client
.get(&format!("{}/proxy", get_next_proxy()))
.get(format!("{}/proxy", get_next_proxy()))
.query(&[("url", url), ("limit", &limit.to_string())])
.timeout(Duration::from_secs(10))
.send()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub async fn handle(assyst: ThreadSafeAssyst, InteractionCreate(interaction): In
// look at entitlements to see if there is anything new - we can cache this if so
// this usually shouldnt happen except for some edge cases such as a new entitlement was created
// when the bot was down
let entitlements = interaction.entitlements;
let entitlements = interaction.entitlements.clone();
let lock = assyst.entitlements.lock().unwrap().clone();
let mut new = vec![];
for entitlement in entitlements {
Expand Down Expand Up @@ -266,7 +266,7 @@ pub async fn handle(assyst: ThreadSafeAssyst, InteractionCreate(interaction): In
},
}
} else if interaction.kind == InteractionType::ApplicationCommandAutocomplete
&& let Some(InteractionData::ApplicationCommand(command_data)) = interaction.data
&& let Some(InteractionData::ApplicationCommand(command_data)) = interaction.data.clone()
{
let command = find_command_by_name(&command_data.name);
let subcommand_data = parse_subcommand_data(&command_data);
Expand Down Expand Up @@ -311,11 +311,14 @@ pub async fn handle(assyst: ThreadSafeAssyst, InteractionCreate(interaction): In
command.metadata().name.to_owned()
};

let author_id = interaction.author().unwrap().id;

handle_autocomplete(
assyst.clone(),
interaction.id,
interaction.token,
interaction.guild_id,
author_id,
&full_name,
&focused_option.name,
&inner_option.0,
Expand Down
6 changes: 3 additions & 3 deletions assyst-database/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct DatabaseCache {
global_blacklist: Cache<u64, bool>,
disabled_commands: Cache<u64, Arc<Mutex<HashSet<String>>>>,
copied_tags: Cache<u64 /* user id */, String /* content */>,
guild_tag_names: Cache<u64, Vec<String>>,
guild_tag_names: Cache<u64, Vec<(u64 /* auhtor id */, String)>>,
}
impl DatabaseCache {
pub fn new() -> Self {
Expand Down Expand Up @@ -143,11 +143,11 @@ impl DatabaseCache {
self.copied_tags.get(&user_id)
}

pub fn insert_guild_tag_names(&self, guild_id: u64, names: Vec<String>) {
pub fn insert_guild_tag_names(&self, guild_id: u64, names: Vec<(u64, String)>) {
self.guild_tag_names.insert(guild_id, names);
}

pub fn get_guild_tag_names(&self, guild_id: u64) -> Option<Vec<String>> {
pub fn get_guild_tag_names(&self, guild_id: u64) -> Option<Vec<(u64, String)>> {
self.guild_tag_names.get(&guild_id)
}
}
Expand Down
11 changes: 9 additions & 2 deletions assyst-database/src/model/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,22 @@ impl Tag {
.await
}

pub async fn get_names_in_guild(handler: &DatabaseHandler, guild_id: i64) -> Result<Vec<String>, sqlx::Error> {
pub async fn get_names_in_guild(
handler: &DatabaseHandler,
guild_id: i64,
) -> Result<Vec<(u64, String)>, sqlx::Error> {
if let Some(g) = handler.cache.get_guild_tag_names(guild_id as u64) {
return Ok(g);
}

let query = r#"SELECT * FROM tags WHERE guild_id = $1"#;

let result: Vec<Tag> = sqlx::query_as(query).bind(guild_id).fetch_all(&handler.pool).await?;
let names = result.iter().map(|x| &x.name).cloned().collect::<Vec<_>>();
let names = result
.iter()
.map(|x| (x.author as u64, x.name.clone()))
.collect::<Vec<_>>();

handler.cache.insert_guild_tag_names(guild_id as u64, names.clone());

Ok(names)
Expand Down
7 changes: 3 additions & 4 deletions assyst-proc-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
let mut parse_idents = Vec::new();
let mut parse_exprs = Vec::new();
let mut parse_usage = Vec::new();
let mut parse_attrs = Vec::new();
let mut interaction_parse_exprs = Vec::new();
let mut command_option_exprs = Vec::new();

Expand All @@ -84,7 +85,7 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
for (index, input) in item.sig.inputs.iter().skip(1).enumerate() {
match input {
FnArg::Receiver(_) => panic!("#[command] cannot have `self` arguments"),
FnArg::Typed(PatType { ty, pat, .. }) => {
FnArg::Typed(PatType { ty, pat, attrs, .. }) => {
if let Pat::Ident(ident) = &**pat {
let ident_string = ident.ident.to_string();

Expand All @@ -93,6 +94,7 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
}});
}

parse_attrs.push((stringify!(#pat).to_string(), attrs.clone()));
parse_idents.push(Ident::new(&format!("p{index}"), Span::call_site()));
parse_exprs.push(quote!(<#ty>::parse_raw_message(&mut ctxt, Some((stringify!(#pat).to_string(), stringify!(#ty).to_string()))).await));
parse_usage.push(quote!(<#ty as crate::command::arguments::ParseArgument>::usage(stringify!(#pat))));
Expand Down Expand Up @@ -173,9 +175,6 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
default_member_permissions: None,
description: meta.description.to_owned(),
description_localizations: None,
// TODO: set based on if dms are allowed
// TODO: update to `contexts` once this is required
// (see https://discord.com/developers/docs/interactions/application-commands#create-global-application-command)
dm_permission: Some(true),
guild_id: None,
id: None,
Expand Down

0 comments on commit 1ab1bb8

Please sign in to comment.