diff --git a/Sources/CassandraClient/Data.swift b/Sources/CassandraClient/Data.swift index d00a546..9d732c8 100644 --- a/Sources/CassandraClient/Data.swift +++ b/Sources/CassandraClient/Data.swift @@ -17,6 +17,8 @@ import Foundation import Logging import NIO +public protocol PagingStateToken: ContiguousBytes {} + extension CassandraClient { /// Resulting row(s) of a Cassandra query. Data are returned all at once. public final class Rows: Sequence { @@ -46,6 +48,35 @@ extension CassandraClient { Iterator(rows: self) } + /// Returns a reusable paging token. + /// + /// - Warning: This token is not suitable or safe for sharing externally. + public func opaquePagingStateToken() throws -> OpaquePagingStateToken { + try OpaquePagingStateToken(token: self.rawPagingStateToken()) + } + + private func rawPagingStateToken() throws -> [UInt8] { + var buffer: UnsafePointer? + var length = 0 + + // The underlying memory is freed with the Rows result + let result = cass_result_paging_state_token(self.rawPointer, &buffer, &length) + guard result == CASS_OK, let bytesPointer = buffer else { + throw CassandraClient.Error(result) + } + + let tokenBytes: [UInt8] = bytesPointer.withMemoryRebound(to: UInt8.self, capacity: length) { + let bufferPointer = UnsafeBufferPointer(start: $0, count: length) + return Array(unsafeUninitializedCapacity: length) { storagePointer, storageCount in + var (unwritten, endIndex) = storagePointer.initialize(from: bufferPointer) + precondition(unwritten.next() == nil) + storageCount = storagePointer.distance(from: storagePointer.startIndex, to: endIndex) + } + } + + return tokenBytes + } + public final class Iterator: IteratorProtocol { public typealias Element = Row @@ -283,6 +314,20 @@ extension CassandraClient { cass_value_is_null(self.rawPointer) == cass_true } } + + /// A reusable page token that can be used by `Statement` to resume querying + /// at a specific position. + public struct OpaquePagingStateToken: PagingStateToken { + let token: [UInt8] + + internal init(token: [UInt8]) { + self.token = token + } + + public func withUnsafeBytes(_ body: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { + try self.token.withUnsafeBytes(body) + } + } } // MARK: - Int8 diff --git a/Sources/CassandraClient/Statement.swift b/Sources/CassandraClient/Statement.swift index 569b4be..0ba7d74 100644 --- a/Sources/CassandraClient/Statement.swift +++ b/Sources/CassandraClient/Statement.swift @@ -91,6 +91,23 @@ extension CassandraClient { try checkResult { cass_statement_set_paging_size(self.rawPointer, pagingSize) } } + /// Sets the starting page of the returned paginated results. + /// + /// The paging state token can be obtained by the `pagingStateToken()` + /// function on `Rows`. + /// + /// - Warning: The paging state should not be exposed to or come from + /// untrusted environments. The paging state could be spoofed and + /// potentially used to gain access to other data. + public func setPagingStateToken(_ pagingStateToken: PagingStateToken) throws { + try checkResult { + pagingStateToken.withUnsafeBytes { + let buffer = $0.bindMemory(to: CChar.self) + return cass_statement_set_paging_state_token(self.rawPointer, buffer.baseAddress, buffer.count) + } + } + } + deinit { cass_statement_free(self.rawPointer) } diff --git a/Tests/CassandraClientTests/CassandraClientTests.swift b/Tests/CassandraClientTests/CassandraClientTests.swift index 4ba360a..c2ba36b 100644 --- a/Tests/CassandraClientTests/CassandraClientTests.swift +++ b/Tests/CassandraClientTests/CassandraClientTests.swift @@ -233,6 +233,44 @@ final class Tests: XCTestCase { } } + func testPagingToken() throws { + let tableName = "test_\(DispatchTime.now().uptimeNanoseconds)" + try self.cassandraClient.run("create table \(tableName) (id int primary key, data text);").wait() + + let options = CassandraClient.Statement.Options(consistency: .localQuorum) + + let count = Int.random(in: 5000 ... 6000) + var futures = [EventLoopFuture]() + (0 ..< count).forEach { index in + futures.append( + self.cassandraClient.run( + "insert into \(tableName) (id, data) values (?, ?);", + parameters: [.int32(Int32(index)), .string(UUID().uuidString)], + options: options + ) + ) + } + + let initialPages = try self.cassandraClient.query("select id, data from \(tableName);", pageSize: Int32(5)).wait() + + for _ in 0 ..< Int.random(in: 10 ... 20) { + _ = try! initialPages.nextPage().wait() + } + + let page = try initialPages.nextPage().wait() + let pageToken = try page.opaquePagingStateToken() + let row = try initialPages.nextPage().wait().first! + + let statement = try CassandraClient.Statement(query: "select id, data from \(tableName);") + try! statement.setPagingStateToken(pageToken) + let offsetPages = try self.cassandraClient.execute(statement: statement, pageSize: Int32(5), on: nil).wait() + let pagedRow: CassandraClient.Row = try offsetPages.nextPage().wait().first! + + let id1: CassandraClient.Column = pagedRow.column(0)! + let id2: CassandraClient.Column = row.column(0)! + XCTAssertEqual(id1.int32, id2.int32) + } + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) func testQueryAsyncIterator() throws { #if !(compiler(>=5.5) && canImport(_Concurrency))