From 71d8e92d168333f79ee29a555d32fee086e66ca5 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 3 Sep 2024 21:49:51 -0400 Subject: [PATCH] [Vertex AI] Add `responseSchema` to `GenerationConfig` (#13576) --- FirebaseVertexAI/CHANGELOG.md | 4 ++ .../Sources/GenerationConfig.swift | 17 ++++++- .../Tests/Unit/GenerationConfigTests.swift | 49 ++++++++++++++++--- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index a848362a3a8..7390ef0c57c 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -9,6 +9,10 @@ - [changed] **Breaking Change**: The source image in the `ImageConversionError.couldNotConvertToJPEG` error case is now an enum value instead of the `Any` type. (#13575) +- [added] Added support for specifying a JSON `responseSchema` in + `GenerationConfig`; see + [control generated output](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output) + for more details. (#13576) # 10.29.0 - [feature] Added community support for watchOS. (#13215) diff --git a/FirebaseVertexAI/Sources/GenerationConfig.swift b/FirebaseVertexAI/Sources/GenerationConfig.swift index ec29b708c75..3f3a4b6f214 100644 --- a/FirebaseVertexAI/Sources/GenerationConfig.swift +++ b/FirebaseVertexAI/Sources/GenerationConfig.swift @@ -70,6 +70,17 @@ public struct GenerationConfig { /// - `application/json`: JSON response in the candidates. public let responseMIMEType: String? + /// Output schema of the generated candidate text. + /// If set, a compatible ``responseMIMEType`` must also be set. + /// + /// Compatible MIME types: + /// - `application/json`: Schema for JSON response. + /// + /// Refer to the [Control generated + /// output](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output) + /// guide for more details. + public let responseSchema: Schema? + /// Creates a new `GenerationConfig` value. /// /// - Parameter temperature: See ``temperature`` @@ -78,9 +89,12 @@ public struct GenerationConfig { /// - Parameter candidateCount: See ``candidateCount`` /// - Parameter maxOutputTokens: See ``maxOutputTokens`` /// - Parameter stopSequences: See ``stopSequences`` + /// - Parameter responseMIMEType: See ``responseMIMEType`` + /// - Parameter responseSchema: See ``responseSchema`` public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil, maxOutputTokens: Int? = nil, - stopSequences: [String]? = nil, responseMIMEType: String? = nil) { + stopSequences: [String]? = nil, responseMIMEType: String? = nil, + responseSchema: Schema? = nil) { // Explicit init because otherwise if we re-arrange the above variables it changes the API // surface. self.temperature = temperature @@ -90,6 +104,7 @@ public struct GenerationConfig { self.maxOutputTokens = maxOutputTokens self.stopSequences = stopSequences self.responseMIMEType = responseMIMEType + self.responseSchema = responseSchema } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift index e2bfe0d4fb6..35450c03758 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift @@ -48,7 +48,7 @@ final class GenerationConfigTests: XCTestCase { let candidateCount = 2 let maxOutputTokens = 256 let stopSequences = ["END", "DONE"] - let responseMIMEType = "text/plain" + let responseMIMEType = "application/json" let generationConfig = GenerationConfig( temperature: temperature, topP: topP, @@ -56,7 +56,8 @@ final class GenerationConfigTests: XCTestCase { candidateCount: candidateCount, maxOutputTokens: maxOutputTokens, stopSequences: stopSequences, - responseMIMEType: responseMIMEType + responseMIMEType: responseMIMEType, + responseSchema: Schema(type: .array, items: Schema(type: .string)) ) let jsonData = try encoder.encode(generationConfig) @@ -67,6 +68,12 @@ final class GenerationConfigTests: XCTestCase { "candidateCount" : \(candidateCount), "maxOutputTokens" : \(maxOutputTokens), "responseMIMEType" : "\(responseMIMEType)", + "responseSchema" : { + "items" : { + "type" : "STRING" + }, + "type" : "ARRAY" + }, "stopSequences" : [ "END", "DONE" @@ -78,16 +85,46 @@ final class GenerationConfigTests: XCTestCase { """) } - func testEncodeGenerationConfig_responseMIMEType() throws { - let mimeType = "image/jpeg" - let generationConfig = GenerationConfig(responseMIMEType: mimeType) + func testEncodeGenerationConfig_jsonResponse() throws { + let mimeType = "application/json" + let generationConfig = GenerationConfig( + responseMIMEType: mimeType, + responseSchema: Schema( + type: .object, + properties: [ + "firstName": Schema(type: .string), + "lastName": Schema(type: .string), + "age": Schema(type: .integer), + ], + requiredProperties: ["firstName", "lastName", "age"] + ) + ) let jsonData = try encoder.encode(generationConfig) let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) XCTAssertEqual(json, """ { - "responseMIMEType" : "\(mimeType)" + "responseMIMEType" : "\(mimeType)", + "responseSchema" : { + "properties" : { + "age" : { + "type" : "INTEGER" + }, + "firstName" : { + "type" : "STRING" + }, + "lastName" : { + "type" : "STRING" + } + }, + "required" : [ + "firstName", + "lastName", + "age" + ], + "type" : "OBJECT" + } } """) }