From 40a8064ad95252b4cbdca6cc2f6ac426aa694d0c Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Sat, 6 Apr 2024 16:11:13 +0200 Subject: [PATCH 1/2] more explicit explanation of what happens in between cell calls in the extension properties api in notebooks --- .../dataframe/documentation/AccessApi.kt | 2 +- .../dataframe/samples/api/ApiLevels.kt | 2 +- .../dataframe/samples/api/ApiLevels.kt | 2 +- .../topics/extensionPropertiesApi.md | 28 ++++++++++++------- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt index 960b24b79..1fa61a0b2 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/AccessApi.kt @@ -122,7 +122,7 @@ internal interface AccessApi { * * For example: * ```kotlin - * val df = DataFrame.read("titanic.csv") + * val df /* : AnyFrame */ = DataFrame.read("titanic.csv") * ``` */ interface ExtensionPropertiesApi diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt index f37615b63..f579f53bd 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt @@ -140,7 +140,7 @@ class ApiLevels { @TransformDataFrameExpressions fun extensionProperties1() { // SampleStart - val df = DataFrame.read("titanic.csv") + val df /* : AnyFrame */ = DataFrame.read("titanic.csv") // SampleEnd } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt index f37615b63..f579f53bd 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/ApiLevels.kt @@ -140,7 +140,7 @@ class ApiLevels { @TransformDataFrameExpressions fun extensionProperties1() { // SampleStart - val df = DataFrame.read("titanic.csv") + val df /* : AnyFrame */ = DataFrame.read("titanic.csv") // SampleEnd } } diff --git a/docs/StardustDocs/topics/extensionPropertiesApi.md b/docs/StardustDocs/topics/extensionPropertiesApi.md index 00bfba268..ccbe2a08a 100644 --- a/docs/StardustDocs/topics/extensionPropertiesApi.md +++ b/docs/StardustDocs/topics/extensionPropertiesApi.md @@ -2,19 +2,30 @@ -When [`DataFrame`](DataFrame.md) is used within Jupyter Notebooks or Datalore with Kotlin Kernel, -after every cell execution all new global variables of type DataFrame are analyzed and replaced -with typed [`DataFrame`](DataFrame.md) wrapper with auto-generated extension properties for data access: +When [`DataFrame`](DataFrame.md) is used within Jupyter/Kotlin Notebook or Datalore with the Kotlin Kernel, +something special happens: +After every cell execution, all new global variables of type DataFrame are analyzed and replaced +with a typed [`DataFrame`](DataFrame.md) wrapper along with auto-generated extension properties for data access. +For instance, say we run: ```kotlin -val df = DataFrame.read("titanic.csv") +val df /* : AnyFrame */ = DataFrame.read("titanic.csv") ``` -Now data can be accessed by `.` member accessor +In normal Kotlin code, we would now have a variable of type [`AnyFrame` (=`DataFrame<*>`)](DataFrame.md) that doesn't have any +extension properties to access its columns. We would either have to define them manually or use the +[`@DataSchema`](schemas.md) annotation to [generate them](schemasGradle.md#configuration). + +By contrast, after this cell is run in a notebook, the columns of the dataframe are used as a basis +to generate a hidden `@DataSchema interface TypeX`, +along with extension properties like `val DataFrame.age` etc. +Next, the `df` variable is shadowed by a new version cast to `DataFrame`. + +As a result, now columns can be accessed directly on `df`! @@ -28,12 +39,9 @@ df.add("lastName") { name.split(",").last() } The `titanic.csv` file could be found [here](https://github.com/Kotlin/dataframe/blob/master/data/titanic.csv). -In notebooks, extension properties are generated for [`DataSchema`](schemas.md) that is extracted from [`DataFrame`](DataFrame.md) -instance after REPL line execution. -After that [`DataFrame`](DataFrame.md) variable is typed with its own [`DataSchema`](schemas.md), so only valid extension properties corresponding to actual columns in DataFrame will be allowed by the compiler and suggested by completion. - Extension properties can be generated in IntelliJ IDEA using the [Kotlin Dataframe Gradle plugin](schemasGradle.md#configuration). -In notebooks generated properties won't appear and be updated until the cell has been executed. It often means that you have to introduce new variable frequently to sync extension properties with actual schema +In notebooks generated properties won't appear and be updated until the cell has been executed. +It often means that you have to introduce new variable frequently to sync extension properties with actual schema. From 57f504f1b6d19ea727396965afea008ee6f50cd6 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 15 Apr 2024 17:57:48 +0200 Subject: [PATCH 2/2] added proof-of-concept `keys` addition to `groupBy {}.aggregate {}` to provide access to the keys: AnyRow used to group the df by. --- .../aggregation/AggregateGroupedDsl.kt | 6 ++- .../kotlinx/dataframe/impl/GroupByImpl.kt | 21 ++++++-- .../impl/aggregation/GroupByReceiverImpl.kt | 48 +++++++++++++++++-- .../kotlinx/dataframe/api/groupBy.kt | 17 +++++++ .../aggregation/AggregateGroupedDsl.kt | 6 ++- .../kotlinx/dataframe/impl/GroupByImpl.kt | 21 ++++++-- .../impl/aggregation/GroupByReceiverImpl.kt | 48 +++++++++++++++++-- .../kotlinx/dataframe/api/groupBy.kt | 17 +++++++ 8 files changed, 162 insertions(+), 22 deletions(-) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt index 43b529522..105c56aa2 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt @@ -1,3 +1,7 @@ package org.jetbrains.kotlinx.dataframe.aggregation -public abstract class AggregateGroupedDsl : AggregateDsl() +import org.jetbrains.kotlinx.dataframe.AnyRow + +public abstract class AggregateGroupedDsl : AggregateDsl() { + public abstract val keys: AnyRow +} diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt index bf58b3a66..6860343d5 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt @@ -48,14 +48,24 @@ internal class GroupByImpl( override fun updateGroups(transform: Selector, DataFrame>) = df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy - override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName) + override fun toDataFrame(groupedColumnName: String?) = + if (groupedColumnName == null || groupedColumnName == groups.name()) { + df + } else { + df.rename(groups).into(groupedColumnName) + } override fun toString() = df.toString() override fun remainingColumnsSelector(): ColumnsSelector<*, *> = keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } } - override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast() + override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy( + df = toDataFrame(), + selector = { groups }, + removeColumns = true, + body = body, + ).cast() override fun filter(predicate: GroupedRowFilter): GroupBy { val indices = (0 until df.nrow).filter { @@ -78,12 +88,13 @@ internal fun aggregateGroupBy( val removed = df.removeImpl(columns = selector) - val hasKeyColumns = removed.df.ncol > 0 + val keys = removed.df + val hasKeyColumns = keys.ncol > 0 - val groupedFrame = column.values.map { + val groupedFrame = column.values.mapIndexed { i, it -> if (it == null) null else { - val builder = GroupByReceiverImpl(it, hasKeyColumns) + val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] } val result = body(builder, builder) if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield( NamedValue.create( diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt index 9b857c254..5860b29a6 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument import org.jetbrains.kotlinx.dataframe.impl.getListType import kotlin.reflect.KType -internal class GroupByReceiverImpl(override val df: DataFrame, override val hasGroupingKeys: Boolean) : +internal class GroupByReceiverImpl( + override val df: DataFrame, + override val hasGroupingKeys: Boolean, + private val retrieveKey: () -> AnyRow = { + error("This property can only be used inside 'groupBy { }.aggregate { }' clause") + } +) : AggregateGroupedDsl(), AggregateInternalDsl, AggregatableInternal by df as AggregatableInternal, DataFrame by df { + override val keys by lazy { + retrieveKey() + } + private val values = mutableListOf() internal fun child(): GroupByReceiverImpl { @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va allValues.add(it) } } + is ValueColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + is ColumnGroup<*> -> { val frameType = it.value.type().arguments.singleOrNull()?.type - allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty)) + allValues.add( + NamedValue.create( + it.path, + it.value.asDataFrame(), + DataFrame::class.createTypeWithArgument(frameType), + DataFrame.Empty + ) + ) } + is FrameColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + else -> { allValues.add(it) } @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va when (value.value) { is AggregatedPivot<*> -> { val pivot = value.value - val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1 + val dropFirstNameInPath = + pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() } + .count() == 1 pivot.aggregator.values.forEach { val targetPath = if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst() @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va } pivot.aggregator.values.clear() } + is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df)) else -> values.add(value) } diff --git a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 4f7e78748..59770fa72 100644 --- a/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/generated-sources/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -55,4 +55,21 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `aggregate based on the key column`() { + val df = dataFrameOf( + "a", "b", "c" + )( + 1, 2, 3, + 4, 5, 6, + ) + val grouped = df.groupBy { expr("test") { "a"() + "b"() } } + .aggregate { + count() into "count" + keys into "keys" + } + + grouped.print() + } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt index 43b529522..105c56aa2 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateGroupedDsl.kt @@ -1,3 +1,7 @@ package org.jetbrains.kotlinx.dataframe.aggregation -public abstract class AggregateGroupedDsl : AggregateDsl() +import org.jetbrains.kotlinx.dataframe.AnyRow + +public abstract class AggregateGroupedDsl : AggregateDsl() { + public abstract val keys: AnyRow +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt index bf58b3a66..6860343d5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt @@ -48,14 +48,24 @@ internal class GroupByImpl( override fun updateGroups(transform: Selector, DataFrame>) = df.convert(groups) { transform(it, it) }.asGroupBy(groups.name()) as GroupBy - override fun toDataFrame(groupedColumnName: String?) = if (groupedColumnName == null || groupedColumnName == groups.name()) df else df.rename(groups).into(groupedColumnName) + override fun toDataFrame(groupedColumnName: String?) = + if (groupedColumnName == null || groupedColumnName == groups.name()) { + df + } else { + df.rename(groups).into(groupedColumnName) + } override fun toString() = df.toString() override fun remainingColumnsSelector(): ColumnsSelector<*, *> = keyColumnsInGroups.toColumnSet().let { groupCols -> { all().except(groupCols) } } - override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy(toDataFrame(), { groups }, removeColumns = true, body).cast() + override fun aggregate(body: AggregateGroupedBody) = aggregateGroupBy( + df = toDataFrame(), + selector = { groups }, + removeColumns = true, + body = body, + ).cast() override fun filter(predicate: GroupedRowFilter): GroupBy { val indices = (0 until df.nrow).filter { @@ -78,12 +88,13 @@ internal fun aggregateGroupBy( val removed = df.removeImpl(columns = selector) - val hasKeyColumns = removed.df.ncol > 0 + val keys = removed.df + val hasKeyColumns = keys.ncol > 0 - val groupedFrame = column.values.map { + val groupedFrame = column.values.mapIndexed { i, it -> if (it == null) null else { - val builder = GroupByReceiverImpl(it, hasKeyColumns) + val builder = GroupByReceiverImpl(it, hasKeyColumns) { keys[i] } val result = body(builder, builder) if (result != Unit && result !is NamedValue && result !is AggregatedPivot<*>) builder.yield( NamedValue.create( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt index 9b857c254..5860b29a6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/GroupByReceiverImpl.kt @@ -18,12 +18,22 @@ import org.jetbrains.kotlinx.dataframe.impl.createTypeWithArgument import org.jetbrains.kotlinx.dataframe.impl.getListType import kotlin.reflect.KType -internal class GroupByReceiverImpl(override val df: DataFrame, override val hasGroupingKeys: Boolean) : +internal class GroupByReceiverImpl( + override val df: DataFrame, + override val hasGroupingKeys: Boolean, + private val retrieveKey: () -> AnyRow = { + error("This property can only be used inside 'groupBy { }.aggregate { }' clause") + } +) : AggregateGroupedDsl(), AggregateInternalDsl, AggregatableInternal by df as AggregatableInternal, DataFrame by df { + override val keys by lazy { + retrieveKey() + } + private val values = mutableListOf() internal fun child(): GroupByReceiverImpl { @@ -41,16 +51,41 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va allValues.add(it) } } + is ValueColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + is ColumnGroup<*> -> { val frameType = it.value.type().arguments.singleOrNull()?.type - allValues.add(NamedValue.create(it.path, it.value.asDataFrame(), DataFrame::class.createTypeWithArgument(frameType), DataFrame.Empty)) + allValues.add( + NamedValue.create( + it.path, + it.value.asDataFrame(), + DataFrame::class.createTypeWithArgument(frameType), + DataFrame.Empty + ) + ) } + is FrameColumn<*> -> { - allValues.add(NamedValue.create(it.path, it.value.toList(), getListType(it.value.type()), emptyList())) + allValues.add( + NamedValue.create( + it.path, + it.value.toList(), + getListType(it.value.type()), + emptyList() + ) + ) } + else -> { allValues.add(it) } @@ -70,7 +105,9 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va when (value.value) { is AggregatedPivot<*> -> { val pivot = value.value - val dropFirstNameInPath = pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() }.count() == 1 + val dropFirstNameInPath = + pivot.inward == true && value.path.isNotEmpty() && pivot.aggregator.values.distinctBy { it.path.firstOrNull() } + .count() == 1 pivot.aggregator.values.forEach { val targetPath = if (dropFirstNameInPath && it.path.size > 0) value.path + it.path.dropFirst() @@ -80,6 +117,7 @@ internal class GroupByReceiverImpl(override val df: DataFrame, override va } pivot.aggregator.values.clear() } + is AggregateInternalDsl<*> -> yield(value.copy(value = value.value.df)) else -> values.add(value) } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 4f7e78748..59770fa72 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -55,4 +55,21 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `aggregate based on the key column`() { + val df = dataFrameOf( + "a", "b", "c" + )( + 1, 2, 3, + 4, 5, 6, + ) + val grouped = df.groupBy { expr("test") { "a"() + "b"() } } + .aggregate { + count() into "count" + keys into "keys" + } + + grouped.print() + } }