Skip to content

Commit

Permalink
Update bot to use only PRAW, remove Pushshift API
Browse files Browse the repository at this point in the history
  • Loading branch information
Lauler committed Dec 17, 2023
1 parent c982d58 commit a973400
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 124 deletions.
68 changes: 40 additions & 28 deletions bot.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
import logging
import os
import torch
import praw
import pandas as pd
import datetime as dt
import logging
import praw
from psaw import PushshiftAPI
from dotenv import load_dotenv
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoTokenizer,
pipeline,
)
from src.comment import choose_post, create_reply_msg
from src.data import (
download_comments,
analyze_comments,
download_submission,
filter_comments,
get_posted_comments,
merge_comment_submission,
predict_comments,
preprocess_comments,
save_feather,
)
from src.translate import translation_preprocess
from dotenv import load_dotenv

logging.basicConfig(
filename="sprakpolisen.log",
Expand All @@ -39,9 +35,14 @@
model = AutoModelForTokenClassification.from_pretrained("Lauler/deformer")
model.to(device)

# NER pipeline
pipe = pipeline("ner", model=model, tokenizer=tokenizer, device=0)

# Machine Translation model
tokenizer_translate = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-sv-en")
model_translate = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-sv-en", output_attentions=True)
model_translate = AutoModelForSeq2SeqLM.from_pretrained(
"Helsinki-NLP/opus-mt-sv-en", output_attentions=True
)
model_translate.eval()
model_translate.to(device)

Expand All @@ -61,22 +62,27 @@
password=pw,
)

api = PushshiftAPI(reddit)
subreddit = reddit.subreddit("sweden")

df = download_comments(api, weeks=0, hours=4, minutes=45)
df = preprocess_comments(df) # Sentence splitting, and more
pipe = pipeline("ner", model=model, tokenizer=tokenizer, device=0)
df = predict_comments(df, pipe, threshold=0.98) # Only saves preds above threshold
df_comment = filter_comments(df)

#### Write comment info to file ####
date = dt.datetime.now().strftime("%Y-%m-%d_%H-%M")
save_feather(df_comment, type="comment", date=date)
df_subs = []
df_comments = []

for submission in subreddit.hot(limit=35):
if submission.num_comments == 0:
continue

#### Download info about submission thread ####
df_comment["id_sub"] = df_comment["link_id"].str.slice(start=3)
df_sub = download_submission(df_comment["id_sub"].tolist(), reddit_api=reddit)
df_sub = download_submission(submission)
df_comment = analyze_comments(submission, pipe=pipe)
df_subs.append(df_sub)
df_comments.append(df_comment)

df_sub = pd.concat(df_subs).reset_index(drop=True)
df_comment = pd.concat(df_comments).reset_index(drop=True)

#### Write comment and submission info to file ####
date = dt.datetime.now().strftime("%Y-%m-%d_%H-%M")
save_feather(df_comment, type="comment", date=date)
save_feather(df_sub, type="submission", date=date)

# Merge
Expand All @@ -89,12 +95,17 @@
except:
pass

df_all = df_all[~(df_all["n_mis_det"] == 1)].reset_index(drop=True)

# Choose which comment to post reply to
df_post = choose_post(df_all, min_hour=0.7, max_hour=17)
df_post = choose_post(df_all, min_hour=0.7, max_hour=19)

df_all.columns
# df_post = df_all.iloc[1:2].reset_index(drop=True)

# df_post = df_all.iloc[2:3].reset_index(drop=True)
df_post["sentences"] = df_post["sentences"].apply(lambda sens: [sen.replace("…", ".") for sen in sens])
df_post["sentences"] = df_post["sentences"].apply(
lambda sens: [sen.replace("…", ".") for sen in sens]
)

#### Translate to English
pipes = translation_preprocess(
Expand All @@ -106,8 +117,8 @@

reply_msg = create_reply_msg(df_post, pipes=pipes)

save_feather(df_all, type="all", date=date)

save_feather(df_all, type="all", date=date)

for i in range(len(df_all)):
try:
Expand All @@ -123,7 +134,7 @@
# if a single comment author in the comment chain has blocked SprakpolisenBot.
logging.error(f'Failed replying to comment id {df_post["id"][0]} because of block.')
df_all = df_all[df_all["id"] != df_post["id"][0]] # Remove unsuccessful reply attempt
df_post = choose_post(df_all, min_hour=1, max_hour=17)
df_post = choose_post(df_all, min_hour=1, max_hour=19)

#### Translate to English
pipes = translation_preprocess(
Expand All @@ -134,6 +145,7 @@
)
reply_msg = create_reply_msg(df_post, pipes=pipes)


logging.info("Succesfully replied.")

# Save replies/posted comments
Expand Down
12 changes: 4 additions & 8 deletions src/comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def choose_post(df_all, min_hour, max_hour):
logger.exception("No suitable comment reply candidates. Exiting.")
raise e

if any(df_all["nr_mistakes"] > 1):
df_multimistake = df_all[df_all["nr_mistakes"] > 1].reset_index(drop=True)
if any(df_all["n_mis"] > 1):
df_multimistake = df_all[df_all["n_mis"] > 1].reset_index(drop=True)

if len(df_multimistake) > 1:
max_mistake_idx = df_multimistake["nr_mistakes"].idxmax()
max_mistake_idx = df_multimistake["n_mis"].idxmax()
df_post = df_multimistake.iloc[max_mistake_idx : (max_mistake_idx + 1), :]
else:
df_post = df_multimistake
Expand Down Expand Up @@ -81,7 +81,6 @@ def correct_sentence(preds, sentences):
offset = 0
correct_sens = []
for pred, sentence in zip(preds, sentences):

if len(pred) == 0:
continue

Expand All @@ -108,7 +107,7 @@ def correct_sentence(preds, sentences):

def correct_sentence_en(preds, correct_sens):
"""
We have correctly translated sentences already, but want to
We have correctly translated sentences already, but want to
introduce the wrong form of they/them with a strikethrough
next to the already corrected instance of they/them/the/those.
"""
Expand All @@ -131,14 +130,12 @@ def correct_sentence_en(preds, correct_sens):
them_words = ["Them", "them"]

for i, (pred, sentence) in enumerate(zip(preds_en, correct_sens)):

target_indices = [] # To check if we're trying to correct same word twice or more
contains_pred_mismatch = False # Check if 'de' maps to they_words, and 'dem' to them_words
if len(pred) == 0:
continue

for j, entity in enumerate(pred):

target_indices.append(entity["index"])
if entity["word"] not in (they_words + them_words):
logger.info(
Expand Down Expand Up @@ -211,4 +208,3 @@ def create_reply_msg(df_post, pipes):
message += create_guide(df_post)

return message

Loading

0 comments on commit a973400

Please sign in to comment.