diff --git a/.env.example b/.env.example index 48b06c7..24ff444 100644 --- a/.env.example +++ b/.env.example @@ -34,3 +34,4 @@ ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]} ML_MAX_NUMBER_OF_ITERATIONS=100 UPDATE_CHAT_ADMINS_INTERVAL_SEC=86400 UPDATE_CHAT_ADMINS=true +ML_CUSTOM_EMOJI_THRESHOLD=20 diff --git a/src/VahterBanBot.Tests/TgMessageUtils.fs b/src/VahterBanBot.Tests/TgMessageUtils.fs index 80d43b3..13caef5 100644 --- a/src/VahterBanBot.Tests/TgMessageUtils.fs +++ b/src/VahterBanBot.Tests/TgMessageUtils.fs @@ -33,8 +33,11 @@ type Tg() = ChatInstance = Guid.NewGuid().ToString() ) ) + + static member emoji(?offset: int) = MessageEntity(Type = MessageEntityType.CustomEmoji, Offset = defaultArg offset 0 , Length = 1) + static member emojies(n: int) = Array.init n (fun i -> Tg.emoji i) - static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string) = + static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string, ?entities: MessageEntity[]) = let updateId = next() let msgId = next() Update( @@ -47,7 +50,8 @@ type Tg() = From = (from |> Option.defaultValue (Tg.user())), Date = (date |> Option.defaultValue DateTime.UtcNow), Caption = (caption |> Option.defaultValue null), - ReplyToMessage = null + ReplyToMessage = null, + Entities = (entities |> Option.defaultValue null) ), EditedMessage = if editedText |> Option.isSome then diff --git a/src/VahterBanBot/Bot.fs b/src/VahterBanBot/Bot.fs index f397b85..1b37354 100644 --- a/src/VahterBanBot/Bot.fs +++ b/src/VahterBanBot/Bot.fs @@ -513,7 +513,7 @@ let justMessage if not shouldBeSkipped then let! usrMsgCount = DB.countUniqueUserMsg message.From.Id - match ml.Predict(message.TextOrCaption, usrMsgCount) with + match ml.Predict(message.TextOrCaption, usrMsgCount, message.Entities) with | Some prediction -> %mlActivity.SetTag("spamScoreMl", prediction.Score) diff --git a/src/VahterBanBot/DB.fs b/src/VahterBanBot/DB.fs index 39266fb..89cdf6a 100644 --- a/src/VahterBanBot/DB.fs +++ b/src/VahterBanBot/DB.fs @@ -162,10 +162,12 @@ let getUserById (userId: int64): Task = return users |> Seq.tryHead } +[] type SpamOrHamDb = { text: string spam: bool less_than_n_messages: bool + custom_emoji_count: int created_at: DateTime } let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task = @@ -175,7 +177,11 @@ let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task> 'type' = 'custom_emoji') AS custom_emoji_count + FROM message, + LATERAL JSONB_ARRAY_ELEMENTS(raw_message -> 'entities') AS entities + GROUP BY message.id), + less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages FROM "user" u LEFT JOIN message m ON u.id = m.user_id GROUP BY u.id), @@ -191,6 +197,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg spam_or_ham AS (SELECT x.text, x.spam, x.less_than_n_messages, + x.custom_emoji_count, MAX(x.created_at) AS created_at FROM (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text, CASE @@ -210,6 +217,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg ELSE TRUE END AS spam, COALESCE(l.less_than_n_messages, TRUE) AS less_than_n_messages, + COALESCE(ce.custom_emoji_count, 0) AS custom_emoji_count, COALESCE(re_id.banned_at, re_text.banned_at, m.created_at) AS created_at FROM (SELECT * FROM message @@ -218,8 +226,9 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg FULL OUTER JOIN really_banned re_id ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id LEFT JOIN really_banned re_text ON m.text = re_text.message_text + LEFT JOIN custom_emojis ce ON m.id = ce.id LEFT JOIN less_than_n_messages l ON m.user_id = l.id) x - GROUP BY text, spam, less_than_n_messages) + GROUP BY text, spam, less_than_n_messages, custom_emoji_count) SELECT * FROM spam_or_ham ORDER BY created_at; diff --git a/src/VahterBanBot/ML.fs b/src/VahterBanBot/ML.fs index 23f549b..c741231 100644 --- a/src/VahterBanBot/ML.fs +++ b/src/VahterBanBot/ML.fs @@ -21,6 +21,7 @@ type SpamOrHam = { text: string spam: bool lessThanNMessagesF: single + moreThanNEmojisF: single createdAt: DateTime } [] @@ -75,6 +76,7 @@ type MachineLearning( { text = x.text spam = x.spam createdAt = x.created_at + moreThanNEmojisF = if x.custom_emoji_count > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f lessThanNMessagesF = if x.less_than_n_messages then 1.0f else 0.0f } ) |> fun x -> @@ -90,7 +92,7 @@ type MachineLearning( let dataProcessPipeline = mlContext.Transforms.Text .FeaturizeText(outputColumnName = "TextFeaturized", inputColumnName = "text") - .Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF";|])) + .Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF"; "moreThanNEmojisF"|])) .Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression( labelColumnName = "spam", featureColumnName = "Features", @@ -121,14 +123,22 @@ type MachineLearning( logger.LogError(ex, "Error training model") } - member _.Predict(text: string, userMsgCount: int) = + member _.Predict(text: string, userMsgCount: int, entities: MessageEntity array) = try match predictionEngine with | Some predictionEngine -> + let emojiCount = + entities + |> Option.ofObj + |> Option.defaultValue [||] + |> Seq.filter (fun x -> x.Type = MessageEntityType.CustomEmoji) + |> Seq.length + predictionEngine.Predict { text = text spam = false lessThanNMessagesF = if userMsgCount < botConf.MlTrainCriticalMsgCount then 1.0f else 0.0f + moreThanNEmojisF = if emojiCount > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f createdAt = DateTime.UtcNow } |> Some | None -> diff --git a/src/VahterBanBot/Program.fs b/src/VahterBanBot/Program.fs index 5615993..f81382d 100644 --- a/src/VahterBanBot/Program.fs +++ b/src/VahterBanBot/Program.fs @@ -75,7 +75,8 @@ let botConf = MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single MlWarningThreshold = getEnvOr "ML_WARNING_THRESHOLD" "0.0" |> single - MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int + MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int + MlCustomEmojiThreshold = getEnvOr "ML_CUSTOM_EMOJI_THRESHOLD" "20" |> int MlStopWordsInChats = getEnvOr "ML_STOP_WORDS_IN_CHATS" "{}" |> fromJson } let validateApiKey (ctx : HttpContext) = diff --git a/src/VahterBanBot/Types.fs b/src/VahterBanBot/Types.fs index 4f190a7..f1dcbe8 100644 --- a/src/VahterBanBot/Types.fs +++ b/src/VahterBanBot/Types.fs @@ -42,6 +42,7 @@ type BotConfiguration = MlSpamThreshold: single MlWarningThreshold: single MlMaxNumberOfIterations: int + MlCustomEmojiThreshold: int MlStopWordsInChats: Dictionary } []