Skip to content
This repository has been archived by the owner on Oct 2, 2023. It is now read-only.

Be a little friendlier to the Reddit API #125

Closed
wants to merge 12 commits into from
87 changes: 74 additions & 13 deletions integrations/reddit/cog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Optional, List
import re
from datetime import datetime
from itertools import groupby
from sys import platform
from typing import Optional, List, Union

from aiohttp import ClientSession
from discord import Embed, TextChannel
Expand All @@ -27,12 +29,25 @@
logger = get_logger(__name__)


async def user_agent() -> str:
"""
construct a user agent string in accordance with the Reddit API rules:
<https://github.com/reddit-archive/reddit/wiki/API#rules>
"""
author = f"/u/{await RedditSettings.author.get()}"
target_platform = platform
app_id = Config.NAME.replace(" ", "").lower()
version_string = f"v{Config.VERSION}"

return f"{target_platform}:{app_id}:{version_string} (by {author})"


async def exists_subreddit(subreddit: str) -> bool:
subreddit = re.sub("^(/r/|r/)", "", subreddit)
async with ClientSession() as session, session.get(
# raw_json=1 as parameter to get unicode characters instead of html escape sequences
f"https://www.reddit.com/r/{subreddit}/about.json?raw_json=1",
headers={"User-agent": f"{Config.NAME}/{Config.VERSION}"},
headers={"User-agent": await user_agent()},
) as response:
return response.ok

Expand All @@ -42,17 +57,18 @@ async def get_subreddit_name(subreddit: str) -> str:
async with ClientSession() as session, session.get(
Scriptim marked this conversation as resolved.
Show resolved Hide resolved
# raw_json=1 as parameter to get unicode characters instead of html escape sequences
f"https://www.reddit.com/r/{subreddit}/about.json?raw_json=1",
headers={"User-agent": f"{Config.NAME}/{Config.VERSION}"},
headers={"User-agent": await user_agent()},
) as response:
return (await response.json())["data"]["display_name"]


async def fetch_reddit_posts(subreddit: str, limit: int) -> Optional[List[dict]]:
subreddit = re.sub("^(/r/|r/)", "", subreddit)
async def fetch_reddit_posts(subreddits: Union[str, List[str]], limit: int) -> Optional[List[dict]]:
subreddits_str = subreddits if isinstance(subreddits, str) else "+".join(subreddits)

async with ClientSession() as session, session.get(
# raw_json=1 as parameter to get unicode characters instead of html escape sequences
f"https://www.reddit.com/r/{subreddit}/hot.json?raw_json=1",
headers={"User-agent": f"{Config.NAME}/{Config.VERSION}"},
f"https://www.reddit.com/r/{subreddits_str}/hot.json?raw_json=1",
headers={"User-agent": await user_agent()},
params={"limit": str(limit)},
) as response:
if response.status != 200:
Expand Down Expand Up @@ -95,6 +111,12 @@ def create_embed(post: dict) -> Embed:
return embed


async def is_author_set() -> bool:
"""return whether a Reddit author username has been specified in the cog settings"""
author = await RedditSettings.author.get()
return bool(author)


class RedditCog(Cog, name="Reddit"):
CONTRIBUTORS = [Contributor.Scriptim, Contributor.Defelo, Contributor.wolflu, Contributor.Anorak]

Expand All @@ -110,10 +132,16 @@ async def reddit_loop(self):
async def pull_hot_posts(self):
logger.info("pulling hot reddit posts")
limit = await RedditSettings.limit.get()
async for reddit_channel in await db.stream(select(RedditChannel)): # type: RedditChannel
text_channel: Optional[TextChannel] = self.bot.get_channel(reddit_channel.channel)

def text_channel_key(reddit_channel: RedditChannel) -> Optional[TextChannel]:
return self.bot.get_channel(reddit_channel.channel)

reddit_channels = sorted(await db.all(select(RedditChannel)), key=text_channel_key)

