Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

groupBy {}.aggregate { keys } #662

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
package org.jetbrains.kotlinx.dataframe.aggregation

public abstract class AggregateGroupedDsl<out T> : AggregateDsl<T>()
import org.jetbrains.kotlinx.dataframe.AnyRow

public abstract class AggregateGroupedDsl<out T> : AggregateDsl<T>() {
public abstract val keys: AnyRow
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,24 @@ internal class GroupByImpl<T, G>(
override fun <R> updateGroups(transform: Selector<DataFrame<G>, DataFrame<R>>) =
df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy<T, R>

override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName)
override fun toDataFrame(groupedColumnName: String?) =
if (groupedColumnName == null || groupedColumnName == groups.name()) {
df
} else {
df.rename(groups).into(groupedColumnName)
}

override fun toString() = df.toString()

override fun remainingColumnsSelector(): ColumnsSelector<*, *> =
keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } }

override fun <R> aggregate(body: AggregateGroupedBody<G, R>) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast<G>()
override fun <R> aggregate(body: AggregateGroupedBody<G, R>) = aggregateGroupBy(
df = toDataFrame(),
selector = { groups },
removeColumns = true,
body = body,
).cast<G>()

override fun filter(predicate: GroupedRowFilter<T, G>): GroupBy<T, G> {
val indices = (0 until df.nrow).filter {
Expand All @@ -78,12 +88,13 @@ internal fun <T, G, R> aggregateGroupBy(

val removed = df.removeImpl(columns = selector)

val hasKeyColumns = removed.df.ncol > 0
val keys = removed.df
val hasKeyColumns = keys.ncol > 0

val groupedFrame = column.values.map {
val groupedFrame = column.values.mapIndexed { i, it ->
if (it == null) null
else {
val builder = GroupByReceiverImpl(it, hasKeyColumns)
val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] }
val result = body(builder, builder)
if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield(
NamedValue.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument
import org.jetbrains.kotlinx.dataframe.impl.getListType
import kotlin.reflect.KType

internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override val hasGroupingKeys: Boolean) :
internal class GroupByReceiverImpl<T>(
override val df: DataFrame<T>,
override val hasGroupingKeys: Boolean,
private val retrieveKey: () -> AnyRow = {
error("This property can only be used inside 'groupBy { }.aggregate { }' clause")
}
) :
AggregateGroupedDsl<T>(),
AggregateInternalDsl<T>,
AggregatableInternal<T> by df as AggregatableInternal<T>,
DataFrame<T> by df {

override val keys by lazy {
retrieveKey()
}

private val values = mutableListOf<NamedValue>()

internal fun child(): GroupByReceiverImpl<T> {
Expand All @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
allValues.add(it)
}
}

is ValueColumn<*> -> {
allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList<Unit>()))
allValues.add(
NamedValue.create(
it.path,
it.value.toList(),
getListType(it.value.type()),
emptyList<Unit>()
)
)
}

is ColumnGroup<*> -> {
val frameType = it.value.type().arguments.singleOrNull()?.type
allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty))
allValues.add(
NamedValue.create(
it.path,
it.value.asDataFrame(),
DataFrame::class.createTypeWithArgument(frameType),
DataFrame.Empty
)
)
}

is FrameColumn<*> -> {
allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList<Unit>()))
allValues.add(
NamedValue.create(
it.path,
it.value.toList(),
getListType(it.value.type()),
emptyList<Unit>()
)
)
}

else -> {
allValues.add(it)
}
Expand All @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
when (value.value) {
is AggregatedPivot<*> -> {
val pivot = value.value
val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1
val dropFirstNameInPath =
pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }
.count() == 1
pivot.aggregator.values.forEach {
val targetPath =
if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst()
Expand All @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
}
pivot.aggregator.values.clear()
}

is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df))
else -> values.add(value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,21 @@ class GroupByTests {
getFrameColumn("d") into "e"
}["e"].type() shouldBe typeOf<List<AnyFrame>>()
}

@Test
fun `aggregate based on the key column`() {
val df = dataFrameOf(
"a", "b", "c"
)(
1, 2, 3,
4, 5, 6,
)
val grouped = df.groupBy { expr("test") { "a"<Int>() + "b"<Int>() } }
.aggregate {
count() into "count"
keys into "keys"
}

grouped.print()
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
package org.jetbrains.kotlinx.dataframe.aggregation

public abstract class AggregateGroupedDsl<out T> : AggregateDsl<T>()
import org.jetbrains.kotlinx.dataframe.AnyRow

public abstract class AggregateGroupedDsl<out T> : AggregateDsl<T>() {
public abstract val keys: AnyRow
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,24 @@ internal class GroupByImpl<T, G>(
override fun <R> updateGroups(transform: Selector<DataFrame<G>, DataFrame<R>>) =
df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy<T, R>

override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName)
override fun toDataFrame(groupedColumnName: String?) =
if (groupedColumnName == null || groupedColumnName == groups.name()) {
df
} else {
df.rename(groups).into(groupedColumnName)
}

override fun toString() = df.toString()

override fun remainingColumnsSelector(): ColumnsSelector<*, *> =
keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } }

