From ec4bd9ab3433fe4a642189b61ade542ccd4c3c16 Mon Sep 17 00:00:00 2001 From: hermannm Date: Mon, 21 Oct 2024 18:47:36 +0200 Subject: [PATCH] Add list return to batchCreate/batchUpdate Now that we support database-generated IDs, it is useful to get this ID returned from batchCreate. Previously, we did not return results here out of fear of allocating too much, but I now believe this is a premature optimization. If we do see cases where this causes memory issues, we can consider adding specialized versions of the batch methods that do not return results. --- .../documentstore/repository/Repository.kt | 26 +- .../repository/RepositoryJdbi.kt | 229 +++++++++++------- .../documentstore/utils/BatchOperation.kt | 60 ++++- .../documentstore/repository/BatchTest.kt | 37 ++- .../documentstore/repository/CrudTest.kt | 3 +- 5 files changed, 231 insertions(+), 124 deletions(-) diff --git a/src/main/kotlin/no/liflig/documentstore/repository/Repository.kt b/src/main/kotlin/no/liflig/documentstore/repository/Repository.kt index 74e2f2a..e1a446f 100644 --- a/src/main/kotlin/no/liflig/documentstore/repository/Repository.kt +++ b/src/main/kotlin/no/liflig/documentstore/repository/Repository.kt @@ -52,20 +52,11 @@ interface Repository> { * The implementation in [RepositoryJdbi.batchCreate] uses * [Prepared Batches from JDBI](https://jdbi.org/releases/3.45.1/#_prepared_batches) to make the * implementation as efficient as possible. - * - * This method does not return a list of the created entities. This is because it takes an - * [Iterable], which may be a lazy view into a large collection, and we don't want to assume for - * library users that it's OK to allocate a list of that size. If you find a use-case for getting - * the [Version] of created entities, you should either list them out afterwards with - * [listByIds]/[listAll]/[RepositoryJdbi.getByPredicate], or alert the library authors so we may - * consider returning results here. */ - fun batchCreate(entities: Iterable) { + fun batchCreate(entities: Iterable): List> { // A default implementation is provided here on the interface, so that implementors don't have // to implement this themselves (for e.g. mock repositories). - for (entity in entities) { - create(entity) - } + return entities.map { create(it) } } /** @@ -76,22 +67,13 @@ interface Repository> { * [Prepared Batches from JDBI](https://jdbi.org/releases/3.45.1/#_prepared_batches) to make the * implementation as efficient as possible. * - * This method does not return a list of the updated entities. This is because it takes an - * [Iterable], which may be a lazy view into a large collection, and we don't want to assume for - * library users that it's OK to allocate a list of that size. If you find a use-case for getting - * the new [Version] of updated entities, you should either list them out afterwards with - * [listByIds]/[listAll]/[RepositoryJdbi.getByPredicate], or alert the library authors so we may - * consider returning results here. - * * @throws ConflictRepositoryException If the version on any of the given entities does not match * the current version of the entity in the database. */ - fun batchUpdate(entities: Iterable>) { + fun batchUpdate(entities: Iterable>): List> { // A default implementation is provided here on the interface, so that implementors don't have // to implement this themselves (for e.g. mock repositories). - for (entity in entities) { - update(entity.item, entity.version) - } + return entities.map { update(it.item, it.version) } } /** diff --git a/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt b/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt index e64cd88..7ddb650 100644 --- a/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt +++ b/src/main/kotlin/no/liflig/documentstore/repository/RepositoryJdbi.kt @@ -11,7 +11,6 @@ import no.liflig.documentstore.entity.Version import no.liflig.documentstore.entity.Versioned import no.liflig.documentstore.entity.getEntityIdType import no.liflig.documentstore.utils.executeBatchOperation -import org.jdbi.v3.core.Handle import org.jdbi.v3.core.Jdbi import org.jdbi.v3.core.mapper.RowMapper import org.jdbi.v3.core.statement.Query @@ -41,8 +40,8 @@ import org.jdbi.v3.core.statement.Query * @param jdbi Must have the [DocumentStorePlugin] installed for the queries in this class to work. * @param serializationAdapter See [SerializationAdapter] for an example of how to implement this. * @param idsGeneratedByDatabase When using [IntegerEntityId], one often wants the entity IDs to be - * generated by the database. This affects how we perform [create], so if your `id` column is - * `PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY`, you must set this flag to true. + * generated by the database. This affects how we perform [create] (and [batchCreate]), so if your + * `id` column is `PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY`, you must set this flag to true. * * The `create` method still takes an entity with an `id`, but when this flag is set, the `id` * given to `create` is ignored (you can use [IntegerEntityId.GENERATED] as a dummy ID value). The @@ -70,14 +69,14 @@ open class RepositoryJdbi>( override fun create(entity: EntityT): Versioned { try { - useHandle(jdbi) { handle -> - val createdAt = Instant.now() - val version = Version.initial() + val createdAt = Instant.now() + val version = Version.initial() - if (idsGeneratedByDatabase) { - return createEntityWithGeneratedId(handle, entity, createdAt, version) - } + if (idsGeneratedByDatabase) { + return createEntityWithGeneratedId(entity, createdAt, version) + } + useHandle(jdbi) { handle -> handle .createUpdate( """ @@ -92,8 +91,9 @@ open class RepositoryJdbi>( .bind("createdAt", createdAt) .bind("modifiedAt", createdAt) .execute() - return Versioned(entity, version, createdAt = createdAt, modifiedAt = createdAt) } + + return Versioned(entity, version, createdAt = createdAt, modifiedAt = createdAt) } catch (e: Exception) { // Call mapDatabaseException first to handle connection-related exceptions, before calling // mapCreateOrUpdateException (which may be overridden by users for custom error handling). @@ -130,41 +130,42 @@ open class RepositoryJdbi>( * although the behavior we have here is more akin to `GENERATED ALWAYS`. */ private fun createEntityWithGeneratedId( - handle: Handle, entity: EntityT, createdAt: Instant, version: Version ): Versioned { - val createdEntity = - handle - .createQuery( - """ - WITH generated_id AS ( - SELECT nextval(pg_get_serial_sequence('${tableName}', 'id')) AS value - ) - INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) - SELECT - generated_id.value, - jsonb_set(:data::jsonb, '{id}', to_jsonb(generated_id.value)), - :version, - :createdAt, - :modifiedAt - FROM generated_id - RETURNING data - """ - .trimIndent(), - ) - .bind("data", serializationAdapter.toJson(entity)) - .bind("version", version) - .bind("createdAt", createdAt) - .bind("modifiedAt", createdAt) - .map(entityDataMapper) - .firstOrNull() - ?: throw IllegalStateException( - "INSERT query for entity with generated ID did not return entity data", - ) + val entityWithGeneratedId = + useHandle(jdbi) { handle -> + handle + .createQuery( + """ + WITH generated_id AS ( + SELECT nextval(pg_get_serial_sequence('${tableName}', 'id')) AS value + ) + INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) + SELECT + generated_id.value, + jsonb_set(:data::jsonb, '{id}', to_jsonb(generated_id.value)), + :version, + :createdAt, + :modifiedAt + FROM generated_id + RETURNING data + """ + .trimIndent(), + ) + .bind("data", serializationAdapter.toJson(entity)) + .bind("version", version) + .bind("createdAt", createdAt) + .bind("modifiedAt", createdAt) + .map(entityDataMapper) + .firstOrNull() + ?: throw IllegalStateException( + "INSERT query for entity with generated ID did not return entity data", + ) + } - return Versioned(createdEntity, version, createdAt = createdAt, modifiedAt = createdAt) + return Versioned(entityWithGeneratedId, version, createdAt = createdAt, modifiedAt = createdAt) } override fun get(id: EntityIdT, forUpdate: Boolean): Versioned? { @@ -279,17 +280,21 @@ open class RepositoryJdbi>( return getByPredicate() // Defaults to all } - override fun batchCreate(entities: Iterable) { - transactional { - useHandle(jdbi) { handle -> - val createdAt = Instant.now() - val version = Version.initial() + override fun batchCreate(entities: Iterable): List> { + val size = entities.sizeIfKnown() + if (size == 0) { + return emptyList() + } - if (idsGeneratedByDatabase) { - batchCreateEntitiesWithGeneratedId(handle, entities, createdAt, version) - return@transactional - } + val createdAt = Instant.now() + val version = Version.initial() + if (idsGeneratedByDatabase) { + return batchCreateEntitiesWithGeneratedId(entities, size, createdAt, version) + } + + transactional { + useHandle(jdbi) { handle -> executeBatchOperation( handle, entities, @@ -310,47 +315,87 @@ open class RepositoryJdbi>( ) } } + + // We wait until here to create the result list, which may be large, to avoid allocating it + // before calling the database. That would keep the list in memory while we are waiting for the + // database, needlessly reducing throughput. + return entities.map { entity -> + Versioned(entity, version, createdAt = createdAt, modifiedAt = createdAt) + } } + /** + * See: + * - [createEntityWithGeneratedId] for the challenges with generated IDs, which we also face here + * - [no.liflig.documentstore.utils.executeBatch] for how we handle returning the data from our + * created entities (which we need here in order to get the generated IDs) + */ private fun batchCreateEntitiesWithGeneratedId( - handle: Handle, entities: Iterable, + size: Int?, createdAt: Instant, version: Version - ) { - executeBatchOperation( - handle, - entities, - statement = - """ - WITH generated_id AS ( - SELECT nextval(pg_get_serial_sequence('${tableName}', 'id')) AS value - ) - INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) - SELECT - generated_id.value, - jsonb_set(:data::jsonb, '{id}', to_jsonb(generated_id.value)), - :version, - :createdAt, - :modifiedAt - FROM generated_id - """ - .trimIndent(), - bindParameters = { batch, entity -> - batch - .bind("data", serializationAdapter.toJson(entity)) - .bind("version", version) - .bind("createdAt", createdAt) - .bind("modifiedAt", createdAt) - }, - ) - } + ): List> { + // If we know the size of the given entities, we want to pre-allocate capacity for the result + val entitiesWithGeneratedId: ArrayList> = + if (size != null) ArrayList(size) else ArrayList() - override fun batchUpdate(entities: Iterable>) { transactional { useHandle(jdbi) { handle -> - val now = Instant.now() + executeBatchOperation( + handle, + entities, + statement = + """ + WITH generated_id AS ( + SELECT nextval(pg_get_serial_sequence('${tableName}', 'id')) AS value + ) + INSERT INTO "${tableName}" (id, data, version, created_at, modified_at) + SELECT + generated_id.value, + jsonb_set(:data::jsonb, '{id}', to_jsonb(generated_id.value)), + :version, + :createdAt, + :modifiedAt + FROM generated_id + """ + .trimIndent(), + bindParameters = { batch, entity -> + batch + .bind("data", serializationAdapter.toJson(entity)) + .bind("version", version) + .bind("createdAt", createdAt) + .bind("modifiedAt", createdAt) + }, + columnsToReturn = arrayOf(Columns.DATA), + handleReturnedColumns = { resultSet -> + for (entityWithGeneratedId in resultSet.map(entityDataMapper)) { + entitiesWithGeneratedId.add( + Versioned( + entityWithGeneratedId, + version, + createdAt = createdAt, + modifiedAt = createdAt, + ), + ) + } + }, + ) + } + } + + return entitiesWithGeneratedId + } + + override fun batchUpdate(entities: Iterable>): List> { + if (entities.sizeIfKnown() == 0) { + return emptyList() + } + + val modifiedAt = Instant.now() + transactional { + useHandle(jdbi) { handle -> executeBatchOperation( handle, entities, @@ -372,7 +417,7 @@ open class RepositoryJdbi>( batch .bind("data", serializationAdapter.toJson(entity.item)) .bind("nextVersion", nextVersion) - .bind("modifiedAt", now) + .bind("modifiedAt", modifiedAt) .bind("id", entity.item.id) .bind("previousVersion", entity.version) }, @@ -382,9 +427,20 @@ open class RepositoryJdbi>( ) } } + + // We wait until here to create the result list, which may be large, to avoid allocating it + // before calling the database. That would keep the list in memory while we are waiting for the + // database, needlessly reducing throughput. + return entities.map { entity -> + entity.copy(modifiedAt = modifiedAt, version = entity.version.next()) + } } override fun batchDelete(entities: Iterable>) { + if (entities.sizeIfKnown() == 0) { + return + } + transactional { useHandle(jdbi) { handle -> executeBatchOperation( @@ -596,3 +652,14 @@ open class RepositoryJdbi>( } } } + +/** + * An Iterable may or may not have a known size. But in some of our methods on RepositoryJdbi, we + * can make optimizations, such as returning early or pre-allocating results, if we know the size. + * So we use this extension function to see if we can get a size from the Iterable. + * + * This is the same way that the Kotlin standard library does it for e.g. [Iterable.map]. + */ +private fun Iterable<*>.sizeIfKnown(): Int? { + return if (this is Collection<*>) this.size else null +} diff --git a/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt b/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt index d090b61..3f3cdc5 100644 --- a/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt +++ b/src/main/kotlin/no/liflig/documentstore/utils/BatchOperation.kt @@ -1,6 +1,7 @@ package no.liflig.documentstore.utils import org.jdbi.v3.core.Handle +import org.jdbi.v3.core.result.BatchResultBearing import org.jdbi.v3.core.statement.PreparedBatch /** @@ -18,6 +19,10 @@ import org.jdbi.v3.core.statement.PreparedBatch * executed batch, which may be more than 1 if the number of items exceeds [batchSize]. A second * parameter is provided to [handleModifiedRowCounts] with the start index of the current batch, * which can then be used to get the corresponding entity for diagnostics purposes. + * + * If you need to return something from the query, pass columns names in [columnsToReturn]. This + * will append `RETURNING` to the SQL statement with the given column names. You can then iterate + * over the results with [handleReturnedColumns]. */ internal fun executeBatchOperation( handle: Handle, @@ -25,6 +30,8 @@ internal fun executeBatchOperation( statement: String, bindParameters: (PreparedBatch, BatchItemT) -> PreparedBatch, handleModifiedRowCounts: ((IntArray, Int) -> Unit)? = null, + columnsToReturn: Array? = null, + handleReturnedColumns: ((BatchResultBearing) -> Unit)? = null, batchSize: Int = 50, ) { runWithAutoCommitDisabled(handle) { @@ -43,10 +50,13 @@ internal fun executeBatchOperation( elementCountInCurrentBatch++ if (elementCountInCurrentBatch >= batchSize) { - val modifiedRowCounts = currentBatch.execute() - if (handleModifiedRowCounts != null) { - handleModifiedRowCounts(modifiedRowCounts, startIndexOfCurrentBatch) - } + executeBatch( + currentBatch, + startIndexOfCurrentBatch, + handleModifiedRowCounts, + columnsToReturn, + handleReturnedColumns, + ) currentBatch = null elementCountInCurrentBatch = 0 @@ -55,10 +65,44 @@ internal fun executeBatchOperation( // If currentBatch is non-null here, that means we still have remaining entities to update if (currentBatch != null) { - val executeResult = currentBatch.execute() - if (handleModifiedRowCounts != null) { - handleModifiedRowCounts(executeResult, startIndexOfCurrentBatch) - } + executeBatch( + currentBatch, + startIndexOfCurrentBatch, + handleModifiedRowCounts, + columnsToReturn, + handleReturnedColumns, + ) + } + } +} + +/** + * We have 2 different variants here: + * - If the batch query is not returning anything, we can call [PreparedBatch.execute], which just + * returns the modified row counts. + * - If the batch query does return something (i.e. [columnsToReturn] is set), then we must call + * [PreparedBatch.executePreparedBatch]. That appends the given columns in a `RETURNING` clause on + * the query, and gives us a result object which we can handle in [handleReturnedColumns]. + */ +private fun executeBatch( + currentBatch: PreparedBatch, + startIndexOfCurrentBatch: Int, + handleModifiedRowCounts: ((IntArray, Int) -> Unit)?, + columnsToReturn: Array?, + handleReturnedColumns: ((BatchResultBearing) -> Unit)?, +) { + if (columnsToReturn.isNullOrEmpty()) { + val modifiedRowCounts = currentBatch.execute() + if (handleModifiedRowCounts != null) { + handleModifiedRowCounts(modifiedRowCounts, startIndexOfCurrentBatch) + } + } else { + val result = currentBatch.executePreparedBatch(*columnsToReturn) + if (handleModifiedRowCounts != null) { + handleModifiedRowCounts(result.modifiedRowCounts(), startIndexOfCurrentBatch) + } + if (handleReturnedColumns != null) { + handleReturnedColumns(result) } } } diff --git a/src/test/kotlin/no/liflig/documentstore/repository/BatchTest.kt b/src/test/kotlin/no/liflig/documentstore/repository/BatchTest.kt index 43ffe09..9b83cea 100644 --- a/src/test/kotlin/no/liflig/documentstore/repository/BatchTest.kt +++ b/src/test/kotlin/no/liflig/documentstore/repository/BatchTest.kt @@ -35,15 +35,19 @@ class BatchTest { (1..largeBatchSize).map { number -> ExampleEntity(text = "batch-test-${testNumberFormat.format(number)}") } - exampleRepo.batchCreate(entitiesToCreate) + entities = exampleRepo.batchCreate(entitiesToCreate) - entities = exampleRepo.listByIds(entitiesToCreate.map { it.id }) assertNotEquals(0, entities.size) assertEquals(entitiesToCreate.size, entities.size) for ((index, entity) in entities.withIndex()) { // We order by text in ExampleRepository.getByTexts, so they should be returned in same order assertEquals(entity, entities[index]) } + + // Verify that fetching out the created entities gives the same results as the ones we got back + // from batchCreate + val fetchedEntities = exampleRepo.listByIds(entitiesToCreate.map { it.id }) + assertEquals(fetchedEntities, entities) } @Order(2) @@ -57,14 +61,18 @@ class BatchTest { ) entity.copy(item = updatedEntity) } - exampleRepo.batchUpdate(updatedEntities) + entities = exampleRepo.batchUpdate(updatedEntities) - entities = exampleRepo.listByIds(updatedEntities.map { it.item.id }) assertEquals(updatedEntities.size, entities.size) for ((index, entity) in entities.withIndex()) { assertNotNull(entity.item.moreText) assertEquals(entity, entities[index]) } + + // Verify that fetching out the updated entities gives the same results as the ones we got back + // from batchUpdate + val fetchedEntities = exampleRepo.listByIds(updatedEntities.map { it.item.id }) + assertEquals(fetchedEntities, entities) } @Order(3) @@ -103,14 +111,21 @@ class BatchTest { text = "batch-test-with-generated-id-${testNumberFormat.format(number)}", ) } - exampleRepoWithGeneratedIntegerId.batchCreate(entitiesToCreate) + val createdEntities = exampleRepoWithGeneratedIntegerId.batchCreate(entitiesToCreate) - val entities = exampleRepoWithGeneratedIntegerId.listAll() - assertNotEquals(0, entities.size) - assert(entities.size >= entitiesToCreate.size) + assertEquals(entitiesToCreate.size, createdEntities.size) + // Verify that returned entities are in the same order that we gave them + for ((index, createdEntity) in createdEntities.withIndex()) { + assertEquals(entitiesToCreate[index].text, createdEntity.item.text) + + // After calling batchCreate, the IDs should now have been set by the database + assertNotEquals(IntegerEntityId.GENERATED, createdEntity.item.id.value) + } - val expectedTextFields = entitiesToCreate.map { it.text } - val actualTextFields = entities.map { it.item.text } - assert(actualTextFields.containsAll(expectedTextFields)) + // Verify that fetching out the created entities gives the same results as the ones we got back + // from batchCreate + val fetchedEntities = + exampleRepoWithGeneratedIntegerId.listByIds(createdEntities.map { it.item.id }) + assertEquals(fetchedEntities, createdEntities) } } diff --git a/src/test/kotlin/no/liflig/documentstore/repository/CrudTest.kt b/src/test/kotlin/no/liflig/documentstore/repository/CrudTest.kt index 3e722bc..2713837 100644 --- a/src/test/kotlin/no/liflig/documentstore/repository/CrudTest.kt +++ b/src/test/kotlin/no/liflig/documentstore/repository/CrudTest.kt @@ -143,8 +143,7 @@ class CrudTest { .map { entity -> exampleRepoWithGeneratedIntegerId.create(entity) } // After calling RepositoryJdbi.create, the IDs should now have been set by the database - val entityIds = entities.map { it.item.id } - entityIds.forEach { id -> assertNotEquals(IntegerEntityId.GENERATED, id.value) } + entities.forEach { entity -> assertNotEquals(IntegerEntityId.GENERATED, entity.item.id.value) } val getResult = exampleRepoWithGeneratedIntegerId.get(entities[0].item.id) assertNotNull(getResult)