Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added knowledge to training about emoji amount #65

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]}
ML_MAX_NUMBER_OF_ITERATIONS=100
UPDATE_CHAT_ADMINS_INTERVAL_SEC=86400
UPDATE_CHAT_ADMINS=true
ML_CUSTOM_EMOJI_THRESHOLD=20
8 changes: 6 additions & 2 deletions src/VahterBanBot.Tests/TgMessageUtils.fs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ type Tg() =
ChatInstance = Guid.NewGuid().ToString()
)
)

static member emoji(?offset: int) = MessageEntity(Type = MessageEntityType.CustomEmoji, Offset = defaultArg offset 0 , Length = 1)
static member emojies(n: int) = Array.init n (fun i -> Tg.emoji i)

static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string) =
static member quickMsg (?text: string, ?chat: Chat, ?from: User, ?date: DateTime, ?callback: CallbackQuery, ?caption: string, ?editedText: string, ?entities: MessageEntity[]) =
let updateId = next()
let msgId = next()
Update(
Expand All @@ -47,7 +50,8 @@ type Tg() =
From = (from |> Option.defaultValue (Tg.user())),
Date = (date |> Option.defaultValue DateTime.UtcNow),
Caption = (caption |> Option.defaultValue null),
ReplyToMessage = null
ReplyToMessage = null,
Entities = (entities |> Option.defaultValue null)
),
EditedMessage =
if editedText |> Option.isSome then
Expand Down
2 changes: 1 addition & 1 deletion src/VahterBanBot/Bot.fs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ let justMessage
if not shouldBeSkipped then
let! usrMsgCount = DB.countUniqueUserMsg message.From.Id

match ml.Predict(message.TextOrCaption, usrMsgCount) with
match ml.Predict(message.TextOrCaption, usrMsgCount, message.Entities) with
| Some prediction ->
%mlActivity.SetTag("spamScoreMl", prediction.Score)

Expand Down
13 changes: 11 additions & 2 deletions src/VahterBanBot/DB.fs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,12 @@ let getUserById (userId: int64): Task<DbUser option> =
return users |> Seq.tryHead
}

[<CLIMutable>]
type SpamOrHamDb =
{ text: string
spam: bool
less_than_n_messages: bool
custom_emoji_count: int
created_at: DateTime }

let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task<SpamOrHamDb array> =
Expand All @@ -175,7 +177,11 @@ let mlData (criticalMsgCount: int) (criticalDate: DateTime) : Task<SpamOrHamDb a
//language=postgresql
let sql =
"""
WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages
WITH custom_emojis AS (SELECT message.id, COUNT(*) FILTER (WHERE entities ->> 'type' = 'custom_emoji') AS custom_emoji_count
FROM message,
LATERAL JSONB_ARRAY_ELEMENTS(raw_message -> 'entities') AS entities
GROUP BY message.id),
less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsgCount AS less_than_n_messages
FROM "user" u
LEFT JOIN message m ON u.id = m.user_id
GROUP BY u.id),
Expand All @@ -191,6 +197,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
spam_or_ham AS (SELECT x.text,
x.spam,
x.less_than_n_messages,
x.custom_emoji_count,
MAX(x.created_at) AS created_at
FROM (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text,
CASE
Expand All @@ -210,6 +217,7 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
ELSE TRUE
END AS spam,
COALESCE(l.less_than_n_messages, TRUE) AS less_than_n_messages,
COALESCE(ce.custom_emoji_count, 0) AS custom_emoji_count,
COALESCE(re_id.banned_at, re_text.banned_at, m.created_at) AS created_at
FROM (SELECT *
FROM message
Expand All @@ -218,8 +226,9 @@ WITH less_than_n_messages AS (SELECT u.id, COUNT(DISTINCT m.text) < @criticalMsg
FULL OUTER JOIN really_banned re_id
ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id
LEFT JOIN really_banned re_text ON m.text = re_text.message_text
LEFT JOIN custom_emojis ce ON m.id = ce.id
LEFT JOIN less_than_n_messages l ON m.user_id = l.id) x
GROUP BY text, spam, less_than_n_messages)
GROUP BY text, spam, less_than_n_messages, custom_emoji_count)
SELECT *
FROM spam_or_ham
ORDER BY created_at;
Expand Down
14 changes: 12 additions & 2 deletions src/VahterBanBot/ML.fs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type SpamOrHam =
{ text: string
spam: bool
lessThanNMessagesF: single
moreThanNEmojisF: single
createdAt: DateTime }

[<CLIMutable>]
Expand Down Expand Up @@ -75,6 +76,7 @@ type MachineLearning(
{ text = x.text
spam = x.spam
createdAt = x.created_at
moreThanNEmojisF = if x.custom_emoji_count > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f
lessThanNMessagesF = if x.less_than_n_messages then 1.0f else 0.0f }
)
|> fun x ->
Expand All @@ -90,7 +92,7 @@ type MachineLearning(
let dataProcessPipeline =
mlContext.Transforms.Text
.FeaturizeText(outputColumnName = "TextFeaturized", inputColumnName = "text")
.Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF";|]))
.Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF"; "moreThanNEmojisF"|]))
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(
labelColumnName = "spam",
featureColumnName = "Features",
Expand Down Expand Up @@ -121,14 +123,22 @@ type MachineLearning(
logger.LogError(ex, "Error training model")
}

member _.Predict(text: string, userMsgCount: int) =
member _.Predict(text: string, userMsgCount: int, entities: MessageEntity array) =
try
match predictionEngine with
| Some predictionEngine ->
let emojiCount =
entities
|> Option.ofObj
|> Option.defaultValue [||]
|> Seq.filter (fun x -> x.Type = MessageEntityType.CustomEmoji)
|> Seq.length

predictionEngine.Predict
{ text = text
spam = false
lessThanNMessagesF = if userMsgCount < botConf.MlTrainCriticalMsgCount then 1.0f else 0.0f
moreThanNEmojisF = if emojiCount > botConf.MlCustomEmojiThreshold then 1.0f else 0.0f
createdAt = DateTime.UtcNow }
|> Some
| None ->
Expand Down
3 changes: 2 additions & 1 deletion src/VahterBanBot/Program.fs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ let botConf =
MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float
MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single
MlWarningThreshold = getEnvOr "ML_WARNING_THRESHOLD" "0.0" |> single
MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int
MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int
MlCustomEmojiThreshold = getEnvOr "ML_CUSTOM_EMOJI_THRESHOLD" "20" |> int
MlStopWordsInChats = getEnvOr "ML_STOP_WORDS_IN_CHATS" "{}" |> fromJson }

let validateApiKey (ctx : HttpContext) =
Expand Down
1 change: 1 addition & 0 deletions src/VahterBanBot/Types.fs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type BotConfiguration =
MlSpamThreshold: single
MlWarningThreshold: single
MlMaxNumberOfIterations: int
MlCustomEmojiThreshold: int
MlStopWordsInChats: Dictionary<int64, string list> }

[<CLIMutable>]
Expand Down
Loading