override fun <R> aggregate(body: AggregateGroupedBody<G, R>) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast<G>()
override fun <R> aggregate(body: AggregateGroupedBody<G, R>) = aggregateGroupBy(
df = toDataFrame(),
selector = { groups },
removeColumns = true,
body = body,
).cast<G>()

override fun filter(predicate: GroupedRowFilter<T, G>): GroupBy<T, G> {
val indices = (0 until df.nrow).filter {
Expand All @@ -78,12 +88,13 @@ internal fun <T, G, R> aggregateGroupBy(

val removed = df.removeImpl(columns = selector)

val hasKeyColumns = removed.df.ncol > 0
val keys = removed.df
val hasKeyColumns = keys.ncol > 0

val groupedFrame = column.values.map {
val groupedFrame = column.values.mapIndexed { i, it ->
if (it == null) null
else {
val builder = GroupByReceiverImpl(it, hasKeyColumns)
val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] }
val result = body(builder, builder)
if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield(
NamedValue.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument
import org.jetbrains.kotlinx.dataframe.impl.getListType
import kotlin.reflect.KType

internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override val hasGroupingKeys: Boolean) :
internal class GroupByReceiverImpl<T>(
override val df: DataFrame<T>,
override val hasGroupingKeys: Boolean,
private val retrieveKey: () -> AnyRow = {
Copy link
Collaborator

@koperagen koperagen Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it has to be lambda here? Can be a DataRow, not sure. And what about default parameter: can it somehow actually throw an exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lambda could be replaced by by a nullable AnyRow perhaps.
And yes, it will throw an exception when you use dataFrame.aggregate { keys }, since for some reason, the same AggregateGroupedDsl is used there. There's also an option to get here via pivot, so for those cases I make it throw a helpful exception.

Copy link
Collaborator

@koperagen koperagen Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, i'd say we need to make type of lambda in Groupby.aggregate more specific then? So only for that case we provide keys as a DSL property. Also, if it makes sense, we can make aggregate an extension function and hide existing member one (but this could be a different story)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, i remember now. aggregate probably should be an extension function on GroupBy

public interface GroupBy<out T, out G> : Grouped<G> {

    public val groups: FrameColumn<G>

    public val keys: DataFrame<T>

then keys can become DataRow<T>

Because now aggregate on GroupBy is resolved to this member and it simply doesn't know anything about keys

public interface Grouped<out T> : Aggregatable<T> {

    public fun <R> aggregate(body: AggregateGroupedBody<T, R>): DataFrame<T>
}

error("This property can only be used inside 'groupBy { }.aggregate { }' clause")
}
) :
AggregateGroupedDsl<T>(),
AggregateInternalDsl<T>,
AggregatableInternal<T> by df as AggregatableInternal<T>,
DataFrame<T> by df {

override val keys by lazy {
retrieveKey()
}

private val values = mutableListOf<NamedValue>()

internal fun child(): GroupByReceiverImpl<T> {
Expand All @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
allValues.add(it)
}
}

is ValueColumn<*> -> {
allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList<Unit>()))
allValues.add(
NamedValue.create(
it.path,
it.value.toList(),
getListType(it.value.type()),
emptyList<Unit>()
)
)
}

is ColumnGroup<*> -> {
val frameType = it.value.type().arguments.singleOrNull()?.type
allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty))
allValues.add(
NamedValue.create(
it.path,
it.value.asDataFrame(),
DataFrame::class.createTypeWithArgument(frameType),
DataFrame.Empty
)
)
}

is FrameColumn<*> -> {
allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList<Unit>()))
allValues.add(
NamedValue.create(
it.path,
it.value.toList(),
getListType(it.value.type()),
emptyList<Unit>()
)
)
}

else -> {
allValues.add(it)
}
Expand All @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
when (value.value) {
is AggregatedPivot<*> -> {
val pivot = value.value
val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1
val dropFirstNameInPath =
pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }
.count() == 1
pivot.aggregator.values.forEach {
val targetPath =
if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst()
Expand All @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl<T>(override val df: DataFrame<T>, override va
}
pivot.aggregator.values.clear()
}

is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df))
else -> values.add(value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,21 @@ class GroupByTests {
getFrameColumn("d") into "e"
}["e"].type() shouldBe typeOf<List<AnyFrame>>()
}

@Test
fun `aggregate based on the key column`() {
val df = dataFrameOf(
"a", "b", "c"
)(
1, 2, 3,
4, 5, 6,
)
val grouped = df.groupBy { expr("test") { "a"<Int>() + "b"<Int>() } }
.aggregate {
count() into "count"
keys into "keys"
}

grouped.print()
}
}