diff --git a/.env.example b/.env.example index 84107c5..b061ce9 100644 --- a/.env.example +++ b/.env.example @@ -15,3 +15,10 @@ USE_FAKE_TG_API=false CLEANUP_OLD_MESSAGES=true CLEANUP_INTERVAL_SEC=86400 CLEANUP_OLD_LIMIT_SEC=259200 +ML_ENABLED=false +ML_SEED= +ML_SPAM_DELETION_ENABLED=false +ML_TRAIN_BEFORE_DATE=2021-01-01 +ML_TRAINING_SET_FRACTION=0.2 +ML_SPAM_THRESHOLD=0.5 +ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]} diff --git a/src/VahterBanBot.Tests/ContainerTestBase.fs b/src/VahterBanBot.Tests/ContainerTestBase.fs index 776c35c..3f2413d 100644 --- a/src/VahterBanBot.Tests/ContainerTestBase.fs +++ b/src/VahterBanBot.Tests/ContainerTestBase.fs @@ -4,6 +4,7 @@ open System open System.IO open System.Net.Http open System.Text +open System.Threading.Tasks open DotNet.Testcontainers.Builders open DotNet.Testcontainers.Configurations open DotNet.Testcontainers.Containers @@ -90,6 +91,14 @@ type VahterTestContainers() = .WithEnvironment("USE_FAKE_TG_API", "true") .WithEnvironment("USE_POLLING", "false") .WithEnvironment("DATABASE_URL", internalConnectionString) + .WithEnvironment("CLEANUP_OLD_MESSAGES", "false") + .WithEnvironment("ML_ENABLED", "true") + // seed data uses 2021-01-01 as a date for all messages + .WithEnvironment("ML_TRAIN_BEFORE_DATE", "2021-01-02T00:00:00Z") + .WithEnvironment("ML_SEED", "42") + .WithEnvironment("ML_SPAM_DELETION_ENABLED", "true") + .WithEnvironment("ML_SPAM_THRESHOLD", "1.0") + .WithEnvironment("ML_STOP_WORDS_IN_CHATS", """{"-42":["2"]}""") // .net 8.0 upgrade has a breaking change // https://learn.microsoft.com/en-us/dotnet/core/compatibility/containers/8.0/aspnet-port // Azure default port for containers is 80, se we need explicitly set it @@ -124,10 +133,14 @@ type VahterTestContainers() = failwith out // seed some test data - // inserting the only admin users we have - // TODO might be a script in test assembly - let! _ = dbContainer.ExecAsync([|"""INSERT INTO "user"(id, username, banned_by, banned_at, ban_reason) VALUES (34, 'vahter_1', NULL, NULL, NULL), (69, 'vahter_2', NULL, NULL, NULL);"""|]) - + let script = File.ReadAllText(CommonDirectoryPath.GetCallerFileDirectory().DirectoryPath + "/test_seed.sql") + let scriptFilePath = String.Join("/", String.Empty, "tmp", Guid.NewGuid().ToString("D"), Path.GetRandomFileName()) + do! dbContainer.CopyAsync(Encoding.Default.GetBytes script, scriptFilePath, Unix.FileMode644) + let! scriptResult = dbContainer.ExecAsync [|"psql"; "--username"; "vahter_bot_ban_service"; "--dbname"; "vahter_bot_ban"; "--file"; scriptFilePath |] + + if scriptResult.Stderr <> "" then + failwith scriptResult.Stderr + // start the app container do! appContainer.StartAsync() @@ -185,3 +198,18 @@ type VahterTestContainers() = let! count = conn.QuerySingleAsync(sql, {| chatId = msg.Chat.Id; messageId = msg.MessageId |}) return count > 0 } + + member _.MessageIsAutoBanned(msg: Message) = task { + use conn = new NpgsqlConnection(publicConnectionString) + //language=postgresql + let sql = "SELECT COUNT(*) FROM banned_by_bot WHERE banned_in_chat_id = @chatId AND message_id = @messageId" + let! count = conn.QuerySingleAsync(sql, {| chatId = msg.Chat.Id; messageId = msg.MessageId |}) + return count > 0 + } + +// workaround to wait for ML to be ready +type MlAwaitFixture() = + interface IAsyncLifetime with + member this.DisposeAsync() = Task.CompletedTask + // we assume 5 seconds is enough for model to train. Could be flaky + member this.InitializeAsync() = Task.Delay 5000 diff --git a/src/VahterBanBot.Tests/MLBanTests.fs b/src/VahterBanBot.Tests/MLBanTests.fs new file mode 100644 index 0000000..6ae35ac --- /dev/null +++ b/src/VahterBanBot.Tests/MLBanTests.fs @@ -0,0 +1,59 @@ +module VahterBanBot.Tests.MLBanTests + +open System.Net +open System.Threading.Tasks +open VahterBanBot.Tests.ContainerTestBase +open VahterBanBot.Tests.TgMessageUtils +open Xunit +open Xunit.Extensions.AssemblyFixture + +type MLBanTests(fixture: VahterTestContainers, _unused: MlAwaitFixture) = + + [] + let ``Message IS autobanned if it looks like a spam`` () = task { + // record a message, where 2 is in a training set as spam word + // ChatsToMonitor[0] doesn't have stopwords + let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "2") + let! _ = fixture.SendMessage msgUpdate + + // assert that the message got auto banned + let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message + Assert.True msgBanned + } + + [] + let ``Message is NOT autobanned if it has a stopword in specific chat`` () = task { + // record a message, where 2 is in a training set as spam word + // ChatsToMonitor[1] does have a stopword 2 + let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[1], text = "2") + let! _ = fixture.SendMessage msgUpdate + + // assert that the message got auto banned + let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message + Assert.False msgBanned + } + + [] + let ``Message is NOT autobanned if it is a known false-positive spam`` () = task { + // record a message, where 3 is in a training set as spam word + let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "a") + let! _ = fixture.SendMessage msgUpdate + + // assert that the message got auto banned + let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message + Assert.False msgBanned + } + + [] + let ``Message IS autobanned if it is a known false-negative spam`` () = task { + // record a message, where 3 is in a training set as false negative spam word + let msgUpdate = Tg.quickMsg(chat = fixture.ChatsToMonitor[0], text = "3") + let! _ = fixture.SendMessage msgUpdate + + // assert that the message got auto banned + let! msgBanned = fixture.MessageIsAutoBanned msgUpdate.Message + Assert.True msgBanned + } + + interface IAssemblyFixture + interface IClassFixture diff --git a/src/VahterBanBot.Tests/TgMessageUtils.fs b/src/VahterBanBot.Tests/TgMessageUtils.fs index f09aa72..e9703c1 100644 --- a/src/VahterBanBot.Tests/TgMessageUtils.fs +++ b/src/VahterBanBot.Tests/TgMessageUtils.fs @@ -5,7 +5,7 @@ open System.Threading open Telegram.Bot.Types type Tg() = - static let mutable i = 0L + static let mutable i = 1L // higher than the data in the test_seed.sql static let nextInt64() = Interlocked.Increment &i static let next() = nextInt64() |> int static member user (?id: int64, ?username: string, ?firstName: string) = diff --git a/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj b/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj index 526b262..acb0158 100644 --- a/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj +++ b/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj @@ -9,10 +9,14 @@ + + PreserveNewest + + diff --git a/src/VahterBanBot.Tests/test_seed.sql b/src/VahterBanBot.Tests/test_seed.sql new file mode 100644 index 0000000..de9c6ca --- /dev/null +++ b/src/VahterBanBot.Tests/test_seed.sql @@ -0,0 +1,448 @@ +INSERT INTO public."user"(id, username, banned_by, banned_at, ban_reason) +VALUES (34, 'vahter_1', NULL, NULL, NULL), + (69, 'vahter_2', NULL, NULL, NULL); + +-- insert some fake data for ML training +INSERT INTO public."user"(id, username, banned_by, banned_at, ban_reason) +VALUES (1001, 'a', NULL, NULL, NULL), + (1002, 'b', NULL, NULL, NULL), + (1003, 'c', NULL, NULL, NULL), + (1004, 'd', NULL, NULL, NULL), + (1005, 'e', NULL, NULL, NULL), + (1006, 'f', NULL, NULL, NULL), + (1007, 'g', NULL, NULL, NULL), + (1008, 'h', NULL, NULL, NULL), + (1009, 'i', NULL, NULL, NULL), + (1010, 'j', NULL, NULL, NULL); + +INSERT INTO public.message(chat_id, message_id, user_id, created_at, text, raw_message) +VALUES (-666, 10001, 1001, '2021-01-01 00:00:00', 'a', '{}'), -- false positive user banned + (-666, 10002, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10003, 1001, '2021-01-01 00:00:02', 'a', '{}'), + (-666, 10004, 1002, '2021-01-01 00:00:03', 'a', '{}'), + (-666, 10005, 1002, '2021-01-01 00:00:04', 'a', '{}'), + (-666, 10006, 1003, '2021-01-01 00:00:05', 'a', '{}'), + (-666, 10007, 1003, '2021-01-01 00:00:06', 'a', '{}'), + (-666, 10008, 1004, '2021-01-01 00:00:07', 'a', '{}'), -- false positive message banned + (-666, 10009, 1005, '2021-01-01 00:00:08', '1', '{}'), + (-666, 10010, 1005, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10001, 1001, '2021-01-01 00:00:00', 'a', '{}'), + (-42, 10002, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-42, 10003, 1001, '2021-01-01 00:00:02', 'a', '{}'), + (-42, 10004, 1002, '2021-01-01 00:00:03', 'a', '{}'), + (-42, 10005, 1002, '2021-01-01 00:00:04', 'a', '{}'), + (-42, 10006, 1003, '2021-01-01 00:00:05', 'a', '{}'), + (-42, 10007, 1003, '2021-01-01 00:00:06', 'a', '{}'), + (-42, 10008, 1004, '2021-01-01 00:00:07', '3', '{}'), -- false negative + (-42, 10009, 1006, '2021-01-01 00:00:08', '1', '{}'), + + -- to prevent small sample size, we'll copy the next line 100 times + -- this is spam + (-42, 10010, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10011, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10012, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10013, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10014, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10015, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10016, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10017, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10018, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10019, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10020, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10021, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10022, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10023, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10024, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10025, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10026, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10027, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10028, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10029, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10030, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10031, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10032, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10033, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10034, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10035, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10036, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10037, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10038, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10039, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10040, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10041, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10042, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10043, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10044, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10045, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10046, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10047, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10048, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10049, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10050, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10051, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10052, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10053, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10054, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10055, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10056, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10057, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10058, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10059, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10060, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10061, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10062, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10063, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10064, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10065, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10066, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10067, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10068, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10069, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10070, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10071, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10072, 1006, '2021-01-01 00:00:09', '1', '{}'), + (-42, 10073, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10074, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10075, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10076, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10077, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10078, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10079, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10080, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10081, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10082, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10083, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10084, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10085, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10086, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10087, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10088, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10089, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10090, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10091, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10092, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10093, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10094, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10095, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10096, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10097, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10098, 1006, '2021-01-01 00:00:09', '2', '{}'), + (-42, 10099, 1006, '2021-01-01 00:00:09', '2', '{}'), + -- this is not spam + (-666, 10100, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10101, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10102, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10103, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10104, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10105, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10106, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10107, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10108, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10109, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10110, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10111, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10112, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10113, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10114, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10115, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10116, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10117, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10118, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10119, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10120, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10121, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10122, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10123, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10124, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10125, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10126, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10127, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10128, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10129, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10130, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10131, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10132, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10133, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10134, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10135, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10136, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10137, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10138, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10139, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10140, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10141, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10142, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10143, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10144, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10145, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10146, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10147, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10148, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10149, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10150, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10151, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10152, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10153, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10154, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10155, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10156, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10157, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10158, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10159, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10160, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10161, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10162, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10163, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10164, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10165, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10166, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10167, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10168, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10169, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10170, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10171, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10172, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10173, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10174, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10175, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10176, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10177, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10178, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10179, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10180, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10181, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10182, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10183, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10184, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10185, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10186, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10187, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10188, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10189, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10190, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10191, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10192, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10193, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10194, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10195, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10196, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10197, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10198, 1001, '2021-01-01 00:00:01', 'a', '{}'), + (-666, 10199, 1001, '2021-01-01 00:00:01', 'a', '{}'), + + -- to enforce false-negative appearance + (-666, 10200, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10201, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10202, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10203, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10204, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10205, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10206, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10207, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10208, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10209, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10210, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10211, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10212, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10213, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10214, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10215, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10216, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10217, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10218, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10219, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10220, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10221, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10222, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10223, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10224, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10225, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10226, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10227, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10228, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10229, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10230, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10231, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10232, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10233, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10234, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10235, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10236, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10237, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10238, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10239, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10240, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10241, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10242, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10243, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10244, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10245, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10246, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10247, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10248, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10249, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10250, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10251, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10252, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10253, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10254, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10255, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10256, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10257, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10258, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10259, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10260, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10261, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10262, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10263, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10264, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10265, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10266, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10267, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10268, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10269, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10270, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10271, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10272, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10273, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10274, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10275, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10276, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10277, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10278, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10279, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10280, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10281, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10282, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10283, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10284, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10285, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10286, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10287, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10288, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10289, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10290, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10291, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10292, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10293, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10294, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10295, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10296, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10297, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10298, 1001, '2021-01-01 00:00:01', '3', '{}'), + (-666, 10299, 1001, '2021-01-01 00:00:01', '3', '{}'); + +INSERT INTO public.banned(id, message_id, message_text, banned_user_id, banned_at, banned_in_chat_id, banned_in_chat_username, banned_by) +VALUES (100001, 10001, 'a', 1001, '2021-01-01 00:00:00', -666, 'pro.hell', 34), + (100002, 10008, 'a', 1004, '2021-01-01 00:00:07', -666, 'pro.hell', 69), + (100003, 10009, '1', 1005, '2021-01-01 00:00:08', -666, 'pro.hell', 34), + (100004, 10010, '2', 1006, '2021-01-01 00:00:09', -42, 'dotnetru', 69); + +INSERT INTO public.false_positive_users(user_id) +VALUES (1001); + +INSERT INTO public.false_positive_messages(id) +VALUES (100002); + +INSERT INTO public.false_negative_messages(chat_id, message_id) +VALUES (-42, 10008), + (-666, 10200), + (-666, 10201), + (-666, 10202), + (-666, 10203), + (-666, 10204), + (-666, 10205), + (-666, 10206), + (-666, 10207), + (-666, 10208), + (-666, 10209), + (-666, 10210), + (-666, 10211), + (-666, 10212), + (-666, 10213), + (-666, 10214), + (-666, 10215), + (-666, 10216), + (-666, 10217), + (-666, 10218), + (-666, 10219), + (-666, 10220), + (-666, 10221), + (-666, 10222), + (-666, 10223), + (-666, 10224), + (-666, 10225), + (-666, 10226), + (-666, 10227), + (-666, 10228), + (-666, 10229), + (-666, 10230), + (-666, 10231), + (-666, 10232), + (-666, 10233), + (-666, 10234), + (-666, 10235), + (-666, 10236), + (-666, 10237), + (-666, 10238), + (-666, 10239), + (-666, 10240), + (-666, 10241), + (-666, 10242), + (-666, 10243), + (-666, 10244), + (-666, 10245), + (-666, 10246), + (-666, 10247), + (-666, 10248), + (-666, 10249), + (-666, 10250), + (-666, 10251), + (-666, 10252), + (-666, 10253), + (-666, 10254), + (-666, 10255), + (-666, 10256), + (-666, 10257), + (-666, 10258), + (-666, 10259), + (-666, 10260), + (-666, 10261), + (-666, 10262), + (-666, 10263), + (-666, 10264), + (-666, 10265), + (-666, 10266), + (-666, 10267), + (-666, 10268), + (-666, 10269), + (-666, 10270), + (-666, 10271), + (-666, 10272), + (-666, 10273), + (-666, 10274), + (-666, 10275), + (-666, 10276), + (-666, 10277), + (-666, 10278), + (-666, 10279), + (-666, 10280), + (-666, 10281), + (-666, 10282), + (-666, 10283), + (-666, 10284), + (-666, 10285), + (-666, 10286), + (-666, 10287), + (-666, 10288), + (-666, 10289), + (-666, 10290), + (-666, 10291), + (-666, 10292), + (-666, 10293), + (-666, 10294), + (-666, 10295), + (-666, 10296), + (-666, 10297), + (-666, 10298), + (-666, 10299); diff --git a/src/VahterBanBot/Bot.fs b/src/VahterBanBot/Bot.fs index 48614dc..98e6600 100644 --- a/src/VahterBanBot/Bot.fs +++ b/src/VahterBanBot/Bot.fs @@ -7,6 +7,7 @@ open System.Threading.Tasks open Microsoft.Extensions.Logging open Telegram.Bot open Telegram.Bot.Types +open VahterBanBot.ML open VahterBanBot.Types open VahterBanBot.Utils open VahterBanBot.Antispam @@ -377,19 +378,29 @@ let unban logger.LogInformation logMsg } -let warnSpamDetection +let killSpammerAutomated (botClient: ITelegramBotClient) (botConfig: BotConfiguration) (message: Message) (logger: ILogger) + (deleteMessage: bool) score = task { - use banOnReplyActivity = botActivity.StartActivity("warnSpamDetection") + use banOnReplyActivity = botActivity.StartActivity("killAutomated") %banOnReplyActivity .SetTag("spammerId", message.From.Id) .SetTag("spammerUsername", message.From.Username) + + if deleteMessage then + // delete message + do! botClient.DeleteMessageAsync(ChatId(message.Chat.Id), message.MessageId) + |> safeTaskAwait (fun e -> logger.LogError ($"Failed to delete message {message.MessageId} from chat {message.Chat.Id}", e)) + // 0 here is the bot itself + do! DbBanned.banMessage 0 message + |> DB.banUserByBot + + let msgType = if deleteMessage then "Deleted" else "Detected" + let logMsg = $"""{msgType} spam (score: {score}) in {prependUsername message.Chat.Username} ({message.Chat.Id}) from {prependUsername message.From.Username} ({message.From.Id}) with text:\n{message.Text}""" - let logMsg = $"Detected spam (score: {score}) in {prependUsername message.Chat.Username} ({message.Chat.Id}) from {prependUsername message.From.Username} ({message.From.Id}) with text:\n{message.Text}" - // log both to logger and to logs channel do! botClient.SendTextMessageAsync(ChatId(botConfig.LogsChannelId), logMsg) |> taskIgnore logger.LogInformation logMsg @@ -399,18 +410,48 @@ let justMessage (botClient: ITelegramBotClient) (botConfig: BotConfiguration) (logger: ILogger) + (ml: MachineLearning) (message: Message) = task { - let spamScore = if message.Text <> null then calcSpamScore message.Text else 0 - - if spamScore > 100 then - do! warnSpamDetection botClient botConfig message logger spamScore - use _ = + use justMessageActivity = botActivity .StartActivity("justMessage") .SetTag("fromUserId", message.From.Id) .SetTag("fromUsername", message.From.Username) - .SetTag("spamScore", spamScore) + + + if botConfig.MlEnabled && message.Text <> null then + use mlActivity = botActivity.StartActivity("mlPrediction") + + let shouldBeSkipped = + match botConfig.MlStopWordsInChats.TryGetValue message.Chat.Id with + | true, stopWords -> + stopWords + |> Seq.exists (fun sw -> message.Text.Contains(sw, StringComparison.OrdinalIgnoreCase)) + | _ -> false + %mlActivity.SetTag("skipPrediction", shouldBeSkipped) + + if not shouldBeSkipped then + match ml.Predict message.Text with + | Some prediction -> + %mlActivity.SetTag("spamScoreMl", prediction.Score) + + if prediction.Score >= botConfig.MlSpamThreshold then + // delete message + do! killSpammerAutomated botClient botConfig message logger botConfig.MlSpamDeletionEnabled prediction.Score + elif prediction.Score > 0.0f then + // just warn + do! killSpammerAutomated botClient botConfig message logger false prediction.Score + else + // not a spam + () + | None -> + // no prediction (error or not ready yet) + () + + let spamScore = if message.Text <> null then calcSpamScore message.Text else 0 + %justMessageActivity.SetTag("spamScore", spamScore) + do! message |> DbMessage.newMessage @@ -497,6 +538,7 @@ let onUpdate (botClient: ITelegramBotClient) (botConfig: BotConfiguration) (logger: ILogger) + (ml: MachineLearning) (message: Message) = task { use banOnReplyActivity = botActivity.StartActivity("onUpdate") @@ -530,5 +572,5 @@ let onUpdate // if message is not a command from authorized user, just save it ID to DB else - do! justMessage botClient botConfig logger message + do! justMessage botClient botConfig logger ml message } diff --git a/src/VahterBanBot/DB.fs b/src/VahterBanBot/DB.fs index 1125651..1b5c7f5 100644 --- a/src/VahterBanBot/DB.fs +++ b/src/VahterBanBot/DB.fs @@ -2,6 +2,7 @@ open System open System.Threading.Tasks +open Microsoft.ML.Data open Npgsql open VahterBanBot.Types open Dapper @@ -64,6 +65,21 @@ VALUES (@message_id, @message_text, @banned_user_id, @banned_at, @banned_in_chat return banned } +let banUserByBot (banned: DbBanned) : Task = + task { + use conn = new NpgsqlConnection(connString) + + //language=postgresql + let sql = + """ +INSERT INTO banned_by_bot (message_id, message_text, banned_user_id, banned_at, banned_in_chat_id, banned_in_chat_username) +VALUES (@message_id, @message_text, @banned_user_id, @banned_at, @banned_in_chat_id, @banned_in_chat_username) + """ + + let! _ = conn.ExecuteAsync(sql, banned) + return banned + } + let getUserMessages (userId: int64): Task = task { use conn = new NpgsqlConnection(connString) @@ -101,13 +117,18 @@ let getVahterStats(banInterval: TimeSpan option): Task = //language=postgresql let sql = """ -SELECT vahter.username AS vahter - , COUNT(*) AS killCountTotal - , COUNT(*) FILTER (WHERE b.banned_at > NOW() - @banInterval::INTERVAL) AS killCountInterval -FROM banned b - JOIN "user" vahter ON vahter.id = b.banned_by -GROUP BY b.banned_by, vahter.username -ORDER BY killCountTotal DESC +(SELECT vahter.username AS vahter + , COUNT(*) AS killCountTotal + , COUNT(*) FILTER (WHERE b.banned_at > NOW() - @banInterval::INTERVAL) AS killCountInterval + FROM banned b + JOIN "user" vahter ON vahter.id = b.banned_by + GROUP BY b.banned_by, vahter.username + UNION + SELECT 'bot' AS vahter + , COUNT(*) AS killCountTotal + , COUNT(*) FILTER (WHERE bbb.banned_at > NOW() - @banInterval::INTERVAL) AS killCountInterval + FROM banned_by_bot bbb) + ORDER BY killCountTotal DESC """ let! stats = conn.QueryAsync(sql, {| banInterval = banInterval |}) @@ -123,3 +144,48 @@ let getUserById (userId: int64): Task = let! users = conn.QueryAsync(sql, {| userId = userId |}) return users |> Seq.tryHead } + +type SpamOrHam = + { [] + text: string + [] + spam: bool } + +let mlData(criticalDate: DateTime) : Task = + task { + use conn = new NpgsqlConnection(connString) + + //language=postgresql + let sql = + """ +WITH really_banned AS (SELECT * + FROM banned b + -- known false positive spam messages + WHERE NOT EXISTS(SELECT 1 FROM false_positive_users fpu WHERE fpu.user_id = b.banned_user_id) + AND NOT EXISTS(SELECT 1 FROM false_positive_messages fpm WHERE fpm.id = b.id) + AND b.message_text IS NOT NULL + AND b.banned_at <= @criticalDate), + spam_or_ham AS (SELECT DISTINCT COALESCE(m.text, re_id.message_text) AS text, + CASE + -- known false negative spam messages + WHEN EXISTS(SELECT 1 + FROM false_negative_messages fnm + WHERE fnm.chat_id = m.chat_id + AND fnm.message_id = m.message_id) + THEN TRUE + WHEN re_id.banned_user_id IS NULL AND re_text.banned_user_id IS NULL + THEN FALSE + ELSE TRUE + END AS spam + FROM (SELECT * FROM message WHERE text IS NOT NULL AND created_at <= @criticalDate) m + FULL OUTER JOIN really_banned re_id + ON m.message_id = re_id.message_id AND m.chat_id = re_id.banned_in_chat_id + LEFT JOIN really_banned re_text ON m.text = re_text.message_text) +SELECT * +FROM spam_or_ham +ORDER BY RANDOM(); +""" + + let! data = conn.QueryAsync(sql, {| criticalDate = criticalDate |}) + return Array.ofSeq data + } diff --git a/src/VahterBanBot/ML.fs b/src/VahterBanBot/ML.fs new file mode 100644 index 0000000..95ff504 --- /dev/null +++ b/src/VahterBanBot/ML.fs @@ -0,0 +1,101 @@ +module VahterBanBot.ML + +open System +open System.Diagnostics +open System.Text +open System.Threading.Tasks +open Microsoft.Extensions.Hosting +open Microsoft.Extensions.Logging +open Microsoft.ML +open Microsoft.ML.Data +open Telegram.Bot +open Telegram.Bot.Types +open VahterBanBot.DB +open VahterBanBot.Types +open VahterBanBot.Utils + +[] +type Prediction = + { Score: single + text: string + spam: bool } + +type MachineLearning( + logger: ILogger, + telegramClient: ITelegramBotClient, + botConf: BotConfiguration +) = + let metricsToString(metrics: CalibratedBinaryClassificationMetrics) (duration: TimeSpan) = + let sb = StringBuilder() + %sb.AppendLine($"Model trained in {duration.TotalSeconds} seconds with following metrics:") + %sb.AppendLine($"Accuracy: {metrics.Accuracy}") + %sb.AppendLine($"AreaUnderPrecisionRecallCurve: {metrics.AreaUnderPrecisionRecallCurve}") + %sb.AppendLine($"ConfusionMatrix:\n```\n{metrics.ConfusionMatrix.GetFormattedConfusionTable()}\n```") + %sb.AppendLine($"Entropy:{metrics.Entropy}") + %sb.AppendLine($"F1Score:{metrics.F1Score}") + %sb.AppendLine($"LogLoss:{metrics.LogLoss}") + %sb.AppendLine($"LogLossReduction:{metrics.LogLossReduction}") + %sb.AppendLine($"NegativePrecision:{metrics.NegativePrecision}") + %sb.AppendLine($"NegativeRecall:{metrics.NegativeRecall}") + %sb.AppendLine($"PositivePrecision:{metrics.PositivePrecision}") + %sb.AppendLine($"PositiveRecall:{metrics.PositiveRecall}") + sb.ToString() + + let mutable predictionEngine: PredictionEngine option = None + + let trainModel() = task { + // switch to thread pool + do! Task.Yield() + + let sw = Stopwatch.StartNew() + + let mlContext = MLContext(botConf.MlSeed) + + let! data = DB.mlData botConf.MlTrainBeforeDate + + let dataView = mlContext.Data.LoadFromEnumerable data + let trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction = botConf.MlTrainingSetFraction) + let trainingData = trainTestSplit.TrainSet + let testData = trainTestSplit.TestSet + + let dataProcessPipeline = mlContext.Transforms.Text.FeaturizeText(outputColumnName = "Features", inputColumnName = "text") + let trainer = mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(labelColumnName = "spam", featureColumnName = "Features") + let trainingPipeline = dataProcessPipeline.Append(trainer) + + let trainedModel = trainingPipeline.Fit(trainingData) + predictionEngine <- Some(mlContext.Model.CreatePredictionEngine(trainedModel)) + + let predictions = trainedModel.Transform(testData) + let metrics = mlContext.BinaryClassification.Evaluate(data = predictions, labelColumnName = "spam", scoreColumnName = "Score") + + sw.Stop() + + let metricsStr = metricsToString metrics sw.Elapsed + logger.LogInformation metricsStr + do! telegramClient.SendTextMessageAsync(ChatId(botConf.LogsChannelId), metricsStr) + |> taskIgnore + } + + member _.Predict(text: string) = + try + match predictionEngine with + | Some predictionEngine -> + predictionEngine.Predict({ text = text; spam = false }) + |> Some + | None -> + logger.LogInformation "Model not trained yet" + None + with ex -> + logger.LogError(ex, "Error predicting") + None + + interface IHostedService with + member this.StartAsync _ = task { + if botConf.MlEnabled then + try + do! trainModel() + with ex -> + logger.LogError(ex, "Error training model") + } + + member this.StopAsync _ = Task.CompletedTask diff --git a/src/VahterBanBot/Program.fs b/src/VahterBanBot/Program.fs index f2ee6ae..5689be0 100644 --- a/src/VahterBanBot/Program.fs +++ b/src/VahterBanBot/Program.fs @@ -16,6 +16,7 @@ open Giraffe open Microsoft.Extensions.DependencyInjection open Telegram.Bot.Types.Enums open VahterBanBot.Cleanup +open VahterBanBot.ML open VahterBanBot.Utils open VahterBanBot.Bot open VahterBanBot.Types @@ -44,7 +45,14 @@ let botConf = UseFakeTgApi = getEnvOr "USE_FAKE_TG_API" "false" |> bool.Parse CleanupOldMessages = getEnvOr "CLEANUP_OLD_MESSAGES" "true" |> bool.Parse CleanupInterval = getEnvOr "CLEANUP_INTERVAL_SEC" "86400" |> int |> TimeSpan.FromSeconds - CleanupOldLimit = getEnvOr "CLEANUP_OLD_LIMIT_SEC" "259200" |> int |> TimeSpan.FromSeconds } + CleanupOldLimit = getEnvOr "CLEANUP_OLD_LIMIT_SEC" "259200" |> int |> TimeSpan.FromSeconds + MlEnabled = getEnvOr "ML_ENABLED" "false" |> bool.Parse + MlSeed = getEnvOrWith "ML_SEED" (Nullable()) (int >> Nullable) + MlSpamDeletionEnabled = getEnvOr "ML_SPAM_DELETION_ENABLED" "false" |> bool.Parse + MlTrainBeforeDate = getEnvOrWith "ML_TRAIN_BEFORE_DATE" DateTime.UtcNow (DateTimeOffset.Parse >> _.UtcDateTime) + MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float + MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single + MlStopWordsInChats = getEnvOr "ML_STOP_WORDS_IN_CHATS" "{}" |> JsonConvert.DeserializeObject<_> } let validateApiKey (ctx : HttpContext) = match ctx.TryGetRequestHeader "X-Telegram-Bot-Api-Secret-Token" with @@ -60,6 +68,8 @@ let builder = WebApplication.CreateBuilder() .AddGiraffe() .AddHostedService() .AddHostedService() + .AddSingleton() + .AddHostedService(fun sp -> sp.GetRequiredService()) .AddHttpClient("telegram_bot_client") .AddTypedClient(fun httpClient sp -> let options = TelegramBotClientOptions(botConf.BotToken) @@ -132,9 +142,10 @@ let webApp = choose [ use scope = ctx.RequestServices.CreateScope() let telegramClient = scope.ServiceProvider.GetRequiredService() + let ml = scope.ServiceProvider.GetRequiredService() let logger = ctx.GetLogger() try - do! onUpdate telegramClient botConf (ctx.GetLogger "VahterBanBot.Bot") update.Message + do! onUpdate telegramClient botConf (ctx.GetLogger "VahterBanBot.Bot") ml update.Message %topActivity.SetTag("update-error", false) with e -> logger.LogError(e, $"Unexpected error while processing update: {updateBodyJson}") @@ -160,7 +171,8 @@ if botConf.UsePolling then let ctx = app.Services.CreateScope() let logger = ctx.ServiceProvider.GetRequiredService>() let client = ctx.ServiceProvider.GetRequiredService() - do! onUpdate client botConf logger update.Message + let ml = ctx.ServiceProvider.GetRequiredService() + do! onUpdate client botConf logger ml update.Message } member x.HandlePollingErrorAsync (botClient: ITelegramBotClient, ex: Exception, cancellationToken: CancellationToken) = Task.CompletedTask diff --git a/src/VahterBanBot/Types.fs b/src/VahterBanBot/Types.fs index 1da2bff..e5f0099 100644 --- a/src/VahterBanBot/Types.fs +++ b/src/VahterBanBot/Types.fs @@ -20,7 +20,14 @@ type BotConfiguration = UsePolling: bool CleanupOldMessages: bool CleanupInterval: TimeSpan - CleanupOldLimit: TimeSpan } + CleanupOldLimit: TimeSpan + MlEnabled: bool + MlSeed: Nullable + MlSpamDeletionEnabled: bool + MlTrainBeforeDate: DateTime + MlTrainingSetFraction: float + MlSpamThreshold: single + MlStopWordsInChats: Dictionary } [] type DbUser = diff --git a/src/VahterBanBot/Utils.fs b/src/VahterBanBot/Utils.fs index 6213479..6178d4c 100644 --- a/src/VahterBanBot/Utils.fs +++ b/src/VahterBanBot/Utils.fs @@ -22,6 +22,12 @@ let getEnvWith name action = if value <> null then action value +let getEnvOrWith name defaultValue action = + let value = Environment.GetEnvironmentVariable name + if value <> null then + action value + else defaultValue + let prependUsername (s: string) = if isNull s then null diff --git a/src/VahterBanBot/VahterBanBot.fsproj b/src/VahterBanBot/VahterBanBot.fsproj index 5054a1a..7075343 100644 --- a/src/VahterBanBot/VahterBanBot.fsproj +++ b/src/VahterBanBot/VahterBanBot.fsproj @@ -10,6 +10,7 @@ + @@ -35,6 +36,7 @@ + diff --git a/src/migrations/V8__ml-stuff.sql b/src/migrations/V8__ml-stuff.sql new file mode 100644 index 0000000..5ddbb47 --- /dev/null +++ b/src/migrations/V8__ml-stuff.sql @@ -0,0 +1,20 @@ +CREATE TABLE banned_by_bot +( + id BIGSERIAL PRIMARY KEY, + message_id INTEGER NULL, + message_text TEXT, + banned_user_id BIGINT NOT NULL + REFERENCES "user" (id), + banned_at TIMESTAMPTZ NOT NULL, + banned_in_chat_id BIGINT NULL, + banned_in_chat_username TEXT NULL +); + +CREATE INDEX banned_by_bot_banned_user_id_idx + ON banned_by_bot (banned_user_id); + +CREATE INDEX banned_by_bot_banned_in_chat_id_idx + ON banned_by_bot (banned_in_chat_id); + +CREATE INDEX banned_by_bot_message_id_idx + ON banned_by_bot (message_id);