Skip to content

Commit

Permalink
add reminder task, refactor & cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Jun 8, 2024
1 parent 8b35f2e commit 0c566c7
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 62 deletions.
8 changes: 8 additions & 0 deletions assyst-common/src/util/discord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ pub fn id_from_mention(word: &str) -> Option<u64> {
pub fn format_tag(user: &User) -> String {
format!("{}#{}", user.name, user.discriminator)
}

/// Generates a message link
pub fn message_link(guild_id: u64, channel_id: u64, message_id: u64) -> String {
format!(
"https://discord.com/channels/{}/{}/{}",
guild_id, channel_id, message_id
)
}
18 changes: 9 additions & 9 deletions assyst-core/src/assyst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ pub struct Assyst {
/// HTTP client for Discord. Handles all HTTP requests to Discord, storing stateful information
/// about current ratelimits.
pub http_client: Arc<HttpClient>,
/// List of the current patrons to Assyst.
pub patrons: Arc<Mutex<Vec<Patron>>>,
/// List of the current premim users of Assyst.
pub premium_users: Arc<Mutex<Vec<Patron>>>,
/// Metrics handler for Prometheus, rate trackers etc.
pub metrics_handler: Arc<MetricsHandler>,
/// The reqwest client, used to issue general HTTP requests
Expand All @@ -51,30 +51,30 @@ impl Assyst {
let shard_count = http_client.gateway().authed().await?.model().await?.shards as u64;
let database_handler =
Arc::new(DatabaseHandler::new(CONFIG.database.to_url(), CONFIG.database.to_url_safe()).await?);
let patrons = Arc::new(Mutex::new(vec![]));
let premium_users = Arc::new(Mutex::new(vec![]));

Ok(Assyst {
persistent_cache_handler: PersistentCacheHandler::new(CACHE_PIPE_PATH),
database_handler: database_handler.clone(),
http_client: http_client.clone(),
patrons: patrons.clone(),
premium_users: premium_users.clone(),
metrics_handler: Arc::new(MetricsHandler::new(database_handler.clone())?),
reqwest_client: reqwest::Client::new(),
tasks: Mutex::new(vec![]),
shard_count,
replies: Replies::new(),
wsi_handler: WsiHandler::new(database_handler.clone(), patrons.clone()),
wsi_handler: WsiHandler::new(database_handler.clone(), premium_users.clone()),
rest_cache_handler: RestCacheHandler::new(http_client.clone()),
command_ratelimits: CommandRatelimits::new(),
})
}

/// Register a new Task to Assyst.
pub async fn register_task(&self, task: Task) {
/// Register a new `Task` to Assyst.
pub fn register_task(&self, task: Task) {
self.tasks.lock().unwrap().push(task);
}

pub async fn update_patron_list(&self, patrons: Vec<Patron>) {
*self.patrons.lock().unwrap() = patrons;
pub fn update_premium_user_list(&self, patrons: Vec<Patron>) {
*self.premium_users.lock().unwrap() = patrons;
}
}
2 changes: 1 addition & 1 deletion assyst-core/src/command/services/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub async fn burntext(ctxt: CommandCtxt<'_>, text: Rest) -> anyhow::Result<()> {
)]
pub async fn r34(ctxt: CommandCtxt<'_>, tags: Rest) -> anyhow::Result<()> {
let result = get_random_r34(ctxt.assyst().clone(), &tags.0).await?;
let reply = format!("{} (Score: {})", result.file_url, result.score);
let reply = format!("{} (Score: **{}**)", result.file_url, result.score);

ctxt.reply(reply).await?;

Expand Down
10 changes: 5 additions & 5 deletions assyst-core/src/gateway_handler/incoming_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use twilight_model::gateway::payload::incoming::{

#[derive(Debug)]
pub enum IncomingEvent {
MessageCreate(Box<MessageCreate>), // this struct is huge.
MessageUpdate(MessageUpdate),
MessageDelete(MessageDelete),
GuildCreate(Box<GuildCreate>), // same problem
ChannelUpdate(ChannelUpdate),
GuildCreate(Box<GuildCreate>), // this struct is huge.
GuildDelete(GuildDelete),
GuildUpdate(GuildUpdate),
MessageCreate(Box<MessageCreate>), // same problem
MessageDelete(MessageDelete),
MessageUpdate(MessageUpdate),
ShardReady(Ready),
ChannelUpdate(ChannelUpdate),
}
impl TryFrom<GatewayEvent> for IncomingEvent {
type Error = ();
Expand Down
42 changes: 25 additions & 17 deletions assyst-core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::Duration;

use crate::assyst::{Assyst, ThreadSafeAssyst};
use crate::task::tasks::get_patrons::get_patrons;
use crate::task::tasks::get_premium_users::get_premium_users;
use crate::task::tasks::top_gg_stats::post_top_gg_stats;
use crate::task::Task;
use assyst_common::config::config::LoggingWebhook;
Expand All @@ -21,6 +21,7 @@ use assyst_common::{err, ok_or_break};
use gateway_handler::handle_raw_event;
use gateway_handler::incoming_event::IncomingEvent;
use rest::web_media_download::get_web_download_api_urls;
use task::tasks::reminders::handle_reminders;
use tokio::spawn;
use tracing::{debug, info, trace};
use twilight_gateway::EventTypeFlags;
Expand Down Expand Up @@ -90,30 +91,37 @@ async fn main() {
);
}

assyst
.register_task(Task::new(
assyst.clone(),
// 10 mins
Duration::from_secs(60 * 10),
function_task_callback!(get_patrons),
))
.await;
assyst.register_task(Task::new(
assyst.clone(),
// 10 mins
Duration::from_secs(60 * 10),
function_task_callback!(get_premium_users),
));
info!("Registered patreon synchronisation task");

if !CONFIG.dev.disable_bot_list_posting {
assyst
.register_task(Task::new(
assyst.clone(),
// 10 mins
Duration::from_secs(60 * 10),
function_task_callback!(post_top_gg_stats),
))
.await;
assyst.register_task(Task::new(
assyst.clone(),
// 10 mins
Duration::from_secs(60 * 10),
function_task_callback!(post_top_gg_stats),
));
info!("Registered top.gg stats POSTing task");
} else {
info!("Bot list POSTing disabled in config.dev.disable_bot_list_posting: not registering task");
}

if !CONFIG.dev.disable_reminder_check {
assyst.register_task(Task::new(
assyst.clone(),
Duration::from_millis(crate::task::tasks::reminders::FETCH_INTERVAL as u64),
function_task_callback!(handle_reminders),
));
info!("Registered reminder check task");
} else {
info!("Reminder processing disabled in config.dev.disable_reminder_check: not registering task");
}

info!("Caching web download API URLs");
let web_download_urls = get_web_download_api_urls(assyst.clone()).await.unwrap();
info!("Got {} URLs to cache", web_download_urls.len());
Expand Down
8 changes: 0 additions & 8 deletions assyst-core/src/rest/patreon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,6 @@ pub async fn get_patrons(assyst: ThreadSafeAssyst) -> anyhow::Result<Vec<Patron>
};
}

for i in CONFIG.dev.admin_users.iter() {
patrons.push(Patron {
user_id: *i,
tier: PatronTier::Tier4,
_admin: true,
})
}

Ok(patrons)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use assyst_common::err;
use tracing::info;

/// Synchronises Assyst with an updated list of patrons.
pub async fn get_patrons(assyst: ThreadSafeAssyst) {
pub async fn get_premium_users(assyst: ThreadSafeAssyst) {
let mut premium_users: Vec<Patron> = vec![];

if !CONFIG.dev.disable_patreon_synchronisation {
info!("Synchronising patron list");

Expand All @@ -18,20 +20,20 @@ pub async fn get_patrons(assyst: ThreadSafeAssyst) {
},
};

assyst.update_patron_list(patrons.clone()).await;
premium_users.extend(patrons.into_iter());

info!("Synchronised patrons: total {}", patrons.len());
} else {
let mut patrons: Vec<Patron> = vec![];
info!("Synchronised patrons from Patreon: total {}", premium_users.len());
}

for i in CONFIG.dev.admin_users.iter() {
patrons.push(Patron {
user_id: *i,
tier: PatronTier::Tier4,
_admin: true,
})
}
// todo: load premium users via entitlements once twilight supports this

assyst.update_patron_list(patrons.clone()).await;
for i in CONFIG.dev.admin_users.iter() {
premium_users.push(Patron {
user_id: *i,
tier: PatronTier::Tier4,
_admin: true,
})
}

assyst.update_premium_user_list(premium_users.clone());
}
3 changes: 2 additions & 1 deletion assyst-core/src/task/tasks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod get_patrons;
pub mod get_premium_users;
pub mod reminders;
pub mod top_gg_stats;
68 changes: 68 additions & 0 deletions assyst-core/src/task/tasks/reminders.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use assyst_common::err;
use assyst_common::util::discord::message_link;
use assyst_database::model::reminder::Reminder;
use twilight_model::channel::message::AllowedMentions;
use twilight_model::id::marker::{ChannelMarker, UserMarker};
use twilight_model::id::Id;

use crate::assyst::ThreadSafeAssyst;

// 30 seconds
pub static FETCH_INTERVAL: i64 = 30000;

async fn process_single_reminder(assyst: ThreadSafeAssyst, reminder: &Reminder) -> anyhow::Result<()> {
assyst
.http_client
.create_message(Id::<ChannelMarker>::new(reminder.channel_id as u64))
.allowed_mentions(Some(&AllowedMentions {
parse: vec![],
replied_user: false,
roles: vec![],
users: vec![Id::<UserMarker>::new(reminder.user_id as u64)],
}))
.content(&format!(
"<@{}> Reminder: {}\n{}",
reminder.user_id,
reminder.message,
message_link(
reminder.guild_id as u64,
reminder.channel_id as u64,
reminder.message_id as u64
)
))
.await?;

Ok(())
}

async fn process_reminders(assyst: ThreadSafeAssyst, reminders: Vec<Reminder>) -> Result<(), anyhow::Error> {
if reminders.len() < 1 {
return Ok(());
}

for reminder in &reminders {
if let Err(e) = process_single_reminder(assyst.clone(), &reminder).await {
err!("Failed to process reminder: {:?}", e);
}

// Once we're done, delete them from database
reminder.remove(&assyst.database_handler).await?;
}

Ok(())
}

pub async fn handle_reminders(assyst: ThreadSafeAssyst) {
let reminders = Reminder::fetch_expiring_max(&assyst.database_handler, FETCH_INTERVAL).await;

match reminders {
Ok(reminders) => {
if let Err(e) = process_reminders(assyst.clone(), reminders).await {
err!("Processing reminder queue failed: {:?}", e);
}
},
Err(e) => {
err!("Fetching reminders failed: {:?}", e);
},
}
}
16 changes: 8 additions & 8 deletions assyst-core/src/wsi_handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ pub type WsiSender = (Sender<JobResult>, FifoSend, usize);

pub struct WsiHandler {
database_handler: Arc<DatabaseHandler>,
patrons: Arc<Mutex<Vec<Patron>>>,
premium_users: Arc<Mutex<Vec<Patron>>>,
pub wsi_tx: UnboundedSender<WsiSender>,
}
impl WsiHandler {
pub fn new(database_handler: Arc<DatabaseHandler>, patrons: Arc<Mutex<Vec<Patron>>>) -> WsiHandler {
pub fn new(database_handler: Arc<DatabaseHandler>, premium_users: Arc<Mutex<Vec<Patron>>>) -> WsiHandler {
let (tx, rx) = unbounded_channel::<WsiSender>();
Self::listen(rx, &CONFIG.urls.wsi);
WsiHandler {
wsi_tx: tx,
database_handler,
patrons,
premium_users,
}
}

Expand Down Expand Up @@ -163,16 +163,16 @@ impl WsiHandler {
/// and are not a patron!
pub async fn get_request_tier(&self, user_id: u64) -> Result<usize, anyhow::Error> {
if let Some(p) = {
let patrons = self.patrons.lock().unwrap();
patrons.iter().find(|i| i.user_id == user_id).cloned()
let premium_users = self.premium_users.lock().unwrap();
premium_users.iter().find(|i| i.user_id == user_id).cloned()
} {
return Ok(p.tier as usize);
}

let user_tier1 = FreeTier2Requests::get_user_free_tier_2_requests(&*self.database_handler, user_id).await?;
let user_tier2 = FreeTier2Requests::get_user_free_tier_2_requests(&*self.database_handler, user_id).await?;

if user_tier1.count > 0 {
user_tier1
if user_tier2.count > 0 {
user_tier2
.change_free_tier_2_requests(&*self.database_handler, -1)
.await?;
Ok(2)
Expand Down
1 change: 1 addition & 0 deletions assyst-database/src/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod free_tier_2_requests;
pub mod global_blacklist;
pub mod prefix;
pub mod reminder;
pub mod user_votes;
42 changes: 42 additions & 0 deletions assyst-database/src/model/reminder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::DatabaseHandler;
use std::time::{SystemTime, UNIX_EPOCH};

#[derive(sqlx::FromRow, Debug)]
pub struct Reminder {
pub id: i32,
pub user_id: i64,
pub timestamp: i64,
pub guild_id: i64,
pub channel_id: i64,
pub message_id: i64,
pub message: String,
}
impl Reminder {
pub async fn fetch_expiring_max(handler: &DatabaseHandler, time_delta: i64) -> Result<Vec<Self>, sqlx::Error> {
let query = "SELECT * FROM reminders WHERE timestamp < $1";

let unix: i64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis()
.try_into()
.expect("count not fit u128 into target type");

sqlx::query_as::<_, Self>(query)
.bind(unix + time_delta)
.fetch_all(&handler.pool)
.await
}

/// True on successful remove, false otherwise
pub async fn remove(&self, handler: &DatabaseHandler) -> Result<bool, sqlx::Error> {
let query = r#"DELETE FROM reminders WHERE user_id = $1 AND id = $2 RETURNING *"#;

sqlx::query(query)
.bind(self.user_id as i64)
.bind(self.id)
.fetch_all(&handler.pool)
.await
.map(|s| !s.is_empty())
}
}
7 changes: 7 additions & 0 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ vote = { token = "", id = 0 }
# Whether to use the webhooks on vote, panic, and error.
enable_webhooks = true

# Entitlements are the subscriptions for the app within Discord. You can probably leave these zeroed and the system will ignore them.
[entitlements]
tier_1_entitlement_id = 0
tier_2_entitlement_id = 0
tier_3_entitlement_id = 0
tier_4_entitlement_id = 0

[dev]
# These Discord user IDs have full control of the bot, including developer-only commands.
# Also grants max-tier premium access.
Expand Down

0 comments on commit 0c566c7

Please sign in to comment.