Skip to content

Commit

Permalink
Add "write" operation to HAConnection (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoncal authored Apr 2, 2024
1 parent 8d1361d commit b602e94
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 2 deletions.
26 changes: 26 additions & 0 deletions Source/Internal/HAConnectionImpl.swift
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,26 @@ extension HAConnectionImpl {
}
}

private func sendWrite(identifier: HARequestIdentifier?, sttBinaryHandlerId: UInt8, audioDataString: String?) {
// If there is no audioData, handlerID will be the payload alone indicating end of audio
var audioData = Data(base64Encoded: audioDataString ?? "") ?? Data()

// Prefix audioData with handler ID so the API can map the binary data
audioData.insert(sttBinaryHandlerId, at: 0)
workQueue.async { [connection] in
connection?.write(data: audioData) { [weak self] in
guard let self else { return }
self.responseController.didWrite()
if let identifier, let request = self.requestController.single(for: identifier) {
callbackQueue.async {
request.resolve(.success(.empty))
}
requestController.clear(invocation: request)
}
}
}
}

func sendRaw(
identifier: HARequestIdentifier?,
request: HARequest
Expand All @@ -401,6 +421,12 @@ extension HAConnectionImpl {
sendWebSocket(identifier: identifier, request: request, command: command)
case let .rest(method, command):
sendRest(identifier: identifier!, request: request, method: method, command: command)
case let .sttData(data):
sendWrite(
identifier: identifier,
sttBinaryHandlerId: data.rawValue,
audioDataString: request.data["audioData"] as? String
)
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion Source/Internal/RequestController/HARequestController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ internal struct HARequestControllerAllowedSendKind: OptionSet {

static let webSocket: Self = .init(rawValue: 0b1)
static let rest: Self = .init(rawValue: 0b10)
static let all: Self = [.webSocket, .rest]
static let sttData: Self = .init(rawValue: 0b11)
static let all: Self = [.webSocket, .rest, .sttData]

func allows(requestType: HARequestType) -> Bool {
switch requestType {
case .webSocket:
return contains(.webSocket)
case .rest:
return contains(.rest)
case .sttData:
return contains(.sttData)
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions Source/Internal/ResponseController/HAResponseController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ internal protocol HAResponseController: AnyObject {
var phase: HAResponseControllerPhase { get }

func reset()
func didWrite()
func didReceive(event: Starscream.WebSocketEvent)
func didReceive(
for identifier: HARequestIdentifier,
Expand Down Expand Up @@ -187,4 +188,8 @@ internal class HAResponseControllerImpl: HAResponseController {
}
}
}

func didWrite() {
HAGlobal.log(.info, "Data written")
}
}
10 changes: 9 additions & 1 deletion Source/Requests/HARequestType.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import Foundation

/// The command to issue
public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
/// Sent over WebSocket, the command of the request
case webSocket(String)
/// Sent over REST, the HTTP method to use and the post-`api/` path
case rest(HAHTTPMethod, String)
/// Sent over WebSocket, the stt binary handler id
case sttData(HASttHandlerId)

/// Create a WebSocket request type by string literal
/// - Parameter value: The name of the WebSocket command
Expand All @@ -16,13 +20,15 @@ public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
switch self {
case let .webSocket(command), let .rest(_, command):
return command
case .sttData:
return ""
}
}

/// The request is issued outside of the lifecycle of a connection
public var isPerpetual: Bool {
switch self {
case .webSocket: return false
case .webSocket, .sttData: return false
case .rest: return true
}
}
Expand All @@ -39,6 +45,8 @@ public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
case let (.webSocket(lhsCommand), .webSocket(rhsCommand)),
let (.rest(_, lhsCommand), .rest(_, rhsCommand)):
return lhsCommand < rhsCommand
case (.sttData, _), (_, .sttData):
return false
}
}

Expand Down
21 changes: 21 additions & 0 deletions Source/Requests/HASttData.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/// Write audio data to websocket, sttBinaryHandlerId is provided by run-start in Assist pipeline
public struct HASttHandlerId: Hashable {
var rawValue: UInt8

public init(rawValue: UInt8) {
self.rawValue = rawValue
}
}

public extension HATypedRequest {
/// Send binary stream STT data
/// - Parameters:
/// - sttHandlerId: Handler Id provided by run-start event from Assist pipeline
/// - audioDataBase64Encoded: Audio data base 64 encoded
/// - Returns: A typed request that can be sent via `HAConnection`
static func sendSttData(sttHandlerId: UInt8, audioDataBase64Encoded: String) -> HATypedRequest<HAResponseVoid> {
.init(request: .init(type: .sttData(.init(rawValue: sttHandlerId)), data: [
"audioData": audioDataBase64Encoded,
]))
}
}
1 change: 1 addition & 0 deletions Tests/FakeEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ internal class FakeEngine: Engine {

func write(data: Data, opcode: FrameOpCode, completion: (() -> Void)?) {
events.append(.writeData(data, opcode: opcode))
completion?()
}

func write(string: String, completion: (() -> Void)?) {
Expand Down
119 changes: 119 additions & 0 deletions Tests/HAConnectionImpl.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,121 @@ internal class HAConnectionImplTests: XCTestCase {
let container = connection.caches
XCTAssertEqual(ObjectIdentifier(container.connection), ObjectIdentifier(connection))
}

func testWriteDataRequestAddsToRequestController() {
connection.connectAutomatically = true
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
connection.connect()
connection.send(request) { _ in }
XCTAssertNotNil(requestController.added.first(where: { invocation in
if case let .sttData(sttBinaryHandlerId) = invocation.request.type {
return sttBinaryHandlerId == .init(rawValue: 1) && invocation.request
.data["audioData"] as? String == expectedData
.base64EncodedString()
}
return false
}))
}

func testWriteDataRequestsCallCompletion() {
let expectation = expectation(description: "Waiting for completion")
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
connection.connect()
responseController.receivedWaitExpectation = expectation
connection.requestController(requestController, didPrepareRequest: request, with: .init(integerLiteral: 1))

wait(for: [expectation], timeout: 5.0)
}

func testSendRawWithSttRequest() {
connection.connect()
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
var expectedAudioData = expectedData
expectedAudioData.insert(1, at: 0)

connection.sendRaw(identifier: nil, request: request)
waitForWorkQueue()
XCTAssertNotNil(engine.events.first { event in
event == .writeData(expectedAudioData, opcode: .binaryFrame)
})
}

func testSendSttRequestCommand() {
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
let request2 = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)

XCTAssertEqual(request.type.command, "")
XCTAssertEqual(request.type < request2.type, false)
}

func testSendSttRequestSentSuccessful() throws {
let expectation = self.expectation(description: "completion")
responseController.phase = .command(version: "2024.4")
_ = connection.send(.sendSttData(sttHandlerId: 1, audioDataBase64Encoded: ""), completion: { _ in
expectation.fulfill()
})
let added = try XCTUnwrap(requestController.added.first as? HARequestInvocationSingle)
added.resolve(.success(.empty))
waitForExpectations(timeout: 10.0)
}

func testSendSttDataTypedRequest() {
let request: HATypedRequest<HAResponseVoid> = .sendSttData(sttHandlerId: 1, audioDataBase64Encoded: "a")

XCTAssertEqual(request.request.data as? [String: String], ["audioData": "a"])
XCTAssertEqual(request.request.type, .sttData(.init(rawValue: 1)))
}

func testSendSettDataClearsOnCompletion() {
connection.connect()
let expectation = expectation(description: "Completion")
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(.init(rawValue: 1)),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
let invocation: HARequestInvocationSingle = .init(
request: request,
completion: { _ in
XCTAssertEqual(self.requestController.cleared.count, 1)
expectation.fulfill()
}
)
requestController.singles[1] = invocation
connection.sendRaw(identifier: 1, request: request)
waitForExpectations(timeout: 5.0)
}
}

extension WebSocketEvent: Equatable {
Expand Down Expand Up @@ -1614,6 +1729,10 @@ private class FakeHAResponseController: HAResponseController {
receivedWaitExpectation?.fulfill()
}

func didWrite() {
receivedWaitExpectation?.fulfill()
}

var receivedRestWaitExpectation: XCTestExpectation?
var receivedRest: [Swift.Result<(HTTPURLResponse, Data?), Error>] = []
func didReceive(for identifier: HARequestIdentifier, response: Swift.Result<(HTTPURLResponse, Data?), Error>) {
Expand Down
6 changes: 6 additions & 0 deletions Tests/HARequestController.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ internal class HARequestControllerTests: XCTestCase {
XCTAssertEqual(types, Set(["test3", "test4"]))
}

func testAddingWriteRequestAllowed() {
delegate.allowedSendKinds = .all
controller.add(.init(request: .init(type: .sttData(.init(rawValue: 1)))))
XCTAssertEqual(delegate.didPrepare.count, 1)
}

func testCancelSingleBeforeSent() {
let invocation = HARequestInvocationSingle(
request: .init(type: "test1", data: [:]),
Expand Down
12 changes: 12 additions & 0 deletions Tests/HAResponseController.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ internal class HAResponseControllerTests: XCTestCase {
waitForCallback()
XCTAssertEqual(delegate.lastReceived, .result(identifier: 2, result: .success(.dictionary(resultDictionary))))
}

func testDidWriteEventLogs() {
let expectation = expectation(description: "Receive log")
HAGlobal.log = { level, message in
XCTAssertEqual(level, .info)
XCTAssertEqual(message, "Data written")
HAGlobal.log = { _, _ in }
expectation.fulfill()
}
controller.didWrite()
wait(for: [expectation], timeout: 2)
}
}

private extension HAResponseControllerTests {
Expand Down

0 comments on commit b602e94

Please sign in to comment.