Skip to content

Commit

Permalink
Add stt data request
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoncal committed Mar 14, 2024
1 parent 0c97129 commit b29f5cc
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 31 deletions.
6 changes: 3 additions & 3 deletions Extensions/Mocks/HAConnection+Mock.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public class HAMockConnection: HAConnection {
public var automaticallyTransitionToConnecting = true

/// Data request received to be written
public var writeDataRequest: HARequest?
public var sttDataRequestReceived: HARequest?

// MARK: - Mock Implementation

Expand Down Expand Up @@ -230,7 +230,7 @@ public class HAMockConnection: HAConnection {
}
}

public func write(_ dataRequest: HARequest) {
writeDataRequest = dataRequest
public func sendSttAudio(_ request: HARequest) {
sttDataRequestReceived = request
}
}
6 changes: 3 additions & 3 deletions Source/HAConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ public protocol HAConnection: AnyObject {
handler: @escaping (HACancellable, T) -> Void
) -> HACancellable

/// Write data to websocket connection
/// Send audio to Assist
/// - Parameters:
/// - request: The data request containing data to be written
func write(_ dataRequest: HARequest)
/// - request: The data request containing sttBinaryHandlerId and data (as base64 string) to be written
func sendSttAudio(_ request: HARequest)
}
25 changes: 13 additions & 12 deletions Source/Internal/HAConnectionImpl.swift
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,10 @@ internal class HAConnectionImpl: HAConnection {

// MARK: - Write

public func write(_ dataRequest: HARequest) {
if case .data = dataRequest.type {
defer { connectAutomaticallyIfNeeded() }
let invocation = HARequestInvocationSingle(request: dataRequest) { _ in }
requestController.add(invocation)
} else {
HAGlobal.log(.error, "Write operation can only be executed by data HARequest")
}
public func sendSttAudio(_ request: HARequest) {
defer { connectAutomaticallyIfNeeded() }
let invocation = HARequestInvocationSingle(request: request) { _ in }
requestController.add(invocation)
}
}

Expand Down Expand Up @@ -404,9 +400,14 @@ extension HAConnectionImpl {
}
}

private func sendWrite(_ data: Data) {
private func sendWrite(_ sttBinaryHandlerId: UInt8, audioDataString: String?) {
// If there is no audioData, handlerID will be the payload alone indicating end of audio
var audioData = Data(base64Encoded: audioDataString ?? "") ?? Data()

// Prefix audioData with handler ID so the API can map the binary data
audioData.insert(sttBinaryHandlerId, at: 0)
workQueue.async { [connection] in
connection?.write(data: data)
connection?.write(data: audioData)
}
}

Expand All @@ -419,8 +420,8 @@ extension HAConnectionImpl {
sendWebSocket(identifier: identifier, request: request, command: command)
case let .rest(method, command):
sendRest(identifier: identifier!, request: request, method: method, command: command)
case let .data(data):
sendWrite(data)
case let .sttData(sttBinaryHandlerId):
sendWrite(sttBinaryHandlerId, audioDataString: request.data["audioData"] as? String)
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions Source/Internal/RequestController/HARequestController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ internal struct HARequestControllerAllowedSendKind: OptionSet {

static let webSocket: Self = .init(rawValue: 0b1)
static let rest: Self = .init(rawValue: 0b10)
static let data: Self = .init(rawValue: 0b10)
static let all: Self = [.webSocket, .rest, .data]
static let sttData: Self = .init(rawValue: 0b10)
static let all: Self = [.webSocket, .rest, .sttData]

func allows(requestType: HARequestType) -> Bool {
switch requestType {
case .webSocket:
return contains(.webSocket)
case .rest:
return contains(.rest)
case .data:
return contains(.data)
case .sttData:
return contains(.sttData)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions Source/Requests/HARequestType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
case webSocket(String)
/// Sent over REST, the HTTP method to use and the post-`api/` path
case rest(HAHTTPMethod, String)
/// Sent over WebSocket, the binary data to write
case data(Data)
/// Sent over WebSocket, the stt binary handler id
case sttData(UInt8)

/// Create a WebSocket request type by string literal
/// - Parameter value: The name of the WebSocket command
Expand All @@ -20,15 +20,15 @@ public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
switch self {
case let .webSocket(command), let .rest(_, command):
return command
case .data:
case .sttData:
return ""
}
}

/// The request is issued outside of the lifecycle of a connection
public var isPerpetual: Bool {
switch self {
case .webSocket, .data: return false
case .webSocket, .sttData: return false
case .rest: return true
}
}
Expand All @@ -45,7 +45,7 @@ public enum HARequestType: Hashable, Comparable, ExpressibleByStringLiteral {
case let (.webSocket(lhsCommand), .webSocket(rhsCommand)),
let (.rest(_, lhsCommand), .rest(_, rhsCommand)):
return lhsCommand < rhsCommand
case (.data, _), (_, .data):
case (.sttData, _), (_, .sttData):
return false
}
}
Expand Down
53 changes: 49 additions & 4 deletions Tests/HAConnectionImpl.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1462,17 +1462,62 @@ internal class HAConnectionImplTests: XCTestCase {
XCTAssertEqual(ObjectIdentifier(container.connection), ObjectIdentifier(connection))
}

func testWriteDataWritesData() {
func testWriteDataRequestAddsToRequestController() {
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(1),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
connection.connect()
connection.write(.init(type: .data(expectedData)))
connection.sendSttAudio(request)
XCTAssertNotNil(requestController.added.first(where: { invocation in
if case let .data(data) = invocation.request.type {
return data == expectedData
if case let .sttData(sttBinaryHandlerId) = invocation.request.type {
return sttBinaryHandlerId == 1 && invocation.request.data["audioData"] as? String == expectedData
.base64EncodedString()
}
return false
}))
}

func testSendRawWithSttRequest() {
connection.connect()
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(1),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
var expectedAudioData = expectedData
expectedAudioData.insert(1, at: 0)

connection.sendRaw(identifier: nil, request: request)
waitForWorkQueue()
XCTAssertNotNil(engine.events.first { event in
event == .writeData(expectedAudioData, opcode: .binaryFrame)
})
}

func testWriteSttRequestCommand() {
let expectedData = "Fake data".data(using: .utf8)!
let request = HARequest(
type: .sttData(1),
data: [
"audioData": expectedData.base64EncodedString(),
]
)
let request2 = HARequest(
type: .sttData(2),
data: [
"audioData": expectedData.base64EncodedString(),
]
)

XCTAssertEqual(request.type.command, "")
XCTAssertEqual(request.type < request2.type, false)
}
}

extension WebSocketEvent: Equatable {
Expand Down
6 changes: 6 additions & 0 deletions Tests/HARequestController.test.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ internal class HARequestControllerTests: XCTestCase {
XCTAssertEqual(types, Set(["test3", "test4"]))
}

func testAddingWriteRequestAllowed() {
delegate.allowedSendKinds = .all
controller.add(.init(request: .init(type: .sttData(1))))
XCTAssertEqual(delegate.didPrepare.count, 1)
}

func testCancelSingleBeforeSent() {
let invocation = HARequestInvocationSingle(
request: .init(type: "test1", data: [:]),
Expand Down

0 comments on commit b29f5cc

Please sign in to comment.