Skip to content

Commit

Permalink
Added ML (#29)
Browse files Browse the repository at this point in the history
* added ML training on start (might be cached in future?) and ability to autoban
  • Loading branch information
Szer authored Jul 20, 2024
1 parent bc4cd2e commit 7858368
Show file tree
Hide file tree
Showing 14 changed files with 829 additions and 27 deletions.
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ USE_FAKE_TG_API=false
CLEANUP_OLD_MESSAGES=true
CLEANUP_INTERVAL_SEC=86400
CLEANUP_OLD_LIMIT_SEC=259200
ML_ENABLED=false
ML_SEED=
ML_SPAM_DELETION_ENABLED=false
ML_TRAIN_BEFORE_DATE=2021-01-01
ML_TRAINING_SET_FRACTION=0.2
ML_SPAM_THRESHOLD=0.5
ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]}
36 changes: 32 additions & 4 deletions src/VahterBanBot.Tests/ContainerTestBase.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ open System
open System.IO
open System.Net.Http
open System.Text
open System.Threading.Tasks
open DotNet.Testcontainers.Builders
open DotNet.Testcontainers.Configurations
open DotNet.Testcontainers.Containers
Expand Down Expand Up @@ -90,6 +91,14 @@ type VahterTestContainers() =
.WithEnvironment("USE_FAKE_TG_API", "true")
.WithEnvironment("USE_POLLING", "false")
.WithEnvironment("DATABASE_URL", internalConnectionString)
.WithEnvironment("CLEANUP_OLD_MESSAGES", "false")
.WithEnvironment("ML_ENABLED", "true")
// seed data uses 2021-01-01 as a date for all messages
.WithEnvironment("ML_TRAIN_BEFORE_DATE", "2021-01-02T00:00:00Z")
.WithEnvironment("ML_SEED", "42")
.WithEnvironment("ML_SPAM_DELETION_ENABLED", "true")
.WithEnvironment("ML_SPAM_THRESHOLD", "1.0")
.WithEnvironment("ML_STOP_WORDS_IN_CHATS", """{"-42":["2"]}""")
// .net 8.0 upgrade has a breaking change
// https://learn.microsoft.com/en-us/dotnet/core/compatibility/containers/8.0/aspnet-port
// Azure default port for containers is 80, se we need explicitly set it
Expand Down Expand Up @@ -124,10 +133,14 @@ type VahterTestContainers() =
failwith out

// seed some test data
// inserting the only admin users we have
// TODO might be a script in test assembly
let! _ = dbContainer.ExecAsync([|"""INSERT INTO "user"(id, username, banned_by, banned_at, ban_reason) VALUES (34, 'vahter_1', NULL, NULL, NULL), (69, 'vahter_2', NULL, NULL, NULL);"""|])

let script = File.ReadAllText(CommonDirectoryPath.GetCallerFileDirectory().DirectoryPath + "/test_seed.sql")
let scriptFilePath = String.Join("/", String.Empty, "tmp", Guid.NewGuid().ToString("D"), Path.GetRandomFileName())
do! dbContainer.CopyAsync(Encoding.Default.GetBytes script, scriptFilePath, Unix.FileMode644)
let! scriptResult = dbContainer.ExecAsync [|"psql"; "--username"; "vahter_bot_ban_service"; "--dbname"; "vahter_bot_ban"; "--file"; scriptFilePath |]

if scriptResult.Stderr <> "" then
failwith scriptResult.Stderr

// start the app container
do! appContainer.StartAsync()

Expand Down Expand Up @@ -185,3 +198,18 @@ type VahterTestContainers() =
let! count = conn.QuerySingleAsync<int>(sql, {| chatId = msg.Chat.Id; messageId = msg.MessageId |})
return count > 0
}

member _.MessageIsAutoBanned(msg: Message) = task {
use conn = new NpgsqlConnection(publicConnectionString)
//language=postgresql
let sql = "SELECT COUNT(*) FROM banned_by_bot WHERE banned_in_chat_id = @chatId AND message_id = @messageId"
let! count = conn.QuerySingleAsync<int>(sql, {| chatId = msg.Chat.Id; messageId = msg.MessageId |})
return count > 0
}

// workaround to wait for ML to be ready
type MlAwaitFixture() =
interface IAsyncLifetime with
member this.DisposeAsync() = Task.CompletedTask
// we assume 5 seconds is enough for model to train. Could be flaky
member this.InitializeAsync() = Task.Delay 5000
59 changes: 59 additions & 0 deletions src/VahterBanBot.Tests/MLBanTests.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module VahterBanBot.Tests.MLBanTests

open System.Net
open System.Threading.Tasks
open VahterBanBot.Tests.ContainerTestBase
open VahterBanBot.Tests.TgMessageUtils
open Xunit
open Xunit.Extensions.AssemblyFixture

type MLBanTests(fixture: VahterTestContainers, _unused: MlAwaitFixture) =

[<Fact>]
let ``Message IS autobanned if it looks like a spam`` () = task {
// record a message, where 2 is in a training set as spam word
// ChatsToMonitor[0] doesn't have stopwords
let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "2")
let! _ = fixture.SendMessage msgUpdate

// assert that the message got auto banned
let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message
Assert.True msgBanned
}

[<Fact>]
let ``Message is NOT autobanned if it has a stopword in specific chat`` () = task {
// record a message, where 2 is in a training set as spam word
// ChatsToMonitor[1] does have a stopword 2
let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[1], text = "2")
let! _ = fixture.SendMessage msgUpdate

// assert that the message got auto banned
let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message
Assert.False msgBanned
}

[<Fact>]
let ``Message is NOT autobanned if it is a known false-positive spam`` () = task {
// record a message, where 3 is in a training set as spam word
let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "a")
let! _ = fixture.SendMessage msgUpdate

// assert that the message got auto banned
let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message
Assert.False msgBanned
}

[<Fact>]
let ``Message IS autobanned if it is a known false-negative spam`` () = task {
// record a message, where 3 is in a training set as false negative spam word
let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "3")
let! _ = fixture.SendMessage msgUpdate

// assert that the message got auto banned
let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message
Assert.True msgBanned
}

interface IAssemblyFixture<VahterTestContainers>
interface IClassFixture<MlAwaitFixture>
2 changes: 1 addition & 1 deletion src/VahterBanBot.Tests/TgMessageUtils.fs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open System.Threading
open Telegram.Bot.Types

type Tg() =
static let mutable i = 0L
static let mutable i = 1L // higher than the data in the test_seed.sql
static let nextInt64() = Interlocked.Increment &i
static let next() = nextInt64() |> int
static member user (?id: int64, ?username: string, ?firstName: string) =
Expand Down
4 changes: 4 additions & 0 deletions src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
</PropertyGroup>

<ItemGroup>
<Content Include="test_seed.sql">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Compile Include="TgMessageUtils.fs" />
<Compile Include="ContainerTestBase.fs" />
<Compile Include="BaseTests.fs" />
<Compile Include="MessageTests.fs" />
<Compile Include="MLBanTests.fs" />
<Compile Include="BanTests.fs" />
<Compile Include="PingTests.fs" />
<Compile Include="Program.fs"/>
Expand Down
Loading

0 comments on commit 7858368

Please sign in to comment.