From 27cffd92bcae0198f68b899d769c7c76c9715c56 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 9 Oct 2024 18:29:40 -0400 Subject: [PATCH] [Vertex AI] Refactor `HarmBlockThreshold` as a struct and add `.off` (#13863) --- FirebaseVertexAI/CHANGELOG.md | 6 ++++ FirebaseVertexAI/Sources/Safety.swift | 35 +++++++++++++++---- .../Tests/Integration/IntegrationTests.swift | 12 +++++++ 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index f5f072975c7..b6444fb81f0 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -46,6 +46,12 @@ - [changed] The response from `GenerativeModel.countTokens(...)` now includes `systemInstruction`, `tools` and `generationConfig` in the `totalTokens` and `totalBillableCharacters` counts, where applicable. (#13813) +- [added] Added a new `HarmCategory` `.civicIntegrity` for filtering content + that may be used to harm civic integrity. (#13728) +- [added] Added a new `HarmBlockThreshold` `.off`, which turns off the safety + filter. (#13863) +- [added] Added new `FinishReason` values `.blocklist`, `.prohibitedContent`, + `.spii` and `.malformedFunctionCall` that may be reported. (#13860) # 11.3.0 - [added] Added `Decodable` conformance for `FunctionResponse`. (#13606) diff --git a/FirebaseVertexAI/Sources/Safety.swift b/FirebaseVertexAI/Sources/Safety.swift index a3d548fd524..4e93a94bf45 100644 --- a/FirebaseVertexAI/Sources/Safety.swift +++ b/FirebaseVertexAI/Sources/Safety.swift @@ -90,18 +90,41 @@ public struct SafetyRating: Equatable, Hashable, Sendable { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct SafetySetting { /// Block at and beyond a specified ``SafetyRating/HarmProbability``. - public enum HarmBlockThreshold: String, Sendable { - // Content with `.negligible` will be allowed. - case blockLowAndAbove = "BLOCK_LOW_AND_ABOVE" + public struct HarmBlockThreshold: EncodableProtoEnum, Sendable { + enum Kind: String { + case blockLowAndAbove = "BLOCK_LOW_AND_ABOVE" + case blockMediumAndAbove = "BLOCK_MEDIUM_AND_ABOVE" + case blockOnlyHigh = "BLOCK_ONLY_HIGH" + case blockNone = "BLOCK_NONE" + case off = "OFF" + } + + /// Content with `.negligible` will be allowed. + public static var blockLowAndAbove: HarmBlockThreshold { + return self.init(kind: .blockLowAndAbove) + } /// Content with `.negligible` and `.low` will be allowed. - case blockMediumAndAbove = "BLOCK_MEDIUM_AND_ABOVE" + public static var blockMediumAndAbove: HarmBlockThreshold { + return self.init(kind: .blockMediumAndAbove) + } /// Content with `.negligible`, `.low`, and `.medium` will be allowed. - case blockOnlyHigh = "BLOCK_ONLY_HIGH" + public static var blockOnlyHigh: HarmBlockThreshold { + return self.init(kind: .blockOnlyHigh) + } /// All content will be allowed. - case blockNone = "BLOCK_NONE" + public static var blockNone: HarmBlockThreshold { + return self.init(kind: .blockNone) + } + + /// Turn off the safety filter. + public static var off: HarmBlockThreshold { + return self.init(kind: .off) + } + + let rawValue: String } enum CodingKeys: String, CodingKey { diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index a1ee926273f..51241c915c2 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -84,6 +84,18 @@ final class IntegrationTests: XCTestCase { func testCountTokens_text() async throws { let prompt = "Why is the sky blue?" + model = vertex.generativeModel( + modelName: "gemini-1.5-pro", + generationConfig: generationConfig, + safetySettings: [ + SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove), + SafetySetting(harmCategory: .hateSpeech, threshold: .blockMediumAndAbove), + SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockOnlyHigh), + SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone), + SafetySetting(harmCategory: .civicIntegrity, threshold: .off), + ], + systemInstruction: systemInstruction + ) let response = try await model.countTokens(prompt)