Skip to content

Commit

Permalink
Expose the paging token API from the C/C++ driver (#22)
Browse files Browse the repository at this point in the history
Co-authored-by: Erich Menge <[email protected]>
  • Loading branch information
erichmenge and Erich Menge authored Mar 4, 2023
1 parent da2a404 commit 1da99aa
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
45 changes: 45 additions & 0 deletions Sources/CassandraClient/Data.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import Foundation
import Logging
import NIO

public protocol PagingStateToken: ContiguousBytes {}

extension CassandraClient {
/// Resulting row(s) of a Cassandra query. Data are returned all at once.
public final class Rows: Sequence {
Expand Down Expand Up @@ -46,6 +48,35 @@ extension CassandraClient {
Iterator(rows: self)
}

/// Returns a reusable paging token.
///
/// - Warning: This token is not suitable or safe for sharing externally.
public func opaquePagingStateToken() throws -> OpaquePagingStateToken {
try OpaquePagingStateToken(token: self.rawPagingStateToken())
}

private func rawPagingStateToken() throws -> [UInt8] {
var buffer: UnsafePointer<CChar>?
var length = 0

// The underlying memory is freed with the Rows result
let result = cass_result_paging_state_token(self.rawPointer, &buffer, &length)
guard result == CASS_OK, let bytesPointer = buffer else {
throw CassandraClient.Error(result)
}

let tokenBytes: [UInt8] = bytesPointer.withMemoryRebound(to: UInt8.self, capacity: length) {
let bufferPointer = UnsafeBufferPointer(start: $0, count: length)
return Array(unsafeUninitializedCapacity: length) { storagePointer, storageCount in
var (unwritten, endIndex) = storagePointer.initialize(from: bufferPointer)
precondition(unwritten.next() == nil)
storageCount = storagePointer.distance(from: storagePointer.startIndex, to: endIndex)
}
}

return tokenBytes
}

public final class Iterator: IteratorProtocol {
public typealias Element = Row

Expand Down Expand Up @@ -283,6 +314,20 @@ extension CassandraClient {
cass_value_is_null(self.rawPointer) == cass_true
}
}

/// A reusable page token that can be used by `Statement` to resume querying
/// at a specific position.
public struct OpaquePagingStateToken: PagingStateToken {
let token: [UInt8]

internal init(token: [UInt8]) {
self.token = token
}

public func withUnsafeBytes<R>(_ body: (UnsafeRawBufferPointer) throws -> R) rethrows -> R {
try self.token.withUnsafeBytes(body)
}
}
}

// MARK: - Int8
Expand Down
17 changes: 17 additions & 0 deletions Sources/CassandraClient/Statement.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,23 @@ extension CassandraClient {
try checkResult { cass_statement_set_paging_size(self.rawPointer, pagingSize) }
}

/// Sets the starting page of the returned paginated results.
///
/// The paging state token can be obtained by the `pagingStateToken()`
/// function on `Rows`.
///
/// - Warning: The paging state should not be exposed to or come from
/// untrusted environments. The paging state could be spoofed and
/// potentially used to gain access to other data.
public func setPagingStateToken(_ pagingStateToken: PagingStateToken) throws {
try checkResult {
pagingStateToken.withUnsafeBytes {
let buffer = $0.bindMemory(to: CChar.self)
return cass_statement_set_paging_state_token(self.rawPointer, buffer.baseAddress, buffer.count)
}
}
}

deinit {
cass_statement_free(self.rawPointer)
}
Expand Down
38 changes: 38 additions & 0 deletions Tests/CassandraClientTests/CassandraClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,44 @@ final class Tests: XCTestCase {
}
}

func testPagingToken() throws {
let tableName = "test_\(DispatchTime.now().uptimeNanoseconds)"
try self.cassandraClient.run("create table \(tableName) (id int primary key, data text);").wait()

let options = CassandraClient.Statement.Options(consistency: .localQuorum)

let count = Int.random(in: 5000 ... 6000)
var futures = [EventLoopFuture<Void>]()
(0 ..< count).forEach { index in
futures.append(
self.cassandraClient.run(
"insert into \(tableName) (id, data) values (?, ?);",
parameters: [.int32(Int32(index)), .string(UUID().uuidString)],
options: options
)
)
}

let initialPages = try self.cassandraClient.query("select id, data from \(tableName);", pageSize: Int32(5)).wait()

for _ in 0 ..< Int.random(in: 10 ... 20) {
_ = try! initialPages.nextPage().wait()
}

let page = try initialPages.nextPage().wait()
let pageToken = try page.opaquePagingStateToken()
let row = try initialPages.nextPage().wait().first!

let statement = try CassandraClient.Statement(query: "select id, data from \(tableName);")
try! statement.setPagingStateToken(pageToken)
let offsetPages = try self.cassandraClient.execute(statement: statement, pageSize: Int32(5), on: nil).wait()
let pagedRow: CassandraClient.Row = try offsetPages.nextPage().wait().first!

let id1: CassandraClient.Column = pagedRow.column(0)!
let id2: CassandraClient.Column = row.column(0)!
XCTAssertEqual(id1.int32, id2.int32)
}

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
func testQueryAsyncIterator() throws {
#if !(compiler(>=5.5) && canImport(_Concurrency))
Expand Down

0 comments on commit 1da99aa

Please sign in to comment.