Skip to content

Commit

Permalink
- Fix placement of version checking in asyncio loop (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
nwithan8 authored Mar 14, 2024
1 parent d118b0e commit c0c7793
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
17 changes: 12 additions & 5 deletions modules/discord/discord_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import modules.statics as statics
import modules.system_stats as system_stats
import modules.tautulli.tautulli_connector
from modules import emojis, utils
from modules import emojis, utils, versioning
from modules.discord.command_manager import CommandManager
from modules.emojis import EmojiManager
from modules.settings.settings_transports import LibraryVoiceChannelsVisibilities
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(self,
performance_monitoring: dict,
enable_slash_commands: bool,
analytics,
new_version_available_func: callable,
version_checker: versioning.VersionChecker,
thousands_separator: str = ""):
self.token = token
self.guild_id = guild_id
Expand Down Expand Up @@ -269,12 +269,12 @@ def __init__(self,
admin_ids=self.admin_ids,
)

self.new_version_available_func = new_version_available_func
self.version_checker = version_checker

def connect(self) -> None:
logging.info('Connecting to Discord...')
# TODO: Look at merging loggers
self.client.run(self.token, reconnect=True)
self.client.run(self.token, reconnect=True) # Can't have any asyncio tasks running before the bot is ready

@property
def stats_voice_category_name(self) -> str:
Expand All @@ -292,6 +292,8 @@ def get_voice_channel_id(self, key: str) -> int:
return self.voice_channel_settings.get(statics.KEY_STATS_CHANNEL_IDS, {}).get(key, 0)

async def on_ready(self) -> None:
# NOTE: This is the first place in the stack where asyncio tasks can be started

logging.info('Connected to Discord.')

logging.info('Setting up Discord slash commands...')
Expand Down Expand Up @@ -349,6 +351,11 @@ async def on_ready(self) -> None:
asyncio.create_task(
self.run_performance_monitoring_service(refresh_time=5 * 60)) # Hard-coded 5-minute refresh time

if self.version_checker:
logging.info("Starting version checking service...")
# noinspection PyAsyncCall
asyncio.create_task(self.version_checker.check_for_new_version())

async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent) -> None:
emoji = payload.emoji
user_id = payload.user_id
Expand Down Expand Up @@ -452,7 +459,7 @@ async def edit_summary_data(self, previous_message) -> discord.Message:
:return: new discord.Message
"""
embed_fields = []
if self.new_version_available_func():
if self.version_checker.is_new_version_available():
embed_fields.append(
{"name": "🔔 New Version Available",
"value": f"A new version of Tauticord is available! [Click here]({consts.GITHUB_REPO_FULL_LINK}) to download it."})
Expand Down
19 changes: 19 additions & 0 deletions modules/versioning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Union

import objectrest
Expand Down Expand Up @@ -58,3 +59,21 @@ def newer_version_available() -> bool:
return _newer_github_commit_available(current_commit_hash=current_version)
else:
return _newer_github_release_available(current_version=current_version)


class VersionChecker:
def __init__(self):
self._new_version_available = False

async def check_for_new_version(self):
while True:
try:
self._new_version_available = newer_version_available()
if self._new_version_available:
logging.debug(f"New version available")
await asyncio.sleep(60 * 60) # Check for new version every hour
except Exception:
exit(1) # Die on any unhandled exception for this subprocess (i.e. internet connection loss)

def is_new_version_available(self) -> bool:
return self._new_version_available
30 changes: 3 additions & 27 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Tauticord is released as-is under the "GNU General Public License".
# Please see the LICENSE file that should have been included as part of this package.
import argparse
import asyncio

import modules.discord.discord_connector as discord
import modules.logs as logging
Expand Down Expand Up @@ -57,34 +56,11 @@
anonymous_ip=True,
do_not_track=not config.extras.allow_analytics)

NEW_VERSION_AVAILABLE = False


def is_new_version_available() -> bool:
return NEW_VERSION_AVAILABLE


async def check_for_new_version():
global NEW_VERSION_AVAILABLE
while True:
try:
NEW_VERSION_AVAILABLE = versioning.newer_version_available()
if NEW_VERSION_AVAILABLE:
logging.debug(f"New version available")
await asyncio.sleep(60 * 60) # Check for new version every hour
except Exception:
exit(1) # Die on any unhandled exception for this subprocess (i.e. internet connection loss)


async def start():
def start():
logging.info(splash_logo())
logging.info("Starting Tauticord...")

# Set up version checking schedule
logging.info("Setting up version checking schedule")
# noinspection PyAsyncCall
asyncio.create_task(check_for_new_version()) # This is purposefully not awaited, it's a background task

# noinspection PyBroadException
try:
logging.info("Setting up Tautulli connector")
Expand Down Expand Up @@ -117,7 +93,7 @@ async def start():
nitro=config.discord.has_discord_nitro,
performance_monitoring=config.performance,
analytics=analytics,
new_version_available_func=is_new_version_available,
version_checker=versioning.VersionChecker(),
)
discord_connector.connect()
except Exception as e:
Expand All @@ -128,4 +104,4 @@ async def start():


if __name__ == '__main__':
asyncio.run(start())
start()

0 comments on commit c0c7793

Please sign in to comment.