for text_channel, reddit_channel_group in groupby(reddit_channels, text_channel_key):
if text_channel is None:
await db.delete(reddit_channel)
for reddit_channel in reddit_channel_group:
await db.delete(reddit_channel)
continue

try:
Expand All @@ -122,9 +150,11 @@ async def pull_hot_posts(self):
await send_alert(self.bot.guilds[0], t.cannot_send(text_channel.mention))
continue

posts = await fetch_reddit_posts(reddit_channel.subreddit, limit)
subreddits = [reddit_channel.subreddit for reddit_channel in reddit_channel_group]

posts = await fetch_reddit_posts(subreddits, limit)
if posts is None:
await send_alert(self.bot.guilds[0], t.could_not_fetch(reddit_channel.subreddit))
await send_alert(self.bot.guilds[0], t.could_not_fetch(",".join(subreddits)))
continue

for post in posts:
Expand All @@ -134,6 +164,10 @@ async def pull_hot_posts(self):
await RedditPost.clean()

async def start_loop(self, interval):
if not await is_author_set():
await send_alert(self.bot.guilds[0], t.log_no_reddit_author_set)
return

self.reddit_loop.cancel()
self.reddit_loop.change_interval(hours=interval)
try:
Expand Down Expand Up @@ -181,6 +215,9 @@ async def reddit_add(self, ctx: Context, subreddit: str, channel: TextChannel):
create a link between a subreddit and a channel
"""

if not await is_author_set():
raise CommandError(t.no_reddit_author_set)

if not await exists_subreddit(subreddit):
raise CommandError(t.subreddit_not_found)

Expand All @@ -202,6 +239,9 @@ async def reddit_remove(self, ctx: Context, subreddit: str, channel: TextChannel
remove a reddit link
"""

if not await is_author_set():
raise CommandError(t.no_reddit_author_set)

subreddit = await get_subreddit_name(subreddit)
link: Optional[RedditChannel] = await db.get(RedditChannel, subreddit=subreddit, channel=channel.id)
if link is None:
Expand Down Expand Up @@ -243,13 +283,34 @@ async def reddit_limit(self, ctx: Context, limit: int):
await reply(ctx, embed=embed)
await send_to_changelog(ctx.guild, t.log_reddit_limit_set(limit))

@reddit.command(name="author", aliases=["user"])
@RedditPermission.write.check
async def reddit_ua_author(self, ctx: Context, author: str):
"""
set author (Reddit username) for user agent string
"""

await RedditSettings.author.set(author)
embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.reddit_author_set)
await send_to_changelog(ctx.guild, t.log_reddit_author_set(author))
await reply(ctx, embed=embed)
Defelo marked this conversation as resolved.
Show resolved Hide resolved

if not self.reddit_loop.is_running():
try:
self.reddit_loop.start()
except RuntimeError:
self.reddit_loop.restart()

@reddit.command(name="trigger", aliases=["t"])
@RedditPermission.trigger.check
async def reddit_trigger(self, ctx: Context):
"""
pull hot posts now and reset the timer
"""

if not await is_author_set():
raise CommandError(t.no_reddit_author_set)

await self.start_loop(await RedditSettings.interval.get())
embed = Embed(title=t.reddit, colour=Colors.Reddit, description=t.done)
await reply(ctx, embed=embed)
1 change: 1 addition & 0 deletions integrations/reddit/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
class RedditSettings(Settings):
interval = 4
limit = 4
author = ""
4 changes: 4 additions & 0 deletions integrations/reddit/translations/en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ reddit_limit_set: "Reddit limit has been updated. :white_check_mark:"
log_reddit_limit_set: "**Reddit limit** has been **set** to {}."
could_not_fetch: Could not fetch reddit posts of r/{}.
cannot_send: Reddit integration cannot send messages to {}.
reddit_author_set: "Reddit author username has been updated. :white_check_mark:"
log_reddit_author_set: "**Reddit author username** has been **set** to u/{}."
no_reddit_author_set: "Please set a Reddit author username."
log_no_reddit_author_set: "Reddit loop has been disabled because no **Reddit author username** is specified."