Skip to content

Commit

Permalink
added knowledge to training about emoji amount (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Szer authored Dec 24, 2024
1 parent dfa46b2 commit cbd8c11
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 8 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]}
ML_MAX_NUMBER_OF_ITERATIONS=100
UPDATE_CHAT_ADMINS_INTERVAL_SEC=86400
UPDATE_CHAT_ADMINS=true
ML_CUSTOM_EMOJI_THRESHOLD=20
8 changes: 6 additions & 2 deletions src/VahterBanBot.Tests/TgMessageUtils.fs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ type Tg() =
ChatInstance = Guid.NewGuid().ToString()
)
)

static member emoji(?offset: int) = MessageEntity(Type = MessageEntityType.CustomEmoji, Offset = defaultArg offset 0 , Length = 1)
static member emojies(n: int) = Array.init n (fun i -> Tg.emoji i)

static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string) =
static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string, ?entities: MessageEntity[]) =
let updateId = next()
let msgId = next()
Update(
Expand All @@ -47,7 +50,8 @@ type Tg() =
From = (from |> Option.defaultValue (Tg.user())),
Date = (date |> Option.defaultValue DateTime.UtcNow),
Caption = (caption |> Option.defaultValue null),
ReplyToMessage = null
ReplyToMessage = null,
Entities = (entities |> Option.defaultValue null)
),
EditedMessage =
if editedText |> Option.isSome then
Expand Down
2 changes: 1 addition & 1 deletion src/VahterBanBot/Bot.fs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ let justMessage
if not shouldBeSkipped then
let! usrMsgCount = DB.countUniqueUserMsg message.From.Id

match ml.Predict(message.TextOrCaption, usrMsgCount) with
match ml.Predict(message.TextOrCaption, usrMsgCount, message.Entities) with
| Some prediction ->
%mlActivity.SetTag("spamScoreMl", prediction.Score)

Expand Down
13 changes: 11 additions & 2 deletions src/VahterBanBot/DB.fs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,12 @@ let getUserById (userId: int64): Task<DbUser option> =
return users |> Seq.tryHead
}

[<CLIMutable>]
type SpamOrHamDb =
{ text: string
spam: bool
less_than_n_messages: bool
custom_emoji_count: int
created_at: DateTime }

let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task<SpamOrHamDb array> =
Expand All @@ -175,7 +177,11 @@ let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task<SpamOrHamDb a
//language=postgresql
let sql =
"""
WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages
WITH custom_emojis AS (SELECT message.id, COUNT(*) FILTER (WHERE entities ->> 'type' = 'custom_emoji') AS custom_emoji_count
FROM message,
LATERAL JSONB_ARRAY_ELEMENTS(raw_message -> 'entities') AS entities
GROUP BY message.id),
less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages
FROM "user" u
LEFT JOIN message m ON u.id = m.user_id
GROUP BY u.id),
Expand All @@ -191,6 +197,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
spam_or_ham AS (SELECT x.text,
x.spam,
x.less_than_n_messages,
x.custom_emoji_count,
MAX(x.created_at) AS created_at
FROM (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text,
CASE
Expand All @@ -210,6 +217,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
ELSE TRUE
END AS spam,
COALESCE(l.less_than_n_messages, TRUE) AS less_than_n_messages,
COALESCE(ce.custom_emoji_count, 0) AS custom_emoji_count,
COALESCE(re_id.banned_at, re_text.banned_at, m.created_at) AS created_at
FROM (SELECT *
FROM message
Expand All @@ -218,8 +226,9 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
FULL OUTER JOIN really_banned re_id
ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id
LEFT JOIN really_banned re_text ON m.text = re_text.message_text
LEFT JOIN custom_emojis ce ON m.id = ce.id
LEFT JOIN less_than_n_messages l ON m.user_id = l.id) x
GROUP BY text, spam, less_than_n_messages)
GROUP BY text, spam, less_than_n_messages, custom_emoji_count)
SELECT *
FROM spam_or_ham
ORDER BY created_at;
Expand Down
14 changes: 12 additions & 2 deletions src/VahterBanBot/ML.fs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type SpamOrHam =
{ text: string
spam: bool
lessThanNMessagesF: single
moreThanNEmojisF: single
createdAt: DateTime }

[<CLIMutable>]
Expand Down Expand Up @@ -75,6 +76,7 @@ type MachineLearning(
{ text = x.text
spam = x.spam
createdAt = x.created_at
moreThanNEmojisF = if x.custom_emoji_count > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f
lessThanNMessagesF = if x.less_than_n_messages then 1.0f else 0.0f }
)
|> fun x ->
Expand All @@ -90,7 +92,7 @@ type MachineLearning(
let dataProcessPipeline =
mlContext.Transforms.Text
.FeaturizeText(outputColumnName = "TextFeaturized", inputColumnName = "text")
.Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF";|]))
.Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF"; "moreThanNEmojisF"|]))
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(
labelColumnName = "spam",
featureColumnName = "Features",
Expand Down Expand Up @@ -121,14 +123,22 @@ type MachineLearning(
logger.LogError(ex, "Error training model")
}

member _.Predict(text: string, userMsgCount: int) =
member _.Predict(text: string, userMsgCount: int, entities: MessageEntity array) =
try
match predictionEngine with
| Some predictionEngine ->
let emojiCount =
entities
|> Option.ofObj
|> Option.defaultValue [||]
|> Seq.filter (fun x -> x.Type = MessageEntityType.CustomEmoji)
|> Seq.length

predictionEngine.Predict
{ text = text
spam = false
lessThanNMessagesF = if userMsgCount < botConf.MlTrainCriticalMsgCount then 1.0f else 0.0f
moreThanNEmojisF = if emojiCount > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f
createdAt = DateTime.UtcNow }
|> Some
| None ->
Expand Down
3 changes: 2 additions & 1 deletion src/VahterBanBot/Program.fs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ let botConf =
MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float
MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single
MlWarningThreshold = getEnvOr "ML_WARNING_THRESHOLD" "0.0" |> single
MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int
MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int
MlCustomEmojiThreshold = getEnvOr "ML_CUSTOM_EMOJI_THRESHOLD" "20" |> int
MlStopWordsInChats = getEnvOr "ML_STOP_WORDS_IN_CHATS" "{}" |> fromJson }

let validateApiKey (ctx : HttpContext) =
Expand Down
1 change: 1 addition & 0 deletions src/VahterBanBot/Types.fs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type BotConfiguration =
MlSpamThreshold: single
MlWarningThreshold: single
MlMaxNumberOfIterations: int
MlCustomEmojiThreshold: int
MlStopWordsInChats: Dictionary<int64, string list> }

[<CLIMutable>]
Expand Down

0 comments on commit cbd8c11

Please sign in to comment.