Skip to content

Commit

Permalink
[Vertex AI] Make generateContentStream/sendMessageStream throws (#13573)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Sep 3, 2024
1 parent 20ec9a3 commit 8f82f5d
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 47 deletions.
6 changes: 4 additions & 2 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
- [fixed] Resolved a decoding error for citations without a `uri` and added
support for decoding `title` fields, which were previously ignored. (#13518)
- [changed] **Breaking Change**: The methods for starting streaming requests
(`generateContentStream` and `sendMessageStream`) and creating a chat instance
(`startChat`) are now asynchronous and must be called with `await`. (#13545)
(`generateContentStream` and `sendMessageStream`) are now throwing and
asynchronous and must be called with `try await`. (#13545, #13573)
- [changed] **Breaking Change**: Creating a chat instance (`startChat`) is now
asynchronous and must be called with `await`. (#13545)

# 10.29.0
- [feature] Added community support for watchOS. (#13215)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ConversationViewModel: ObservableObject {
guard let chat else {
throw ChatError.notInitialized
}
let responseStream = await chat.sendMessageStream(text)
let responseStream = try await chat.sendMessageStream(text)
for try await chunk in responseStream {
messages[messages.count - 1].pending = false
if let text = chunk.text {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ class FunctionCallingViewModel: ObservableObject {
}
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
if functionResponses.isEmpty {
responseStream = await chat.sendMessageStream(text)
responseStream = try await chat.sendMessageStream(text)
} else {
for functionResponse in functionResponses {
messages.insert(functionResponse.chatMessage(), at: messages.count - 1)
}
responseStream = await chat.sendMessageStream(functionResponses.modelContent())
responseStream = try await chat.sendMessageStream(functionResponses.modelContent())
}
for try await chunk in responseStream {
processResponseContent(content: chunk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject {
}
}

let outputContentStream = await model.generateContentStream(prompt, images)
let outputContentStream = try await model.generateContentStream(prompt, images)

// stream response
for try await outputContent in outputContentStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject {

let prompt = "Summarize the following text for me: \(inputText)"

let outputContentStream = await model.generateContentStream(prompt)
let outputContentStream = try await model.generateContentStream(prompt)

// stream response
for try await outputContent in outputContentStream {
Expand Down
17 changes: 6 additions & 11 deletions FirebaseVertexAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public actor Chat {
/// - Parameter parts: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try sendMessageStream([ModelContent(parts: parts)])
}
Expand All @@ -95,21 +95,16 @@ public actor Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

return AsyncThrowingStream { continuation in
Expand All @@ -121,7 +116,7 @@ public actor Chat {

// Send the history alongside the new message as context.
let request = history + newContent
let stream = await model.generateContentStream(request)
let stream = try await model.generateContentStream(request)
do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
Expand Down
15 changes: 5 additions & 10 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public final actor GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return try generateContentStream([ModelContent(parts: parts)])
}
Expand All @@ -190,21 +190,16 @@ public final actor GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent]) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let evaluatedContent: [ModelContent]
do {
evaluatedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
}
throw GenerateContentError.internalError(underlying: underlying)
}

let generateContentRequest = GenerateContentRequest(model: modelResourceName,
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ final class ChatTests: XCTestCase {
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = await chat.sendMessageStream(input)
let stream = try await chat.sendMessageStream(input)

// Ensure the values are parsed correctly
for try await value in stream {
Expand Down
36 changes: 18 additions & 18 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -784,7 +784,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -807,7 +807,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
Expand All @@ -827,7 +827,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
Expand All @@ -847,7 +847,7 @@ final class GenerativeModelTests: XCTestCase {
)

do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
Expand All @@ -866,7 +866,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
do {
for try await content in stream {
XCTAssertNotNil(content.text)
Expand All @@ -887,7 +887,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand All @@ -904,7 +904,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand All @@ -921,7 +921,7 @@ final class GenerativeModelTests: XCTestCase {
)

var hadUnknown = false
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
if let ratings = content.candidates.first?.safetyRatings,
Expand All @@ -940,7 +940,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
var citations = [Citation]()
var responses = [GenerateContentResponse]()
for try await content in stream {
Expand Down Expand Up @@ -996,7 +996,7 @@ final class GenerativeModelTests: XCTestCase {
appCheckToken: appCheckToken
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {}
}

Expand All @@ -1018,7 +1018,7 @@ final class GenerativeModelTests: XCTestCase {
appCheckToken: AppCheckInteropFake.placeholderTokenValue
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await _ in stream {}
}

Expand All @@ -1030,7 +1030,7 @@ final class GenerativeModelTests: XCTestCase {
)
var responses = [GenerateContentResponse]()

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}
Expand All @@ -1056,7 +1056,7 @@ final class GenerativeModelTests: XCTestCase {

var responseCount = 0
do {
let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
for try await content in stream {
XCTAssertNotNil(content.text)
responseCount += 1
Expand All @@ -1076,7 +1076,7 @@ final class GenerativeModelTests: XCTestCase {
func testGenerateContentStream_nonHTTPResponse() async throws {
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()

let stream = await model.generateContentStream("Hi")
let stream = try await model.generateContentStream("Hi")
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand All @@ -1096,7 +1096,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand All @@ -1120,7 +1120,7 @@ final class GenerativeModelTests: XCTestCase {
withExtension: "txt"
)

let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
Expand Down Expand Up @@ -1159,7 +1159,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responses = 0
let stream = await model.generateContentStream(testPrompt)
let stream = try await model.generateContentStream(testPrompt)
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
Expand Down

0 comments on commit 8f82f5d

Please sign in to comment.