diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 472e096..fbf7c2c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,9 +41,13 @@ jobs: - name: Test with pytest id: test env: + CI_EVENT_ID: ${{ github.event.number || github.sha }} GITHUB_PYTEST: "true" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} DISCORD_BOT_TOKEN: ${{ secrets.DISCORD_TEST_BOT_TOKEN }} - DISCORD_WEBHOOK: ${{ secrets.DISCORD_TEST_BOT_WEBHOOK }} + DISCORD_GITHUB_STATUS_CHANNEL_ID: ${{ vars.DISCORD_GITHUB_STATUS_CHANNEL_ID }} + DISCORD_REDDIT_CHANNEL_ID: ${{ vars.DISCORD_REDDIT_CHANNEL_ID }} + DISCORD_SPONSORS_CHANNEL_ID: ${{ vars.DISCORD_SPONSORS_CHANNEL_ID }} GRAVATAR_EMAIL: ${{ secrets.GRAVATAR_EMAIL }} PRAW_CLIENT_ID: ${{ secrets.REDDIT_CLIENT_ID }} PRAW_CLIENT_SECRET: ${{ secrets.REDDIT_CLIENT_SECRET }} diff --git a/Dockerfile b/Dockerfile index 8cf0ff5..b1223ab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,19 +17,23 @@ ENV COMMIT=${COMMIT} ARG DAILY_TASKS=true ARG DAILY_RELEASES=true ARG DAILY_TASKS_UTC_HOUR=12 +ARG DISCORD_GITHUB_STATUS_CHANNEL_ID +ARG DISCORD_REDDIT_CHANNEL_ID +ARG DISCORD_SPONSORS_CHANNEL_ID # Secret config -ARG DISCORD_BOT_TOKEN ARG DAILY_CHANNEL_ID +ARG DISCORD_BOT_TOKEN +ARG DISCORD_CLIENT_ID +ARG DISCORD_CLIENT_SECRET +ARG DISCORD_REDIRECT_URI +ARG GITHUB_WEBHOOK_SECRET_KEY ARG GRAVATAR_EMAIL ARG IGDB_CLIENT_ID ARG IGDB_CLIENT_SECRET ARG PRAW_CLIENT_ID ARG PRAW_CLIENT_SECRET ARG PRAW_SUBREDDIT -ARG DISCORD_WEBHOOK -ARG GRAVATAR_EMAIL -ARG REDIRECT_URI # Environment variables ENV DAILY_TASKS=$DAILY_TASKS @@ -37,6 +41,13 @@ ENV DAILY_RELEASES=$DAILY_RELEASES ENV DAILY_CHANNEL_ID=$DAILY_CHANNEL_ID ENV DAILY_TASKS_UTC_HOUR=$DAILY_TASKS_UTC_HOUR ENV DISCORD_BOT_TOKEN=$DISCORD_BOT_TOKEN +ENV DISCORD_CLIENT_ID=$DISCORD_CLIENT_ID +ENV DISCORD_CLIENT_SECRET=$DISCORD_CLIENT_SECRET +ENV DISCORD_GITHUB_STATUS_CHANNEL_ID=$DISCORD_GITHUB_STATUS_CHANNEL_ID +ENV DISCORD_REDDIT_CHANNEL_ID=$DISCORD_REDDIT_CHANNEL_ID +ENV DISCORD_REDIRECT_URI=$DISCORD_REDIRECT_URI +ENV DISCORD_SPONSORS_CHANNEL_ID=$DISCORD_SPONSORS_CHANNEL_ID +ENV GITHUB_WEBHOOK_SECRET_KEY=$GITHUB_WEBHOOK_SECRET_KEY ENV GRAVATAR_EMAIL=$GRAVATAR_EMAIL ENV IGDB_CLIENT_ID=$IGDB_CLIENT_ID ENV IGDB_CLIENT_SECRET=$IGDB_CLIENT_SECRET @@ -44,8 +55,6 @@ ENV PRAW_CLIENT_ID=$PRAW_CLIENT_ID ENV PRAW_CLIENT_SECRET=$PRAW_CLIENT_SECRET ENV PRAW_SUBREDDIT=$PRAW_SUBREDDIT ENV DISCORD_WEBHOOK=$DISCORD_WEBHOOK -ENV GRAVATAR_EMAIL=$GRAVATAR_EMAIL -ENV REDIRECT_URI=$REDIRECT_URI SHELL ["/bin/bash", "-o", "pipefail", "-c"] # install dependencies @@ -69,7 +78,7 @@ RUN <<_SETUP set -e # replace the version in the code -sed -i "s/version = '0.0.0'/version = '${BUILD_VERSION}'/g" src/common.py +sed -i "s/version = '0.0.0'/version = '${BUILD_VERSION}'/g" src/common/common.py # install dependencies python -m pip install --no-cache-dir -r requirements.txt diff --git a/README.md b/README.md index 2d1f458..c8466a9 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![GitHub Workflow Status (CI)](https://img.shields.io/github/actions/workflow/status/lizardbyte/support-bot/ci.yml.svg?branch=master&label=CI%20build&logo=github&style=for-the-badge)](https://github.com/LizardByte/support-bot/actions/workflows/ci.yml?query=branch%3Amaster) [![Codecov](https://img.shields.io/codecov/c/gh/LizardByte/support-bot.svg?token=900Q93P1DE&style=for-the-badge&logo=codecov&label=codecov)](https://app.codecov.io/gh/LizardByte/support-bot) -Support bot written in python to help manage LizardByte communities. The current focus is discord and reddit, but other -platforms such as GitHub discussions/issues could be added. +Support bot written in python to help manage LizardByte communities. The current focus is Discord and Reddit, but other +platforms such as GitHub discussions/issues might be added in the future. ## Overview @@ -28,23 +28,25 @@ platforms such as GitHub discussions/issues could be added. :exclamation: if using Docker these can be arguments. :warning: Never publicly expose your tokens, secrets, or ids. -| variable | required | default | description | -|-------------------------|----------|------------------------------------------------------|---------------------------------------------------------------| -| DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | -| DAILY_TASKS | False | `true` | Daily tasks on or off. | -| DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | -| DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | -| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | -| GRAVATAR_EMAIL | False | `None` | Gravatar email address for bot avatar. | -| IGDB_CLIENT_ID | False | `None` | Required if daily_releases is enabled. | -| IGDB_CLIENT_SECRET | False | `None` | Required if daily_releases is enabled. | -| SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | -| SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | - -* Running bot: - * `python -m src` -* Invite bot to server: - * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` +| variable | required | default | description | +|----------------------------------|----------|------------------------------------------------------|---------------------------------------------------------------| +| DISCORD_BOT_TOKEN | True | `None` | Token from Bot page on discord developer portal. | +| DISCORD_CLIENT_ID | True | `None` | Discord OAuth2 client id. | +| DISCORD_CLIENT_SECRET | True | `None` | Discord OAuth2 client secret. | +| DISCORD_GITHUB_STATUS_CHANNEL_ID | True | `None` | Channel ID to send GitHub status updates to. | +| DISCORD_REDDIT_CHANNEL_ID | True | `None` | Channel ID to send Reddit post updates to. | +| DISCORD_REDIRECT_URI | False | `https://localhost:8080/discord/callback` | The redirect uri for OAuth2. Must be publicly accessible. | +| DISCORD_SPONSORS_CHANNEL_ID | True | `None` | Channel ID to send sponsorship updates to. | +| GITHUB_WEBHOOK_SECRET_KEY | True | `None` | A secret value to ensure webhooks are from trusted sources. | +| DAILY_TASKS | False | `true` | Daily tasks on or off. | +| DAILY_RELEASES | False | `true` | Send a message for each game released on this day in history. | +| DAILY_CHANNEL_ID | False | `None` | Required if daily_tasks is enabled. | +| DAILY_TASKS_UTC_HOUR | False | `12` | The hour to run daily tasks. | +| GRAVATAR_EMAIL | False | `None` | Gravatar email address for bot avatar. | +| IGDB_CLIENT_ID | False | `None` | Required if daily_releases is enabled. | +| IGDB_CLIENT_SECRET | False | `None` | Required if daily_releases is enabled. | +| SUPPORT_COMMANDS_REPO | False | `https://github.com/LizardByte/support-bot-commands` | Repository for support commands. | +| SUPPORT_COMMANDS_BRANCH | False | `master` | Branch for support commands. | ### Reddit @@ -62,7 +64,13 @@ platforms such as GitHub discussions/issues could be added. | DISCORD_WEBHOOK | False | None | URL of webhook to send discord notifications to | | GRAVATAR_EMAIL | False | None | Gravatar email address to get avatar from | | REDDIT_USERNAME | True | None | Reddit username | -* | REDDIT_PASSWORD | True | None | Reddit password | + | REDDIT_PASSWORD | True | None | Reddit password | + +### Start -* Running bot: - * `python -m src` +```bash +python -m src +``` + +* Invite bot to server: + * `https://discord.com/api/oauth2/authorize?client_id=&permissions=8&scope=bot%20applications.commands` diff --git a/assets/favicon.ico b/assets/favicon.ico new file mode 100644 index 0000000..79620bf Binary files /dev/null and b/assets/favicon.ico differ diff --git a/requirements-dev.txt b/requirements-dev.txt index 8efeb9b..b6e4092 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ betamax==0.9.0 betamax-serializers==0.2.1 pytest==8.3.3 -pytest-asyncio==0.24.0 pytest-cov==6.0.0 +pytest-mock==3.14.0 diff --git a/requirements.txt b/requirements.txt index 2f48774..d046f9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +cryptography==43.0.3 Flask==3.0.3 GitPython==3.1.43 igdb-api-v4==0.3.3 @@ -7,3 +8,4 @@ praw==7.8.1 py-cord==2.6.1 python-dotenv==1.0.1 requests==2.32.3 +requests-oauthlib==2.0.0 diff --git a/src/__main__.py b/src/__main__.py index 7968744..5411ceb 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,40 +1,33 @@ # standard imports -import os import time # development imports from dotenv import load_dotenv load_dotenv(override=False) # environment secrets take priority over .env file -# local imports -if True: # hack for flake8 - from src.discord import bot as d_bot - from src import keep_alive - from src.reddit import bot as r_bot +# local imports, import after env loaded +from src.common import globals # noqa: E402 +from src.discord import bot as d_bot # noqa: E402 +from src.common import webapp # noqa: E402 +from src.reddit import bot as r_bot # noqa: E402 def main(): - # to run in replit - try: - os.environ['REPL_SLUG'] - except KeyError: - pass # not running in replit - else: - keep_alive.keep_alive() # Start the web server + webapp.start() # Start the web server - discord_bot = d_bot.Bot() - discord_bot.start_threaded() # Start the discord bot + globals.DISCORD_BOT = d_bot.Bot() + globals.DISCORD_BOT.start_threaded() # Start the discord bot - reddit_bot = r_bot.Bot() - reddit_bot.start_threaded() # Start the reddit bot + globals.REDDIT_BOT = r_bot.Bot() + globals.REDDIT_BOT.start_threaded() # Start the reddit bot try: - while discord_bot.bot_thread.is_alive() or reddit_bot.bot_thread.is_alive(): + while globals.DISCORD_BOT.bot_thread.is_alive() or globals.REDDIT_BOT.bot_thread.is_alive(): time.sleep(0.5) except KeyboardInterrupt: print("Keyboard Interrupt Detected") - discord_bot.stop() - reddit_bot.stop() + globals.DISCORD_BOT.stop() + globals.REDDIT_BOT.stop() if __name__ == '__main__': # pragma: no cover diff --git a/src/common/__init__.py b/src/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/common.py b/src/common/common.py similarity index 89% rename from src/common.py rename to src/common/common.py index ef33e84..2cfbce4 100644 --- a/src/common.py +++ b/src/common/common.py @@ -36,15 +36,17 @@ def get_avatar_bytes(): return avatar_img -def get_data_dir(): +def get_app_dirs(): # parent directory name of this file, not full path - parent_dir = os.path.dirname(os.path.abspath(__file__)).split(os.sep)[-2] + parent_dir = os.path.dirname(os.path.abspath(__file__)).split(os.sep)[-3] if parent_dir == 'app': # running in Docker container + a = '/app' d = '/data' else: # running locally + a = os.getcwd() d = os.path.join(os.getcwd(), 'data') os.makedirs(d, exist_ok=True) - return d + return a, d # constants @@ -52,5 +54,5 @@ def get_data_dir(): org_name = 'LizardByte' bot_name = f'{org_name}-Bot' bot_url = 'https://app.lizardbyte.dev' -data_dir = get_data_dir() +app_dir, data_dir = get_app_dirs() version = '0.0.0' diff --git a/src/common/crypto.py b/src/common/crypto.py new file mode 100644 index 0000000..a59cf77 --- /dev/null +++ b/src/common/crypto.py @@ -0,0 +1,69 @@ +# standard imports +import os + +# lib imports +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption +from datetime import datetime, timedelta, UTC + +# local imports +from src.common import common + +CERT_FILE = os.path.join(common.data_dir, "cert.pem") +KEY_FILE = os.path.join(common.data_dir, "key.pem") + + +def check_expiration(cert_path: str) -> int: + with open(cert_path, "rb") as cert_file: + cert_data = cert_file.read() + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + expiry_date = cert.not_valid_after_utc + return (expiry_date - datetime.now(UTC)).days + + +def generate_certificate(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=4096, + ) + subject = issuer = x509.Name([ + x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"), + ]) + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.now(UTC) + ).not_valid_after( + datetime.now(UTC) + timedelta(days=365) + ).sign(private_key, hashes.SHA256()) + + with open(KEY_FILE, "wb") as f: + f.write(private_key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=NoEncryption(), + )) + + with open(CERT_FILE, "wb") as f: + f.write(cert.public_bytes(Encoding.PEM)) + + +def initialize_certificate() -> tuple[str, str]: + print("Initializing SSL certificate") + if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE): + cert_expires_in = check_expiration(CERT_FILE) + print(f"Certificate expires in {cert_expires_in} days.") + if cert_expires_in >= 90: + return CERT_FILE, KEY_FILE + print("Generating new certificate") + generate_certificate() + return CERT_FILE, KEY_FILE diff --git a/src/common/database.py b/src/common/database.py new file mode 100644 index 0000000..b27fdd4 --- /dev/null +++ b/src/common/database.py @@ -0,0 +1,22 @@ +# standard imports +import shelve +import threading + + +class Database: + def __init__(self, db_path): + self.db_path = db_path + self.lock = threading.Lock() + + def __enter__(self): + self.lock.acquire() + self.db = shelve.open(self.db_path, writeback=True) + return self.db + + def __exit__(self, exc_type, exc_val, exc_tb): + self.sync() + self.db.close() + self.lock.release() + + def sync(self): + self.db.sync() diff --git a/src/common/globals.py b/src/common/globals.py new file mode 100644 index 0000000..f185cab --- /dev/null +++ b/src/common/globals.py @@ -0,0 +1,2 @@ +DISCORD_BOT = None +REDDIT_BOT = None diff --git a/src/common/sponsors.py b/src/common/sponsors.py new file mode 100644 index 0000000..56b342a --- /dev/null +++ b/src/common/sponsors.py @@ -0,0 +1,73 @@ +# standard imports +import os +from typing import Union + +# lib imports +import requests + + +tier_map = { + 't4-sponsors': 15, + 't3-sponsors': 10, + 't2-sponsors': 5, + 't1-sponsors': 3, +} + + +def get_github_sponsors() -> Union[dict, False]: + """ + Get list of GitHub sponsors. + + Returns + ------- + Union[dict, False] + JSON response containing the list of sponsors. False if an error occurred. + """ + token = os.getenv("GITHUB_TOKEN") + org_name = os.getenv("GITHUB_ORG_NAME", "LizardByte") + + graphql_url = "https://api.github.com/graphql" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + edges { + node { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + name + monthlyPriceInDollars + } + } + } + } + } + } + """ % org_name + + response = requests.post(graphql_url, json={'query': query}, headers=headers) + data = response.json() + + if 'errors' in data or 'message' in data: + print(data) + print('::error::An error occurred while fetching sponsors.') + return False + + return data diff --git a/src/common/time.py b/src/common/time.py new file mode 100644 index 0000000..d0f0f86 --- /dev/null +++ b/src/common/time.py @@ -0,0 +1,19 @@ +# standard imports +import datetime + + +def iso_to_datetime(iso_str): + """ + Convert an ISO 8601 string to a datetime object. + + Parameters + ---------- + iso_str : str + The ISO 8601 string to convert. + + Returns + ------- + datetime.datetime + The datetime object. + """ + return datetime.datetime.fromisoformat(iso_str) diff --git a/src/common/webapp.py b/src/common/webapp.py new file mode 100644 index 0000000..313401a --- /dev/null +++ b/src/common/webapp.py @@ -0,0 +1,207 @@ +# standard imports +import os +from threading import Thread +from typing import Tuple + +# lib imports +import discord +from flask import Flask, jsonify, redirect, request, Response, send_from_directory +from requests_oauthlib import OAuth2Session +from werkzeug.middleware.proxy_fix import ProxyFix + +# local imports +from src.common.common import app_dir +from src.common import crypto +from src.common import globals +from src.common import time + + +DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") +DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") +DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", "https://localhost:8080/discord/callback") + +app = Flask( + import_name='LizardByte-bot', + static_folder=os.path.join(app_dir, 'assets'), +) + +# this allows us to log the real IP address of the client, instead of the IP address of the proxy host +app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1) + + +@app.route('/status') +def status(): + return "LizardByte-bot is live!" + + +@app.route("/favicon.ico") +def favicon(): + return send_from_directory( + directory=app.static_folder, + path="favicon.ico", + mimetype="image/vnd.microsoft.icon", + ) + + +@app.route("/discord/callback") +def discord_callback(): + # get all active states from the global state manager + with globals.DISCORD_BOT.db as db: + active_states = db['oauth_states'] + + discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) + token = discord_oauth.fetch_token("https://discord.com/api/oauth2/token", + client_secret=DISCORD_CLIENT_SECRET, + authorization_response=request.url) + + # Fetch the user's Discord profile + response = discord_oauth.get("https://discord.com/api/users/@me") + discord_user = response.json() + + # if the user is not in the active states, return an error + if discord_user['id'] not in active_states: + return "Invalid state" + + # remove the user from the active states + del active_states[discord_user['id']] + + # Fetch the user's connected accounts + connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") + connections = connections_response.json() + + with globals.DISCORD_BOT.db as db: + db['discord_users'] = db.get('discord_users', {}) + db['discord_users'][discord_user['id']] = { + 'discord_username': discord_user['username'], + 'discord_global_name': discord_user['global_name'], + 'github_id': None, + 'github_username': None, + 'token': token, # TODO: should we store the token at all? + } + + for connection in connections: + if connection['type'] == 'github': + db['discord_users'][discord_user['id']]['github_id'] = connection['id'] + db['discord_users'][discord_user['id']]['github_username'] = connection['name'] + + # Redirect to our main website + return redirect("https://app.lizardbyte.dev") + + +@app.route("/webhook//", methods=["POST"]) +def webhook(source: str, key: str) -> Tuple[Response, int]: + """ + Process webhooks from various sources. + + * GitHub sponsors: https://github.com/sponsors/LizardByte/dashboard/webhooks + * GitHub status: https://www.githubstatus.com + + Parameters + ---------- + source : str + The source of the webhook (e.g., 'github_sponsors', 'github_status'). + key : str + The secret key for the webhook. This must match an environment variable. + + Returns + ------- + flask.Response + Response to the webhook request + """ + valid_sources = [ + "github_sponsors", + "github_status", + ] + + if source not in valid_sources: + return jsonify({"status": "error", "message": "Invalid source"}), 400 + + if key != os.getenv("GITHUB_WEBHOOK_SECRET_KEY"): + return jsonify({"status": "error", "message": "Invalid key"}), 400 + + print(f"received webhook from {source}") + data = request.json + print(f"received webhook data: \n{data}") + + # process the webhook data + if source == "github_sponsors": + if data['action'] == "created": + message = f'New GitHub sponsor: {data["sponsorship"]["sponsor"]["login"]}' + + # create a discord embed + embed = discord.Embed( + author=discord.EmbedAuthor( + name=data["sponsorship"]["sponsor"]["login"], + url=data["sponsorship"]["sponsor"]["url"], + icon_url=data["sponsorship"]["sponsor"]["avatar_url"], + ), + color=0x00ff00, + description=message, + timestamp=time.iso_to_datetime(data['sponsorship']['created_at']), + title="New GitHub Sponsor", + ) + globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_SPONSORS_CHANNEL_ID"), + embeds=[embed], + ) + + elif source == "github_status": + # https://support.atlassian.com/statuspage/docs/enable-webhook-notifications + + embed = discord.Embed( + title="GitHub Status Update", + description=data['page']['status_description'], + color=0x00ff00, + ) + + # handle component updates + if 'component_update' in data: + component_update = data['component_update'] + component = data['component'] + embed = discord.Embed( + color=0x00ff00, + description=f"Status changed from {component_update['old_status']} to {component_update['new_status']}", + timestamp=time.iso_to_datetime(component_update['created_at']), + title=f"Component Update: {component['name']}", + ) + embed.add_field(name="Component ID", value=component['id']) + embed.add_field(name="Component Status", value=component['status']) + + # handle incident updates + if 'incident' in data: + incident = data['incident'] + embed = discord.Embed( + color=0xff0000, + description=incident['impact'], + timestamp=time.iso_to_datetime(incident['created_at']), + title=f"Incident: {incident['name']}", + url=incident.get('shortlink', 'https://www.githubstatus.com'), + ) + for update in incident['incident_updates']: + embed.add_field(name=update['status'], value=update['body'], inline=False) + + globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_GITHUB_STATUS_CHANNEL_ID"), + embeds=[embed], + ) + + return jsonify({"status": "success"}), 200 + + +def run(): + cert_file, key_file = crypto.initialize_certificate() + + app.run( + host="0.0.0.0", + port=8080, + ssl_context=(cert_file, key_file) + ) + + +def start(): + server = Thread( + name="Flask", + daemon=True, + target=run, + ) + server.start() diff --git a/src/discord/bot.py b/src/discord/bot.py index a9baf6c..ce69612 100644 --- a/src/discord/bot.py +++ b/src/discord/bot.py @@ -7,8 +7,8 @@ import discord # local imports -from src.common import bot_name, get_avatar_bytes, org_name -from src.discord.tasks import daily_task +from src.common.common import bot_name, data_dir, get_avatar_bytes, org_name +from src.common.database import Database from src.discord.views import DonateCommandView @@ -21,6 +21,9 @@ class Bot(discord.Bot): when the bot is ready. """ def __init__(self, *args, **kwargs): + # tasks need to be imported here to avoid circular imports + from src.discord.tasks import daily_task, hourly_task + if 'intents' not in kwargs: intents = discord.Intents.all() kwargs['intents'] = intents @@ -30,6 +33,9 @@ def __init__(self, *args, **kwargs): self.bot_thread = threading.Thread(target=lambda: None) self.token = os.environ['DISCORD_BOT_TOKEN'] + self.db = Database(db_path=os.path.join(data_dir, 'discord_bot_database')) + self.daily_task = daily_task + self.hourly_task = hourly_task self.load_extension( name='src.discord.cogs', @@ -37,12 +43,15 @@ def __init__(self, *args, **kwargs): store=False, ) + with self.db as db: + db['oauth_states'] = {} # clear any oauth states from previous sessions + async def on_ready(self): """ Bot on ready event. This function runs when the discord bot is ready. The function will update the bot presence, update the username - and avatar, and start daily tasks. + and avatar, and start tasks. """ print(f'py-cord version: {discord.__version__}') print(f'Logged in as {self.user.name} (ID: {self.user.id})') @@ -61,16 +70,84 @@ async def on_ready(self): await self.sync_commands() + self.hourly_task.start(bot=self) + try: os.environ['DAILY_TASKS'] except KeyError: - daily_task.start(bot=self) + self.daily_task.start(bot=self) else: if os.environ['DAILY_TASKS'].lower() == 'true': - daily_task.start(bot=self) + self.daily_task.start(bot=self) else: print("'DAILY_TASKS' environment variable is disabled") + async def async_send_message( + self, + channel_id: int, + message: str = None, + embeds: list[discord.Embed] = [], + ) -> discord.Message: + """ + Send a message to a specific channel asynchronously. + + Parameters + ---------- + channel_id : int + The ID of the channel to send the message to. + message : str, optional + The message to send. + embeds : list[discord.Embed], optional + A list of embeds to send. + + Returns + ------- + discord.Message + The message that was sent. + """ + # ensure embeds are within Discord's character limits + for embed in embeds: + if len(embed) > 6000: + cut_length = len(embed) - 6000 + 3 + embed.description = embed.description[:-cut_length] + "..." + if len(embed.description) > 4096: + cut_length = len(embed.description) - 4096 + 3 + embed.description = embed.description[:-cut_length] + "..." + + channel = await self.fetch_channel(channel_id) + return await channel.send(content=message, embeds=embeds) + + def send_message( + self, + channel_id: int, + message: str = None, + embeds: list[discord.Embed] = [], + ) -> discord.Message: + """ + Send a message to a specific channel synchronously. + + Parameters + ---------- + channel_id : int + The ID of the channel to send the message to. + message : str, optional + The message to send. + embeds : list[discord.Embed], optional + A list of embeds to send. + + Returns + ------- + discord.Message + The message that was sent. + """ + future = asyncio.run_coroutine_threadsafe( + self.async_send_message( + channel_id=channel_id, + message=message, + embeds=embeds, + ), self.loop) + return future.result() + def start_threaded(self): try: # Login the bot in a separate thread @@ -85,8 +162,9 @@ def start_threaded(self): self.stop() def stop(self, future: asyncio.Future = None): - print("Attempting to stop daily tasks") - daily_task.stop() + print("Attempting to stop tasks") + self.daily_task.stop() + self.hourly_task.stop() print("Attempting to close bot connection") if self.bot_thread is not None and self.bot_thread.is_alive(): asyncio.run_coroutine_threadsafe(self.close(), self.loop) diff --git a/src/discord/cogs/base_commands.py b/src/discord/cogs/base_commands.py index 99734d8..94b42f1 100644 --- a/src/discord/cogs/base_commands.py +++ b/src/discord/cogs/base_commands.py @@ -3,7 +3,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name, org_name, version +from src.common.common import avatar, bot_name, org_name, version from src.discord.views import DonateCommandView from src.discord import cogs_common diff --git a/src/discord/cogs/fun_commands.py b/src/discord/cogs/fun_commands.py index 98e53f2..820fd1d 100644 --- a/src/discord/cogs/fun_commands.py +++ b/src/discord/cogs/fun_commands.py @@ -7,7 +7,7 @@ import requests # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name from src.discord.views import RefundCommandView from src.discord import cogs_common diff --git a/src/discord/cogs/github_commands.py b/src/discord/cogs/github_commands.py new file mode 100644 index 0000000..22e48c8 --- /dev/null +++ b/src/discord/cogs/github_commands.py @@ -0,0 +1,85 @@ +# standard imports +import os + +# lib imports +import discord +from requests_oauthlib import OAuth2Session + +# local imports +from src.common import sponsors + + +class GitHubCommandsCog(discord.Cog): + def __init__(self, bot): + self.bot = bot + + @discord.slash_command( + name="get_sponsors", + description="Get list of GitHub sponsors", + default_member_permissions=discord.Permissions(manage_guild=True), + ) + async def get_sponsors( + self, + ctx: discord.ApplicationContext, + ): + """ + Get list of GitHub sponsors. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + data = sponsors.get_github_sponsors() + + if not data: + await ctx.respond("An error occurred while fetching sponsors.", ephemeral=True) + return + + message = "List of GitHub sponsors" + for edge in data['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + tier = edge['node'].get('tier', {}) + tier_info = f" - Tier: {tier.get('name', 'N/A')} (${tier.get('monthlyPriceInDollars', 'N/A')}/month)" + message += f"\n* [{sponsor['login']}]({sponsor['url']}){tier_info}" + + embed = discord.Embed(title="GitHub Sponsors", color=0x00ff00, description=message) + + await ctx.respond(embed=embed, ephemeral=True) + + @discord.slash_command( + name="link_github", + description="Validate GitHub sponsor status" + ) + async def link_github(self, ctx: discord.ApplicationContext): + """ + Link Discord account with GitHub account, by validating Discord user's "GitHub" connected account status. + + User to login to Discord via OAuth2, and check if their connected GitHub account is a sponsor of the project. + + Parameters + ---------- + ctx : discord.ApplicationContext + Request message context. + """ + discord_oauth = OAuth2Session( + os.environ['DISCORD_CLIENT_ID'], + redirect_uri=os.environ['DISCORD_REDIRECT_URI'], + scope=[ + "identify", + "connections", + ], + ) + authorization_url, state = discord_oauth.authorization_url("https://discord.com/oauth2/authorize") + + with self.bot.db as db: + db['oauth_states'] = db.get('oauth_states', {}) + db['oauth_states'][str(ctx.author.id)] = state + db.sync() + + # Store the state in the user's session or database + await ctx.respond(f"Please authorize the application by clicking [here]({authorization_url}).", ephemeral=True) + + +def setup(bot: discord.Bot): + bot.add_cog(GitHubCommandsCog(bot=bot)) diff --git a/src/discord/cogs/moderator_commands.py b/src/discord/cogs/moderator_commands.py index 2464b7d..2f88697 100644 --- a/src/discord/cogs/moderator_commands.py +++ b/src/discord/cogs/moderator_commands.py @@ -7,7 +7,7 @@ from discord.commands import Option # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name # constants recommended_channel_desc = 'Select the recommended channel' # hack for flake8 F722 diff --git a/src/discord/cogs/support_commands.py b/src/discord/cogs/support_commands.py index edb1502..8995d76 100644 --- a/src/discord/cogs/support_commands.py +++ b/src/discord/cogs/support_commands.py @@ -11,7 +11,7 @@ from mistletoe.markdown_renderer import MarkdownRenderer # local imports -from src.common import avatar, bot_name, data_dir +from src.common.common import avatar, bot_name, data_dir from src.discord.views import DocsCommandView from src.discord import cogs_common diff --git a/src/discord/tasks.py b/src/discord/tasks.py index d4249dd..8c11ada 100644 --- a/src/discord/tasks.py +++ b/src/discord/tasks.py @@ -1,5 +1,5 @@ # standard imports -from datetime import datetime +from datetime import datetime, UTC import json import os @@ -9,165 +9,224 @@ from igdb.wrapper import IGDBWrapper # local imports -from src.common import avatar, bot_name, bot_url +from src.common.common import avatar, bot_name, bot_url +from src.common import sponsors +from src.discord.bot import Bot from src.discord.helpers import igdb_authorization, month_dictionary @tasks.loop(minutes=60.0) -async def daily_task(bot: discord.Bot): +async def daily_task(bot: Bot): """ Run daily task loop. This function runs on a schedule, every 60 minutes. Create an embed and thread for each game released on this day in history (according to IGDB), if enabled. """ - if datetime.utcnow().hour == int(os.getenv(key='DAILY_TASKS_UTC_HOUR', default=12)): - daily_releases = True if os.getenv(key='DAILY_RELEASES', default='true').lower() == 'true' else False - if not daily_releases: - print("'DAILY_RELEASES' environment variable is disabled") + if datetime.now(UTC).hour != int(os.getenv(key='DAILY_TASKS_UTC_HOUR', default=12)): + return + + daily_releases = True if os.getenv(key='DAILY_RELEASES', default='true').lower() == 'true' else False + if not daily_releases: + print("'DAILY_RELEASES' environment variable is disabled") + return + + try: + channel = bot.get_channel(int(os.environ['DAILY_CHANNEL_ID'])) + except KeyError: + print("'DAILY_CHANNEL_ID' not defined in environment variables.") + return + + igdb_auth = igdb_authorization(client_id=os.environ['IGDB_CLIENT_ID'], + client_secret=os.environ['IGDB_CLIENT_SECRET']) + wrapper = IGDBWrapper(client_id=os.environ['IGDB_CLIENT_ID'], auth_token=igdb_auth['access_token']) + + end_point = 'release_dates' + fields = [ + 'human', + 'game.name', + 'game.summary', + 'game.url', + 'game.genres.name', + 'game.rating', + 'game.cover.url', + 'game.artworks.url', + 'game.platforms.name', + 'game.platforms.url' + ] + + where = f'human="{month_dictionary[datetime.now(UTC).month]} {datetime.now(UTC).day:02d}"*' + limit = 500 + query = f'fields {", ".join(fields)}; where {where}; limit {limit};' + + byte_array = bytes(wrapper.api_request(endpoint=end_point, query=query)) + json_result = json.loads(byte_array) + + game_ids = [] + + for game in json_result: + color = 0x9147FF + + try: + game_id = game['game']['id'] + except KeyError: + continue else: - try: - channel = bot.get_channel(int(os.environ['DAILY_CHANNEL_ID'])) - except KeyError: - print("'DAILY_CHANNEL_ID' not defined in environment variables.") - else: - igdb_auth = igdb_authorization(client_id=os.environ['IGDB_CLIENT_ID'], - client_secret=os.environ['IGDB_CLIENT_SECRET']) - wrapper = IGDBWrapper(client_id=os.environ['IGDB_CLIENT_ID'], auth_token=igdb_auth['access_token']) - - end_point = 'release_dates' - fields = [ - 'human', - 'game.name', - 'game.summary', - 'game.url', - 'game.genres.name', - 'game.rating', - 'game.cover.url', - 'game.artworks.url', - 'game.platforms.name', - 'game.platforms.url' - ] - - where = f'human="{month_dictionary[datetime.utcnow().month]} {datetime.utcnow().day:02d}"*' - limit = 500 - query = f'fields {", ".join(fields)}; where {where}; limit {limit};' - - byte_array = bytes(wrapper.api_request(endpoint=end_point, query=query)) - json_result = json.loads(byte_array) - - game_ids = [] - - for game in json_result: - color = 0x9147FF - - try: - game_id = game['game']['id'] - except KeyError: - continue + if game_id not in game_ids: + game_ids.append(game_id) + else: # do not repeat the same game... even though it could be a different platform + continue + + try: + embed = discord.Embed( + title=game['game']['name'], + url=game['game']['url'], + description=game['game']['summary'][0:2000 - 1], + color=color + ) + except KeyError: + continue + + try: + embed.add_field( + name='Release Date', + value=game['human'], + inline=True + ) + except KeyError: + pass + + try: + rating = round(game['game']['rating'] / 20, 1) + embed.add_field( + name='Average Rating', + value=f'⭐{rating}', + inline=True + ) + + if rating < 4.0: # reduce the number of messages per day + continue + except KeyError: + continue + + try: + embed.set_thumbnail( + url=f"https:{game['game']['cover']['url'].replace('_thumb', '_original')}" + ) + except KeyError: + pass + + try: + embed.set_image( + url=f"https:{game['game']['artworks'][0]['url'].replace('_thumb', '_original')}" + ) + except KeyError: + pass + + try: + platforms = ', '.join(platform['name'] for platform in game['game']['platforms']) + name = 'Platforms' if len(game['game']['platforms']) > 1 else 'Platform' + + embed.add_field( + name=name, + value=platforms, + inline=False + ) + except KeyError: + pass + + try: + genres = ', '.join(genre['name'] for genre in game['game']['genres']) + name = 'Genres' if len(game['game']['genres']) > 1 else 'Genre' + + embed.add_field( + name=name, + value=genres, + inline=False + ) + except KeyError: + pass + + embed.set_author( + name=bot_name, + url=bot_url, + icon_url=avatar + ) + + embed.set_footer( + text='Data provided by IGDB', + icon_url='https://www.igdb.com/favicon-196x196.png' + ) + + message = await channel.send(embed=embed) + thread = await message.create_thread(name=embed.title) + + print(f'thread created: {thread.name}') + + +@tasks.loop(minutes=1.0) +async def hourly_task(bot: Bot): + """ + Run hourly task loop. + + This function runs on a schedule, every 1 minute. + If the current time is not at the top of the hour, return. + """ + if datetime.now(UTC).minute != 0: + return + + # check each user in the database for their GitHub sponsor status + with bot.db as db: + discord_users = db.get('discord_users', {}) + + if not discord_users: + return + + github_sponsors = sponsors.get_github_sponsors() + + for user_id, user_data in discord_users.items(): + # check if the user is a GitHub sponsor + for edge in github_sponsors['data']['organization']['sponsorshipsAsMaintainer']['edges']: + sponsor = edge['node']['sponsorEntity'] + if sponsor['login'] == user_data['github_username']: + # user is a sponsor + user_data['github_sponsor'] = True + + monthly_amount = edge['node'].get('tier', {}).get('monthlyPriceInDollars', 0) + + for tier, amount in sponsors.tier_map.items(): + if monthly_amount >= amount: + user_data['sponsor_tiers'] = [tier, 'supporters'] + break + else: + user_data['sponsor_tiers'] = [] + + break + else: + # user is not a sponsor + user_data['github_sponsor'] = False + user_data['sponsor_tiers'] = [] + + # update the discord user roles + for g in bot.guilds: + roles = g.roles + + role_map = { + 't4-sponsors': discord.utils.get(roles, name='t4-sponsors'), + 't3-sponsors': discord.utils.get(roles, name='t3-sponsors'), + 't2-sponsors': discord.utils.get(roles, name='t2-sponsors'), + 't1-sponsors': discord.utils.get(roles, name='t1-sponsors'), + 'supporters': discord.utils.get(roles, name='supporters'), + } + + tiers = user_data['sponsor_tiers'] + + for tier, role in role_map.items(): + role = role_map.get(tier, None) + + if role: + member = g.get_member(user_id) + if tier in tiers: + await member.add_roles(role) else: - if game_id not in game_ids: - game_ids.append(game_id) - else: # do not repeat the same game... even though it could be a different platform - continue - - try: - embed = discord.Embed( - title=game['game']['name'], - url=game['game']['url'], - description=game['game']['summary'][0:2000 - 1], - color=color - ) - except KeyError: - continue - - try: - embed.add_field( - name='Release Date', - value=game['human'], - inline=True - ) - except KeyError: - pass - - try: - rating = round(game['game']['rating'] / 20, 1) - embed.add_field( - name='Average Rating', - value=f'⭐{rating}', - inline=True - ) - - if rating < 4.0: # reduce number of messages per day - continue - except KeyError: - continue - - try: - embed.set_thumbnail( - url=f"https:{game['game']['cover']['url'].replace('_thumb', '_original')}" - ) - except KeyError: - pass - - try: - embed.set_image( - url=f"https:{game['game']['artworks'][0]['url'].replace('_thumb', '_original')}" - ) - except KeyError: - pass - - try: - platforms = '' - name = 'Platform' - - for platform in game['game']['platforms']: - if platforms: - platforms += ", " - name = 'Platforms' - platforms += platform['name'] - - embed.add_field( - name=name, - value=platforms, - inline=False - ) - except KeyError: - pass - - try: - genres = '' - name = 'Genre' - - for genre in game['game']['genres']: - if genres: - genres += ", " - name = 'Genres' - genres += genre['name'] - - embed.add_field( - name=name, - value=genres, - inline=False - ) - except KeyError: - pass - - try: - embed.set_author( - name=bot_name, - url=bot_url, - icon_url=avatar - ) - except KeyError: - pass - - embed.set_footer( - text='Data provided by IGDB', - icon_url='https://www.igdb.com/favicon-196x196.png' - ) - - message = await channel.send(embed=embed) - thread = await message.create_thread(name=embed.title) - - print(f'thread created: {thread.name}') + await member.remove_roles(role) diff --git a/src/discord/views.py b/src/discord/views.py index 4435d8e..f5649d5 100644 --- a/src/discord/views.py +++ b/src/discord/views.py @@ -7,7 +7,7 @@ from discord.ui.button import Button # local imports -from src.common import avatar, bot_name +from src.common.common import avatar, bot_name from src.discord.helpers import get_json from src.discord.modals import RefundModal diff --git a/src/keep_alive.py b/src/keep_alive.py deleted file mode 100644 index 74ab1c9..0000000 --- a/src/keep_alive.py +++ /dev/null @@ -1,20 +0,0 @@ -from flask import Flask -from threading import Thread -import os - -app = Flask('') - - -@app.route('/') -def main(): - return f"{os.environ['REPL_SLUG']} is live!" - - -def run(): - app.run(host="0.0.0.0", port=8080) - - -def keep_alive(): - server = Thread(name="Flask", target=run) - server.setDaemon(daemonic=True) - server.start() diff --git a/src/reddit/bot.py b/src/reddit/bot.py index 7520b9e..0856bb8 100644 --- a/src/reddit/bot.py +++ b/src/reddit/bot.py @@ -1,18 +1,19 @@ # standard imports from datetime import datetime import os -import requests import shelve import sys import threading import time # lib imports +import discord import praw from praw import models # local imports -from src import common +from src.common import common +from src.common import globals class Bot: @@ -31,14 +32,7 @@ def __init__(self, **kwargs): self.user_agent = kwargs.get('user_agent', f'{common.bot_name} {self.version}') self.avatar = kwargs.get('avatar', common.get_bot_avatar(gravatar=os.environ['GRAVATAR_EMAIL'])) self.subreddit_name = kwargs.get('subreddit', os.getenv('PRAW_SUBREDDIT', 'LizardByte')) - - if not kwargs.get('redirect_uri', None): - try: # for running in replit - self.redirect_uri = f'https://{os.environ["REPL_SLUG"]}.{os.environ["REPL_OWNER"].lower()}.repl.co' - except KeyError: - self.redirect_uri = os.getenv('REDIRECT_URI', 'http://localhost:8080') - else: - self.redirect_uri = kwargs['redirect_uri'] + self.redirect_uri = kwargs.get('redirect_uri', os.getenv('REDIRECT_URI', 'http://localhost:8080')) # directories self.data_dir = common.data_dir @@ -66,7 +60,7 @@ def __init__(self, **kwargs): @staticmethod def validate_env() -> bool: required_env = [ - 'DISCORD_WEBHOOK', + 'DISCORD_REDDIT_CHANNEL_ID', 'PRAW_CLIENT_ID', 'PRAW_CLIENT_SECRET', 'REDDIT_PASSWORD', @@ -141,7 +135,7 @@ def process_submission(self, submission: models.Submission): print(f'submission id: {submission.id}') print(f'submission title: {submission.title}') print('---------') - if os.getenv('DISCORD_WEBHOOK'): + if os.getenv('DISCORD_REDDIT_CHANNEL_ID'): self.discord(submission=submission) self.flair(submission=submission) self.karma(submission=submission) @@ -175,37 +169,31 @@ def discord(self, submission: models.Submission): submission_time = datetime.fromtimestamp(submission.created_utc) - # create the discord message - # todo: use the running discord bot, directly instead of using a webhook - discord_webhook = { - 'username': 'LizardByte-Bot', - 'avatar_url': self.avatar, - 'embeds': [ - { - 'author': { - 'name': str(submission.author), - 'url': f'https://www.reddit.com/user/{submission.author}', - 'icon_url': str(redditor.icon_img) - }, - 'title': str(submission.title), - 'url': str(submission.url), - 'description': str(submission.selftext), - 'color': color, - 'thumbnail': { - 'url': 'https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png' - }, - 'footer': { - 'text': f'Posted on r/{self.subreddit_name} at {submission_time}', - 'icon_url': 'https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' - } - } - ] - } - - # actually send the message - r = requests.post(os.environ['DISCORD_WEBHOOK'], json=discord_webhook) - - if r.status_code == 204: # successful completion of request, no additional content + # create the discord embed + embed = discord.Embed( + author=discord.EmbedAuthor( + name=str(submission.author), + url=f'https://www.reddit.com/user/{submission.author}', + icon_url=str(redditor.icon_img), + ), + title=submission.title, + url=submission.url, + description=submission.selftext, + color=color, + thumbnail='https://www.redditstatic.com/desktop2x/img/snoo_discovery@1x.png', + footer=discord.EmbedFooter( + text=f'Posted on r/{self.subreddit_name} at {submission_time}', + icon_url='https://www.redditstatic.com/desktop2x/img/favicon/favicon-32x32.png' + ) + ) + + # actually send the embed + message = globals.DISCORD_BOT.send_message( + channel_id=os.getenv("DISCORD_REDDIT_CHANNEL_ID"), + embeds=[embed], + ) + + if message: with self.lock, shelve.open(self.db) as db: # the shelve doesn't update unless we recreate the main key submissions = db['submissions'] diff --git a/tests/conftest.py b/tests/conftest.py index a9455c6..be0ff46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,39 @@ +# standard imports +import os +import time + # lib imports import dotenv +import pytest + +# local imports +from src.common import globals dotenv.load_dotenv(override=False) # environment secrets take priority over .env file + +# import after env loaded +from src.discord import bot as d_bot # noqa: E402 + + +@pytest.fixture(scope='session') +def discord_bot(): + bot = d_bot.Bot() + bot.start_threaded() + globals.DISCORD_BOT = bot + + while not bot.is_ready(): # Wait until the bot is ready + time.sleep(1) + + yield bot + + bot.stop() + globals.DISCORD_BOT = None + + +@pytest.fixture(scope='function') +def no_github_token(): + og_token = os.getenv('GITHUB_TOKEN') + del os.environ['GITHUB_TOKEN'] + yield + + os.environ['GITHUB_TOKEN'] = og_token diff --git a/tests/fixtures/certs/expired/cert.pem b/tests/fixtures/certs/expired/cert.pem new file mode 100644 index 0000000..4084b6f --- /dev/null +++ b/tests/fixtures/certs/expired/cert.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEtDCCApygAwIBAgIUO9/D0xVF8jI0w7lJQEbuOAkpeXIwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI0MTEyMzAwNDUwM1oXDTI0MTEy +MzAwNDUwM1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAx2ytodLmx/I7DRe6JTGn98I/DEcRdow+f+6UjjIQczPB +jD97JsfV45eVaIWmRMjqn+A8zAKnsBdRpGlFwbAdG174cu/BLdNb/OoVxCSkiZpH +wtmuRVofOgo2VnFTjgG7gu/4GV7SIOsngz/uB+W7xw/GVQfEsDld3cjiLObn0aFv +D7oKE6wAP0gmbGrNDmvFikVUQIU0tuWfc6DLK2QEh00KOsOAjX9bwk7cOO6+W4Xg +EAqYx0XsyNPkWOG8D9FzttnC1UESSoOyr41ne6sn5knkxdk00dxAVpabt7Z1Do4u +NlAR324+u62GkMP9Pv0tXU8YZxnHFT5njXDJMKXwj7vWiHPzR7Ykw6/fCUOYOcop +rOweSUmevmUIGZKQVyKDrcLumGC4IPfv+UCPQyA3eUEyTPnkDawDRepha0FsZ7Pt +R/0Ftm7XW5u4HFMhRyrnDrHBGNywUg+bYGl7MWIqr0p3o9CVENDIRgLyuDrCnnMB +UFrqpbPp7Q6Z64ohdpvb+eJRYBCUJkbbFawUa/SXe6c5/cFAFwoNgN0UHcBvWKpV +INc5WJPgiaHsauADeUuiU4+n3ZdOu8YMCpei+lM+eRR3KadZ2/UE9lzWQ7PfKDtM +iIeon6oudZIlaTJPsV/AFIwJadJKpxYhgJ6JtlcORBUgzaWFEXRL+/rhVRZInNUC +AwEAATANBgkqhkiG9w0BAQsFAAOCAgEAOzJmOXAcERo/kuB11AnrNqk5bpxqF1Gl +ORNxUQflB0f3qooHkuPH6CdrrZ32yUIN+54fcVCCnQfx04PCC4bPFRreTCqyPCtb +Oinfk5BgEmIvE4x9PibPcmQG6zQfHHqOQzsxio6Fjhfk+iL9Fy30W2K3RBvIicOM +BQ+kGysltV+9tMX4wI/VnLCN5LORbBX7fiMnFtmVKeLZalnOWcMqZuc6opQFjWzg +r4vqu6//STkrCvze4tLUMipS8uKXQ9hvrdiXgQGHOZDhRaQCC+TAXYxPn2pxvYYK +l7dlQS1mWY8pPB7X9FMsACmZR2myBIqbHzFsde+Mqyf5fWHihtWwNYPYreCXKZdr +A7LtgQG9KhTUQO9HjFkbG/VYiH5rPUlewd+qLVvdZ8vFS6ZMvMH7eJPdL0ubuM4s +vTDgPXxqE4GqfzuT0d+vmJujllkiOYdbkDNRYv0rekojNbJcNyyDCs1056ke5JPr +//XfgeW1Lwz1yL9xB5U1lqVUaGIifzihO69yNESUSh/niuwDeWYkz/bgo9oM3L9+ +f1WznzC/tcibq+d9V6PE7KRiGfS5ZbRxAm95wrnRurZYkM+eeZHDmPs3InfYe0Zj +WarJjoO+x/+/ErjgsVUHt9JqB8GdXO3Xg7c5bkrt6LqgYxZ2GUDZZSbe/MTktYsp +E/Y7rCRq6LQ= +-----END CERTIFICATE----- diff --git a/tests/fixtures/certs/expired/key.pem b/tests/fixtures/certs/expired/key.pem new file mode 100644 index 0000000..a016fbb --- /dev/null +++ b/tests/fixtures/certs/expired/key.pem @@ -0,0 +1,51 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIJJwIBAAKCAgEAx2ytodLmx/I7DRe6JTGn98I/DEcRdow+f+6UjjIQczPBjD97 +JsfV45eVaIWmRMjqn+A8zAKnsBdRpGlFwbAdG174cu/BLdNb/OoVxCSkiZpHwtmu +RVofOgo2VnFTjgG7gu/4GV7SIOsngz/uB+W7xw/GVQfEsDld3cjiLObn0aFvD7oK +E6wAP0gmbGrNDmvFikVUQIU0tuWfc6DLK2QEh00KOsOAjX9bwk7cOO6+W4XgEAqY +x0XsyNPkWOG8D9FzttnC1UESSoOyr41ne6sn5knkxdk00dxAVpabt7Z1Do4uNlAR +324+u62GkMP9Pv0tXU8YZxnHFT5njXDJMKXwj7vWiHPzR7Ykw6/fCUOYOcoprOwe +SUmevmUIGZKQVyKDrcLumGC4IPfv+UCPQyA3eUEyTPnkDawDRepha0FsZ7PtR/0F +tm7XW5u4HFMhRyrnDrHBGNywUg+bYGl7MWIqr0p3o9CVENDIRgLyuDrCnnMBUFrq +pbPp7Q6Z64ohdpvb+eJRYBCUJkbbFawUa/SXe6c5/cFAFwoNgN0UHcBvWKpVINc5 +WJPgiaHsauADeUuiU4+n3ZdOu8YMCpei+lM+eRR3KadZ2/UE9lzWQ7PfKDtMiIeo +n6oudZIlaTJPsV/AFIwJadJKpxYhgJ6JtlcORBUgzaWFEXRL+/rhVRZInNUCAwEA +AQKCAgADzJG4OfzUhUxTsQaGS95fzW8HDFmMURqltEVXOiPvFebThagSco8kEVCy +14z11YAGwK5X0psgMymGgMzn5jN/wHzqL6AV/+dKN6lnfa02w94nG5+Cybc7k1M6 +rVkCpQzN70ViMli9cM1lZjPiKaG8ppPILeg01TrxDTEl2tZCu5kSiyBDBK1Sh0zY +FubGJg5y1mRHAGKjM1eoy8DjGDov26tcuDm8OFdmqbrvSLkOpEvC8ni7nxzmLIc2 +nEJJaNuT+a0JA/7VtZGTX5W/mOCfNfwqOruTXedJ3v+jbdHoD5RYy4izoXWHfMRK +ALnT193j36xe1nJg+LnfS21BxH+DLNk4OkdiZJzOo3+7uiklZNO0vo6+8RVAckSB +Cvot8x84kUgPPzNyxVilzL5lvfaGrRpAWkPgOFmxyYYkjoenFh0FkWF5foF3sm/S +oSqf6/ojOzxbazI1oZ0yx3YMhfxnTdxQOngy9QeNhEF37ZuU4tutWInyuiIE2u5s +9XNXO7hYqdPqrSV+JOIMuYPTLTxdsKkSbdS4tHUuLt7mO0E2opo1ti8/lvIy10qA +eF2bm/Vcrpt5Zs406uTOZD1a3JaFwVKC1rzVNEDrVomRxsie6Hubkhd8X3SaBx1z +0bicV4yzhPFYu3iqNS2ZU9cJa4H8qJoQ8n1H8Yhi1O7Sc9CBkQKCAQEA7qlyMfPl +N0Vk45BSIWnD+vCaxQDYgqqDXlluGR8NgL0Yhst9Q0XpcP12eIKcTZ6pzdfUe/vK +NsCJXQnDLtxlIzWznXyiY1n/hNRW2dwwYEzPi8IZzE2RM1PltoHn7noVeAmZiayb +mf9M1pRUhbfVsg2gpbGnIR/XLC/1Ll2xNChkSiD5lbRNccRQZqkkIjYl0qCpvnoQ +LQCqdFmVgz0YAkooOf1n2vxVL4g1Ad3761FRSXqsPCNfvPY5+8LhFfBauuMoeS4t +DFjss7yyA6zA9+hxke3dqHVVnXWRTZkuuCzrcT5zOGv31K3l97BmCcKW4lXq6Gja +o0786rg3dpFOOQKCAQEA1emCNwEx5nt2HE070e6TwHBrMPWPKLionodxPsD2hsb0 +RfQKwqxabVVQKEcx+ZpsFFRUF79cHqVwIK6umTwfUdkrgxosjUqr+XYpINfICGln +KIp/nfGac1Sfx528rj7X5ZPaVzldEwCiReZmW29cS/XMKj5fpWXsRTBDr6tvenJn +UiY/6xo5vftFuCHVs4MUntK5BBNSYt3AJq/ZTZx9Lso2+EhkVA1cOlf05TRcWaiI +5KUklf7Rm9owxB3L8uWuBDlxmNymYbn2SuZywaWZlLBA2dTokUK200S4Yhzq/Xrz +dKFn0uTMnATgyj9FUqVZn7p40i7jF93bbBebmMDDfQKCAQBGHMdsf18mRp+l7r8C +C+VEMiz1lRMGB/vB2vnqLWI1INg0uVEaU06KIBwOuSgb8XGnBDHrHoRAY323NGf/ +u0WG+37B1FyMXWMgbZT6OaKIl+gdAa+8gkkW0B3a6Pzu5TSrZ/6QIIIx0nuLSlYu +VlxUC4bXRoJ3y7fVxlz7+xBU50zXLirEXQynUGniTuxLlKa14vca+xcHcXuh5LN0 +s5z7BzgcGSLKhXitFxGjc8hPUDtWH9C7dhTpGVjdalnfrRWqc5NvTi5zwyf+gX+2 +bqjd6455tWx50caOFHzUVB0ShDfCs/r7Z1SOSWwWwN6pHV5gLaduEWextEG+3tGE +ZpmZAoIBAA1FeYC0IEZubnt/BzEVHjGYR+43rfQW0M9VE9+S1Tiza0BTzb8aNloG +Kvz0vdMAk6gHO1hl1O9J0FUWwVpccoz/bkWqAA2cDmNhw1d4S77J206WmShRbwWs +wGUAEk61M2vY6njy5CVjqq2vh7YwiIdl7o7IY+K9GhWI0wo5FqeAJYzhNqH9dIum +5UJxRvLmNQdNh5ELKddcbql3y4GXLeUTQqnQw/i7A3fTMSxvPTOK00NsQ4LS1mpW +9SOVvauKOGumrLeRKPlzMiafeYsuHQMulDdvkCZC/1jIMLBVnvavBB++S9S3wUIE +w3WIy2I/Q/o29XwE0K4QY6anKE4n13kCggEAL0pUNXZc9Sz3auY6PmWCJX+XEggJ +CX8DbUADPKVWGMRekhbSkcgeTZOzqDg3Tcbd5PBQ+PdP9/MG316S0KCvMeBFn4ze +GtKPFfFruS3aLfocQLHP/9p+dEL1g5LjaGvuWQDD++X4EfmNTFe+1yOCb6GTyFBg +AU1QlkSSqQ1TVlH5URUQsLizIwlhTSGy/9J1ylFcLwCnl4uW8VR+exul1tx8zX6P +YWFci1ppCwbwIpSY8zjE8MT5MA3KvOg22gdJUhBhHw36xsYs4MBbzP6YIwm0d+7M +vEcenOCHCMyOdA/Y4/036KnOBaRMyjZ/pABCK1KKaCPZQZR/FaRyyfJ36g== +-----END RSA PRIVATE KEY----- diff --git a/tests/unit/common/test_crypto.py b/tests/unit/common/test_crypto.py new file mode 100644 index 0000000..9eeee2b --- /dev/null +++ b/tests/unit/common/test_crypto.py @@ -0,0 +1,56 @@ +# standard imports +import os +from datetime import datetime, UTC + +# lib imports +from cryptography import x509 +import pytest + +# local imports +from src.common.crypto import check_expiration, generate_certificate, initialize_certificate, CERT_FILE, KEY_FILE + + +@pytest.fixture(scope='module') +def setup_certificates(): + # Ensure the certificates are generated for testing + if not os.path.exists(CERT_FILE) or not os.path.exists(KEY_FILE): + generate_certificate() + yield + # Cleanup after tests + if os.path.exists(CERT_FILE): + os.remove(CERT_FILE) + if os.path.exists(KEY_FILE): + os.remove(KEY_FILE) + + +def test_check_expiration(setup_certificates): + days_left = check_expiration(CERT_FILE) + assert days_left <= 365 + assert days_left >= 364 + + +def test_check_expiration_expired(): + cert_file = os.path.join("tests", "fixtures", "certs", "expired", "cert.pem") + days_left = check_expiration(cert_file) + assert days_left < 0 + + +def test_generate_certificate(setup_certificates): + assert os.path.exists(CERT_FILE) + assert os.path.exists(KEY_FILE) + + with open(CERT_FILE, "rb") as cert_file: + cert_data = cert_file.read() + + cert = x509.load_pem_x509_certificate(cert_data) + assert cert.not_valid_after_utc > datetime.now(UTC) + + +def test_initialize_certificate(setup_certificates): + cert_file, key_file = initialize_certificate() + assert os.path.exists(cert_file) + assert os.path.exists(key_file) + + cert_expires_in = check_expiration(cert_file) + assert cert_expires_in <= 365 + assert cert_expires_in >= 364 diff --git a/tests/unit/common/test_sponsors.py b/tests/unit/common/test_sponsors.py new file mode 100644 index 0000000..b0b0aa8 --- /dev/null +++ b/tests/unit/common/test_sponsors.py @@ -0,0 +1,17 @@ +# local imports +from src.common import sponsors + + +def test_get_github_sponsors(): + data = sponsors.get_github_sponsors() + assert data + assert 'errors' not in data + assert 'data' in data + assert 'organization' in data['data'] + assert 'sponsorshipsAsMaintainer' in data['data']['organization'] + assert 'edges' in data['data']['organization']['sponsorshipsAsMaintainer'] + + +def test_get_github_sponsors_error(no_github_token): + data = sponsors.get_github_sponsors() + assert not data diff --git a/tests/unit/common/test_time.py b/tests/unit/common/test_time.py new file mode 100644 index 0000000..ff41e9e --- /dev/null +++ b/tests/unit/common/test_time.py @@ -0,0 +1,17 @@ +# standard imports +import datetime + +# lib imports +import pytest + +# local imports +from src.common import time + + +@pytest.mark.parametrize("iso_str, expected", [ + ("2024-11-23T20:29:48", datetime.datetime(2024, 11, 23, 20, 29, 48)), + ("2023-01-01T00:00:00", datetime.datetime(2023, 1, 1, 0, 0, 0)), + ("2022-12-31T23:59:59", datetime.datetime(2022, 12, 31, 23, 59, 59)), +]) +def test_iso_to_datetime(iso_str, expected): + assert time.iso_to_datetime(iso_str) == expected diff --git a/tests/unit/common/test_webapp.py b/tests/unit/common/test_webapp.py new file mode 100644 index 0000000..dd6b166 --- /dev/null +++ b/tests/unit/common/test_webapp.py @@ -0,0 +1,209 @@ +# standard imports +import os + +# lib imports +import pytest + +# local imports +from src.common import webapp + + +@pytest.fixture(scope='function') +def test_client(): + """Create a test client for testing webapp endpoints""" + app = webapp.app + app.testing = True + + client = app.test_client() + + # Create a test client using the Flask application configured for testing + with client as test_client: + # Establish an application context + with app.app_context(): + yield test_client # this is where the testing happens! + + +def test_status(test_client): + """ + WHEN the '/status' page is requested (GET) + THEN check that the response is valid + """ + response = test_client.get('/status') + assert response.status_code == 200 + + +def test_favicon(test_client): + """ + WHEN the '/favicon.ico' file is requested (GET) + THEN check that the response is valid + THEN check the content type is 'image/vnd.microsoft.icon' + """ + response = test_client.get('/favicon.ico') + assert response.status_code == 200 + assert response.content_type == 'image/vnd.microsoft.icon' + + +def test_discord_callback_invalid_state(discord_bot, test_client, mocker): + """ + WHEN the '/discord/callback' endpoint is requested (GET) with an invalid state + THEN check that the response is 'Invalid state' + """ + mocker.patch('src.common.webapp.OAuth2Session.fetch_token', return_value={'access_token': 'fake_token'}) + mocker.patch('src.common.webapp.OAuth2Session.get', return_value=mocker.Mock(json=lambda: {'id': 'invalid_user'})) + + response = test_client.get('/discord/callback') + assert response.data == b'Invalid state' + assert response.status_code == 200 + + +def test_webhook_invalid_source(test_client): + """ + WHEN the '/webhook//' endpoint is requested (POST) with an invalid source + THEN check that the response is 'Invalid source' + """ + response = test_client.post('/webhook/invalid_source/invalid_key') + assert response.json == {"status": "error", "message": "Invalid source"} + assert response.status_code == 400 + + +def test_webhook_invalid_key(test_client, mocker): + """ + WHEN the '/webhook//' endpoint is requested (POST) with an invalid key + THEN check that the response is 'Invalid key' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + response = test_client.post('/webhook/github_sponsors/invalid_key') + assert response.json == {"status": "error", "message": "Invalid key"} + assert response.status_code == 400 + + +def test_webhook_github_sponsors(discord_bot, test_client, mocker): + """ + WHEN the '/webhook/github_sponsors/' endpoint is requested (POST) with valid data + THEN check that the response is 'success' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + data = { + 'action': 'created', + 'sponsorship': { + 'sponsor': { + 'login': 'octocat', + 'url': 'https://github.com/octocat', + 'avatar_url': 'https://avatars.githubusercontent.com/u/583231', + }, + 'created_at': '1970-01-01T00:00:00Z', + }, + } + response = test_client.post('/webhook/github_sponsors/valid_key', json=data) + assert response.json == {"status": "success"} + assert response.status_code == 200 + + +@pytest.mark.parametrize("data", [ + # https://support.atlassian.com/statuspage/docs/enable-webhook-notifications/ + { + "meta": { + "unsubscribe": "http://statustest.flyingkleinbrothers.com:5000/?unsubscribe=j0vqr9kl3513", + "documentation": "http://doers.statuspage.io/customer-notifications/webhooks/", + }, + "page": { + "id": "j2mfxwj97wnj", + "status_indicator": "major", + "status_description": "Partial System Outage", + }, + "component_update": { + "created_at": "2013-05-29T21:32:28Z", + "new_status": "operational", + "old_status": "major_outage", + "id": "k7730b5v92bv", + "component_id": "rb5wq1dczvbm", + }, + "component": { + "created_at": "2013-05-29T21:32:28Z", + "id": "rb5wq1dczvbm", + "name": "Some Component", + "status": "operational", + }, + }, + { + "meta": { + "unsubscribe": "http://statustest.flyingkleinbrothers.com:5000/?unsubscribe=j0vqr9kl3513", + "documentation": "http://doers.statuspage.io/customer-notifications/webhooks/", + }, + "page": { + "id": "j2mfxwj97wnj", + "status_indicator": "critical", + "status_description": "Major System Outage", + }, + "incident": { + "backfilled": False, + "created_at": "2013-05-29T15:08:51-06:00", + "impact": "critical", + "impact_override": None, + "monitoring_at": "2013-05-29T16:07:53-06:00", + "postmortem_body": None, + "postmortem_body_last_updated_at": None, + "postmortem_ignored": False, + "postmortem_notified_subscribers": False, + "postmortem_notified_twitter": False, + "postmortem_published_at": None, + "resolved_at": None, + "scheduled_auto_transition": False, + "scheduled_for": None, + "scheduled_remind_prior": False, + "scheduled_reminded_at": None, + "scheduled_until": None, + "shortlink": "http://j.mp/18zyDQx", + "status": "monitoring", + "updated_at": "2013-05-29T16:30:35-06:00", + "id": "lbkhbwn21v5q", + "organization_id": "j2mfxwj97wnj", + "incident_updates": [ + { + "body": "A fix has been implemented and we are monitoring the results.", + "created_at": "2013-05-29T16:07:53-06:00", + "display_at": "2013-05-29T16:07:53-06:00", + "status": "monitoring", + "twitter_updated_at": None, + "updated_at": "2013-05-29T16:09:09-06:00", + "wants_twitter_update": False, + "id": "drfcwbnpxnr6", + "incident_id": "lbkhbwn21v5q", + }, + { + "body": "We are waiting for the cloud to come back online " + "and will update when we have further information", + "created_at": "2013-05-29T15:18:51-06:00", + "display_at": "2013-05-29T15:18:51-06:00", + "status": "identified", + "twitter_updated_at": None, + "updated_at": "2013-05-29T15:28:51-06:00", + "wants_twitter_update": False, + "id": "2rryghr4qgrh", + "incident_id": "lbkhbwn21v5q", + }, + { + "body": "The cloud, located in Norther Virginia, has once again gone the way of the dodo.", + "created_at": "2013-05-29T15:08:51-06:00", + "display_at": "2013-05-29T15:08:51-06:00", + "status": "investigating", + "twitter_updated_at": None, + "updated_at": "2013-05-29T15:28:51-06:00", + "wants_twitter_update": False, + "id": "qbbsfhy5s9kk", + "incident_id": "lbkhbwn21v5q", + }, + ], + "name": "Virginia Is Down", + }, + } +]) +def test_webhook_github_status(discord_bot, test_client, mocker, data): + """ + WHEN the '/webhook/github_status/' endpoint is requested (POST) with valid data + THEN check that the response is 'success' + """ + mocker.patch.dict(os.environ, {"GITHUB_WEBHOOK_SECRET_KEY": "valid_key"}) + response = test_client.post('/webhook/github_status/valid_key', json=data) + assert response.json == {"status": "success"} + assert response.status_code == 200 diff --git a/tests/unit/discord/test_discord_bot.py b/tests/unit/discord/test_discord_bot.py index 500722c..dd56821 100644 --- a/tests/unit/discord/test_discord_bot.py +++ b/tests/unit/discord/test_discord_bot.py @@ -1,42 +1,38 @@ # standard imports import asyncio - -# lib imports -import pytest -import pytest_asyncio +import os # local imports -from src import common -from src.discord import bot as discord_bot - - -@pytest_asyncio.fixture -async def bot(): - # event_loop fixture is deprecated - _loop = asyncio.get_event_loop() - - bot = discord_bot.Bot(loop=_loop) - future = asyncio.run_coroutine_threadsafe(bot.start(token=bot.token), _loop) - await bot.wait_until_ready() # Wait until the bot is ready - yield bot - bot.stop(future=future) - - # wait for the bot to finish - counter = 0 - while not future.done() and counter < 30: - await asyncio.sleep(1) - counter += 1 - future.cancel() # Cancel the bot when the tests are done - - -@pytest.mark.asyncio -async def test_bot_on_ready(bot): - assert bot is not None - assert bot.guilds - assert bot.guilds[0].name == "ReenigneArcher's test server" - assert bot.user.id == 939171917578002502 - assert bot.user.name == common.bot_name - assert bot.user.avatar +from src.common import common + + +def test_bot_on_ready(discord_bot): + assert discord_bot is not None + assert discord_bot.guilds + assert discord_bot.guilds[0].name == "ReenigneArcher's test server" + assert discord_bot.user.id == 939171917578002502 + assert discord_bot.user.name == common.bot_name + assert discord_bot.user.avatar # compare the bot avatar to our intended avatar - assert await bot.user.avatar.read() == common.get_avatar_bytes() + future = asyncio.run_coroutine_threadsafe(discord_bot.user.avatar.read(), discord_bot.loop) + assert future.result() == common.get_avatar_bytes() + + +def test_send_message(discord_bot): + channel_id = int(os.environ['DISCORD_REDDIT_CHANNEL_ID']) + message = f"This is a test message from {os.getenv('CI_EVENT_ID', 'local')}." + embeds = [] + msg = discord_bot.send_message(channel_id=channel_id, message=message, embeds=embeds) + assert msg.content == message + assert msg.channel.id == channel_id + assert msg.author.id == 939171917578002502 + assert msg.author.name == common.bot_name + + avatar_future = asyncio.run_coroutine_threadsafe(msg.author.avatar.read(), discord_bot.loop) + assert avatar_future.result() == common.get_avatar_bytes() + + assert msg.author.display_name == common.bot_name + assert msg.author.discriminator == "7085" + assert msg.author.bot is True + assert msg.author.system is False diff --git a/tests/unit/reddit/test_reddit_bot.py b/tests/unit/reddit/test_reddit_bot.py index 8ff1a84..07a38bc 100644 --- a/tests/unit/reddit/test_reddit_bot.py +++ b/tests/unit/reddit/test_reddit_bot.py @@ -161,7 +161,7 @@ def _submission(self, bot, recorder): def test_validate_env(self, bot): with patch.dict( os.environ, { - "DISCORD_WEBHOOK": "test", + "DISCORD_REDDIT_CHANNEL_ID": "test", "PRAW_CLIENT_ID": "test", "PRAW_CLIENT_SECRET": "test", "REDDIT_PASSWORD": "test", @@ -198,7 +198,7 @@ def test_process_comment(self, bot, recorder, request, slash_command_comment): assert db['comments'][slash_command_comment.id]['slash_command']['project'] == 'sunshine' assert db['comments'][slash_command_comment.id]['slash_command']['command'] == 'vban' - def test_process_submission(self, bot, recorder, request, _submission): + def test_process_submission(self, bot, discord_bot, recorder, request, _submission): with recorder.use_cassette(request.node.name): bot.process_submission(submission=_submission) with bot.lock, shelve.open(bot.db) as db: @@ -213,7 +213,7 @@ def test_comment_loop(self, bot, recorder, request): comment = bot._comment_loop(test=True) assert comment.author - def test_submission_loop(self, bot, recorder, request): + def test_submission_loop(self, bot, discord_bot, recorder, request): with recorder.use_cassette(request.node.name): submission = bot._submission_loop(test=True) assert submission.author