-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3ac2003
commit e6ff23f
Showing
5 changed files
with
214 additions
and
1 deletion.
There are no files selected for viewing
14 changes: 14 additions & 0 deletions
14
...dreams/loritta/helper/serverresponses/sparklypower/HowToBuyPesadelosNaiveBayesResponse.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package net.perfectdreams.loritta.helper.serverresponses.sparklypower | ||
|
||
import net.perfectdreams.loritta.api.messages.LorittaReply | ||
|
||
class HowToBuyPesadelosNaiveBayesResponse(sparklyNaiveBayes: SparklyNaiveBayes) : SparklyNaiveBayesResponse(SparklyNaiveBayes.QuestionCategory.BUY_PESADELOS, sparklyNaiveBayes) { | ||
override fun getResponse(message: String): List<LorittaReply> { | ||
return listOf( | ||
LorittaReply( | ||
"Você pode comprar pesadelos acessando o meu website! https://sparklypower.net/loja", | ||
"<:pantufa_coffee:853048446981111828>" | ||
) | ||
) | ||
} | ||
} |
91 changes: 91 additions & 0 deletions
91
...kotlin/net/perfectdreams/loritta/helper/serverresponses/sparklypower/SparklyNaiveBayes.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package net.perfectdreams.loritta.helper.serverresponses.sparklypower | ||
|
||
import net.perfectdreams.loritta.helper.utils.NaiveBayes | ||
import java.text.Normalizer | ||
|
||
class SparklyNaiveBayes { | ||
val classifier = NaiveBayes<QuestionCategory>() | ||
|
||
fun setup() { | ||
classifier.train( | ||
QuestionCategory.BUY_PESADELOS, | ||
listOf( | ||
"como ganho pesadelos", | ||
"como consigo pesadelos", | ||
"como compra pesadelos" | ||
) | ||
) | ||
|
||
classifier.train( | ||
QuestionCategory.SPARKLY_IP, | ||
listOf( | ||
"qual é o IP do SparklyPower" | ||
) | ||
) | ||
|
||
classifier.train( | ||
QuestionCategory.BUY_VIP, | ||
listOf( | ||
"como compra VIP no SparklyPower" | ||
) | ||
) | ||
} | ||
|
||
private fun train(category: QuestionCategory, documents: List<String>) = classifier.train( | ||
category, documents.map { normalizeNaiveBayesInput(it) } | ||
) | ||
|
||
fun main() { | ||
val documents = listOf( | ||
Pair("qual é o IP do SparklyPower", "ip"), | ||
Pair("manda o IP do SparklyPower", "ip"), | ||
Pair("como comprar VIP", "vip"), | ||
Pair("quero comprar VIP", "vip"), | ||
Pair("como eu protejo um terreno?", "terreno"), | ||
Pair("como proteger um terreno?", "claim"), | ||
) | ||
|
||
val classifier = NaiveBayes<String>() | ||
classifier.train( | ||
"pesadelos", | ||
listOf( | ||
"como ganho pesadelos", | ||
"como consigo pesadelos", | ||
"como compra pesadelos", | ||
"como comprar pesadelos", | ||
"como posso comprar pesadelos" | ||
).map { normalizeNaiveBayesInput(it) } | ||
) | ||
classifier.train(documents) | ||
} | ||
|
||
fun replaceShortenedWordsWithLongWords(source: String) = source | ||
.replace(Regex("\\bSparkly\\b", RegexOption.IGNORE_CASE), "SparklyPower") | ||
.replace(Regex("\\bservidor\\b", RegexOption.IGNORE_CASE), "SparklyPower") | ||
.replace(Regex("\\bserver\\b", RegexOption.IGNORE_CASE), "SparklyPower") | ||
.replace(Regex("\\bpesa\\b", RegexOption.IGNORE_CASE), "pesadelos") | ||
.replace(Regex("\\beh\\b", RegexOption.IGNORE_CASE), "é") | ||
.replace(Regex("\\badissiona\\b", RegexOption.IGNORE_CASE), "adiciona") | ||
.replace(Regex("\\badissiono\\b", RegexOption.IGNORE_CASE), "adiciono") | ||
|
||
|
||
fun normalizeNaiveBayesInput(source: String) = source | ||
.normalize() | ||
.replace("?", "") | ||
.replace("!", "") | ||
.replace(".", "") | ||
.replace(",", "") | ||
.trim() | ||
|
||
private fun String.normalize(): String { | ||
val normalizedString = Normalizer.normalize(this, Normalizer.Form.NFD) | ||
val regex = "\\p{InCombiningDiacriticalMarks}+".toRegex() | ||
return regex.replace(normalizedString, "") | ||
} | ||
|
||
enum class QuestionCategory { | ||
BUY_PESADELOS, | ||
BUY_VIP, | ||
SPARKLY_IP | ||
} | ||
} |
33 changes: 33 additions & 0 deletions
33
...et/perfectdreams/loritta/helper/serverresponses/sparklypower/SparklyNaiveBayesResponse.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package net.perfectdreams.loritta.helper.serverresponses.sparklypower | ||
|
||
import mu.KotlinLogging | ||
import net.perfectdreams.loritta.helper.serverresponses.LorittaResponse | ||
|
||
abstract class SparklyNaiveBayesResponse( | ||
private val category: SparklyNaiveBayes.QuestionCategory, | ||
private val sparklyNaiveBayes: SparklyNaiveBayes | ||
) : LorittaResponse { | ||
private val logger = KotlinLogging.logger {} | ||
|
||
override fun handleResponse(message: String): Boolean { | ||
val normalizedMessage = sparklyNaiveBayes.normalizeNaiveBayesInput(sparklyNaiveBayes.replaceShortenedWordsWithLongWords(message)) | ||
|
||
val classifications = sparklyNaiveBayes.classifier.detailedClassification(normalizedMessage) | ||
.entries | ||
.sortedBy { it.value } | ||
|
||
logger.info { "Results for $normalizedMessage: $classifications" } | ||
|
||
// Get the best classification that matches our message | ||
val bestMatch = classifications.last() | ||
// Not the same category? Bail out! | ||
if (bestMatch.key != category) | ||
return false | ||
|
||
val secondBestMatch = classifications[classifications.size - 2] | ||
val diffBetweenBestMatchAndSecondBestMatch = bestMatch.value - secondBestMatch.value | ||
|
||
// We compare between the second best because if two questions are very similar, then the question is a bit confusing | ||
return bestMatch.value >= 0.4 && diffBetweenBestMatchAndSecondBestMatch >= 0.2 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
src/main/kotlin/net/perfectdreams/loritta/helper/utils/NaiveBayes.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package net.perfectdreams.loritta.helper.utils | ||
|
||
// Thanks ChatGPT | ||
class NaiveBayes<CATEGORYTYPE> { | ||
private val classCounts: MutableMap<CATEGORYTYPE, Int> = HashMap() | ||
val wordCounts: MutableMap<CATEGORYTYPE, MutableMap<String, Int>> = HashMap() | ||
private var totalDocuments: Int = 0 | ||
|
||
fun train(category: CATEGORYTYPE, documents: List<String>) = train( | ||
documents.map { it to category } | ||
) | ||
|
||
fun train(documents: List<Pair<String, CATEGORYTYPE>>) { | ||
for ((text, label) in documents) { | ||
classCounts[label] = classCounts.getOrDefault(label, 0) + 1 | ||
totalDocuments++ | ||
|
||
val words = text.split("\\s+".toRegex()).map { it.toLowerCase() } | ||
if (!wordCounts.containsKey(label)) { | ||
wordCounts[label] = HashMap() | ||
} | ||
val labelWordCounts = wordCounts[label]!! | ||
|
||
for (word in words) { | ||
labelWordCounts[word] = labelWordCounts.getOrDefault(word, 0) + 1 | ||
} | ||
} | ||
} | ||
|
||
fun classify(text: String) = detailedClassification(text).entries.maxBy { it.value }.key | ||
|
||
fun detailedClassification(text: String): Map<CATEGORYTYPE, Double> { | ||
val words = text.split("\\s+".toRegex()).map { it.lowercase() } | ||
val classProbabilities = mutableMapOf<CATEGORYTYPE, Double>() | ||
|
||
for (label in classCounts.keys) { | ||
val logProbability = Math.log(classCounts[label]!!.toDouble() / totalDocuments) | ||
var totalWordCountForClass = 0 | ||
wordCounts[label]?.values?.forEach { totalWordCountForClass += it } | ||
|
||
var logProbabilitySum = logProbability | ||
for (word in words) { | ||
val wordCount = wordCounts[label]?.getOrDefault(word, 0) ?: 0 | ||
logProbabilitySum += java.lang.Math.log((wordCount + 1).toDouble() / (totalWordCountForClass + wordCounts.size)) | ||
} | ||
|
||
classProbabilities[label] = logProbabilitySum | ||
} | ||
|
||
// Convert log probabilities to normal probabilities | ||
val maxLogProbability = classProbabilities.values.maxOrNull() ?: Double.NEGATIVE_INFINITY | ||
var sumProbabilities = 0.0 | ||
for (label in classProbabilities.keys) { | ||
val probability = Math.exp(classProbabilities[label]!! - maxLogProbability) | ||
classProbabilities[label] = probability | ||
sumProbabilities += probability | ||
} | ||
|
||
// Normalize the probabilities | ||
for (label in classProbabilities.keys) { | ||
classProbabilities[label] = classProbabilities[label]!! / sumProbabilities | ||
} | ||
|
||
return classProbabilities | ||
} | ||
} |