Skip to content

Commit

Permalink
Added ScoreAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaro committed Nov 18, 2023
1 parent 18e144e commit 349f280
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ data class QueryResult(val retrievables: List<QueryResultRetrievable>) {
//map partOf relations the right way around
retrieved.forEach { retrieved: Retrieved ->
if (retrieved is Retrieved.RetrievedWithRelationship) {
retrieved.relationships.filter { it.second == "partOf" && it.first == retrieved.id }.forEach {
results[it.third.toString()]?.parts?.add(retrieved.id.toString())
retrieved.relationships.filter { it.pred == "partOf" && it.sub.first == retrieved.id }.forEach {
results[it.obj.first.toString()]?.parts?.add(retrieved.id.toString())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.vitrivr.engine.query.transform

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.retrieve.Transformer

class ScoreAggregator(
override val input: Operator<Retrieved>,
private val aggregationMode: AggregationMode = AggregationMode.MAX,
val relationshps: Set<String> = setOf("partOf")
) : Transformer<Retrieved, Retrieved.RetrievedWithScore> {

enum class AggregationMode {
MAX,
MEAN,
MIN
}

override fun toFlow(scope: CoroutineScope): Flow<Retrieved.RetrievedWithScore> =
input.toFlow(scope).map { retrieved ->
when {
retrieved is Retrieved.RetrievedWithScore -> retrieved//pass through
retrieved is Retrieved.RetrievedWithRelationship -> { //aggregate

val scores =
retrieved.relationships.filter { rel -> rel.pred in this.relationshps && rel.obj.first == retrieved.id }
.map { if (it.sub.second is Retrieved.RetrievedWithScore) (it.sub.second as Retrieved.RetrievedWithScore).score else 0f }

val score = if (scores.isEmpty()) {
0f
} else {
when (aggregationMode) {
AggregationMode.MAX -> scores.max()
AggregationMode.MEAN -> scores.sum() / scores.size
AggregationMode.MIN -> scores.min()
}
}

Retrieved.PlusScore(retrieved, score)

}

else -> Retrieved.PlusScore(retrieved, 0f)

}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.vitrivr.engine.query.transform

import org.vitrivr.engine.core.model.metamodel.Schema
import org.vitrivr.engine.core.model.retrievable.Retrieved
import org.vitrivr.engine.core.operators.Operator
import org.vitrivr.engine.core.operators.retrieve.Transformer
import org.vitrivr.engine.core.operators.retrieve.TransformerFactory

class ScoreAggregatorFactory : TransformerFactory<Retrieved, Retrieved.RetrievedWithScore> {
override fun newTransformer(
input: Operator<Retrieved>,
schema: Schema,
properties: Map<String, String>
): Transformer<Retrieved, Retrieved.RetrievedWithScore> {

val aggregation = properties["aggregation"]?.uppercase()?.let {
try {
ScoreAggregator.AggregationMode.valueOf(it)
} catch (e: IllegalArgumentException) {
null
}
} ?: ScoreAggregator.AggregationMode.MAX

val relationships = properties["relationships"]?.split(",")?.map { s -> s.trim() }?.toSet() ?: setOf("partOf")

return ScoreAggregator(input, aggregation, relationships)
}
}
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
org.vitrivr.engine.query.transform.RelationExpanderFactory
org.vitrivr.engine.query.transform.RelationExpanderFactory
org.vitrivr.engine.query.transform.ScoreAggregatorFactory

0 comments on commit 349f280

Please sign in to comment.