diff --git a/.env.example b/.env.example index 2bbb84c..48b06c7 100644 --- a/.env.example +++ b/.env.example @@ -31,5 +31,6 @@ ML_TRAINING_SET_FRACTION=0.2 ML_SPAM_THRESHOLD=0.5 ML_WARNING_THRESHOLD=0.0 ML_STOP_WORDS_IN_CHATS={"-123":["word1","word2"]} +ML_MAX_NUMBER_OF_ITERATIONS=100 UPDATE_CHAT_ADMINS_INTERVAL_SEC=86400 UPDATE_CHAT_ADMINS=true diff --git a/src/VahterBanBot.Tests/ContainerTestBase.fs b/src/VahterBanBot.Tests/ContainerTestBase.fs index 338362e..e87e36c 100644 --- a/src/VahterBanBot.Tests/ContainerTestBase.fs +++ b/src/VahterBanBot.Tests/ContainerTestBase.fs @@ -30,6 +30,7 @@ type VahterTestContainers() = let mutable publicConnectionString: string = null // base image for the app, we'll build exactly how we build it in Azure + let buildLogger = StringLogger() let image = ImageFromDockerfileBuilder() .WithDockerfileDirectory(solutionDir, String.Empty) @@ -39,6 +40,8 @@ type VahterTestContainers() = .WithBuildArgument("RESOURCE_REAPER_SESSION_ID", ResourceReaper.DefaultSessionId.ToString("D")) // it might speed up the process to not clean up the base image .WithCleanUp(false) + .WithDeleteIfExists(true) + .WithLogger(buildLogger) .Build() // private network for the containers @@ -113,17 +116,27 @@ type VahterTestContainers() = .DependsOn(flywayContainer) .WithWaitStrategy(Wait.ForUnixContainer().UntilPortIsAvailable(80)) .Build() + + let startContainers() = task { + try + // start building the image and spin up db at the same time + let imageTask = image.CreateAsync() + let dbTask = dbContainer.StartAsync() + + // wait for both to finish + do! imageTask + do! dbTask + with + | e -> + let logs = buildLogger.ExtractMessages() + let errorMessage = "Container startup failure, logs:\n" + if String.IsNullOrWhiteSpace logs then "" else logs + raise <| Exception(errorMessage, e) + } interface IAsyncLifetime with member this.InitializeAsync() = task { try - // start building the image and spin up db at the same time - let imageTask = image.CreateAsync() - let dbTask = dbContainer.StartAsync() - - // wait for both to finish - do! imageTask - do! dbTask + do! startContainers() publicConnectionString <- $"Server=127.0.0.1;Database=vahter_bot_ban;Port={dbContainer.GetMappedPublicPort(5432)};User Id=vahter_bot_ban_service;Password=vahter_bot_ban_service;Include Error Detail=true;Minimum Pool Size=1;Maximum Pool Size=20;Max Auto Prepare=100;Auto Prepare Min Usages=1;Trust Server Certificate=true;" // initialize DB with the schema, database and a DB user @@ -158,9 +171,10 @@ type VahterTestContainers() = httpClient.BaseAddress <- uri httpClient.DefaultRequestHeaders.Add("X-Telegram-Bot-Api-Secret-Token", "OUR_SECRET") finally - let struct (_, err) = appContainer.GetLogsAsync().Result - if err <> "" then - failwith err + if appContainer.State <> TestcontainersStates.Undefined then + let struct (_stdout, err) = appContainer.GetLogsAsync().Result + if err <> "" then + failwith err } member this.DisposeAsync() = task { // stop all the containers, flyway might be dead already diff --git a/src/VahterBanBot.Tests/Logging.fs b/src/VahterBanBot.Tests/Logging.fs new file mode 100644 index 0000000..85c6da3 --- /dev/null +++ b/src/VahterBanBot.Tests/Logging.fs @@ -0,0 +1,16 @@ +namespace VahterBanBot.Tests + +open System +open Microsoft.Extensions.Logging + +type StringLogger() = + let lockObj = obj() + let messages = ResizeArray() + interface ILogger with + member this.BeginScope _ = null + member this.IsEnabled _ = true + member this.Log(logLevel, _eventId, state, ex, formatter) = + lock lockObj (fun() -> + messages.Add($"[{logLevel}] {formatter.Invoke(state, ex)}")) + + member _.ExtractMessages(): string = lock lockObj (fun() -> String.Join("\n", messages)) \ No newline at end of file diff --git a/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj b/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj index ffa6474..9ed5220 100644 --- a/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj +++ b/src/VahterBanBot.Tests/VahterBanBot.Tests.fsproj @@ -11,6 +11,7 @@ + diff --git a/src/VahterBanBot/ML.fs b/src/VahterBanBot/ML.fs index 3518992..4a69409 100644 --- a/src/VahterBanBot/ML.fs +++ b/src/VahterBanBot/ML.fs @@ -65,6 +65,8 @@ type MachineLearning( let trainDate = DateTime.UtcNow - botConf.MlTrainInterval let! rawData = DB.mlData botConf.MlTrainCriticalMsgCount trainDate + logger.LogInformation $"Training data count: {rawData.Length}" + let data = rawData |> Array.map (fun x -> @@ -87,20 +89,32 @@ type MachineLearning( mlContext.Transforms.Text .FeaturizeText(outputColumnName = "TextFeaturized", inputColumnName = "text") .Append(mlContext.Transforms.Concatenate(outputColumnName = "Features", inputColumnNames = [|"TextFeaturized"; "lessThanNMessagesF";|])) - .Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(labelColumnName = "spam", featureColumnName = "Features")) + .Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression( + labelColumnName = "spam", + featureColumnName = "Features", + maximumNumberOfIterations = botConf.MlMaxNumberOfIterations + )) + + logger.LogInformation "Fitting model..." let trainedModel = dataProcessPipeline.Fit(trainingData) + + logger.LogInformation "Evaluating model..." + predictionEngine <- Some(mlContext.Model.CreatePredictionEngine(trainedModel)) let predictions = trainedModel.Transform(testData) let metrics = mlContext.BinaryClassification.Evaluate(data = predictions, labelColumnName = "spam", scoreColumnName = "Score") + logger.LogInformation "Model transformation complete" + sw.Stop() let metricsStr = metricsToString metrics sw.Elapsed logger.LogInformation metricsStr do! telegramClient.SendTextMessageAsync(ChatId.Int(botConf.LogsChannelId), metricsStr, parseMode = ParseMode.Markdown) |> taskIgnore + logger.LogInformation "Model trained" with ex -> logger.LogError(ex, "Error training model") } diff --git a/src/VahterBanBot/Program.fs b/src/VahterBanBot/Program.fs index 8bcf5ac..74399e9 100644 --- a/src/VahterBanBot/Program.fs +++ b/src/VahterBanBot/Program.fs @@ -71,6 +71,7 @@ let botConf = MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float MlSpamThreshold = getEnvOr "ML_SPAM_THRESHOLD" "0.5" |> single MlWarningThreshold = getEnvOr "ML_WARNING_THRESHOLD" "0.0" |> single + MlMaxNumberOfIterations = getEnvOr "ML_MAX_NUMBER_OF_ITERATIONS" "50" |> int MlStopWordsInChats = getEnvOr "ML_STOP_WORDS_IN_CHATS" "{}" |> fromJson } let validateApiKey (ctx : HttpContext) = diff --git a/src/VahterBanBot/Types.fs b/src/VahterBanBot/Types.fs index 87635c4..36f6dc3 100644 --- a/src/VahterBanBot/Types.fs +++ b/src/VahterBanBot/Types.fs @@ -39,6 +39,7 @@ type BotConfiguration = MlTrainingSetFraction: float MlSpamThreshold: single MlWarningThreshold: single + MlMaxNumberOfIterations: int MlStopWordsInChats: Dictionary } []