Skip to content

Commit

Permalink
Remove unknown static var and refactor implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Oct 8, 2024
1 parent 8cfaf2e commit 5dc2415
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ extension HarmCategory: CustomStringConvertible {
case .harassment: "Harassment"
case .hateSpeech: "Hate speech"
case .sexuallyExplicit: "Sexually explicit"
case .unknown: "Unknown"
case .civicIntegrity: "Civic integrity"
default:
if isUnknown() {
"Unknown HarmCategory: \(rawValue)"
} else {
"Unhandled HarmCategory: \(rawValue)"
}
}
}
}
Expand Down
60 changes: 49 additions & 11 deletions FirebaseVertexAI/Sources/Safety.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,61 @@ public struct SafetySetting {
}

/// Categories describing the potential harm a piece of content may pose.
public enum HarmCategory: String, Sendable {
/// Unknown. A new server value that isn't recognized by the SDK.
case unknown = "HARM_CATEGORY_UNKNOWN"
public struct HarmCategory: Sendable, Equatable, Hashable {
enum Kind: String {
case harassment = "HARM_CATEGORY_HARASSMENT"
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
case civicIntegrity = "HARM_CATEGORY_CIVIC_INTEGRITY"
}

/// Harassment content.
case harassment = "HARM_CATEGORY_HARASSMENT"
public static var harassment: HarmCategory {
return self.init(rawValue: Kind.harassment.rawValue)
}

/// Negative or harmful comments targeting identity and/or protected attributes.
case hateSpeech = "HARM_CATEGORY_HATE_SPEECH"
public static var hateSpeech: HarmCategory {
return self.init(rawValue: Kind.hateSpeech.rawValue)
}

/// Contains references to sexual acts or other lewd content.
case sexuallyExplicit = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
public static var sexuallyExplicit: HarmCategory {
return self.init(rawValue: Kind.sexuallyExplicit.rawValue)
}

/// Promotes or enables access to harmful goods, services, or activities.
case dangerousContent = "HARM_CATEGORY_DANGEROUS_CONTENT"
public static var dangerousContent: HarmCategory {
return self.init(rawValue: Kind.dangerousContent.rawValue)
}

/// Content related to civic integrity.
public static var civicIntegrity: HarmCategory {
return self.init(rawValue: Kind.civicIntegrity.rawValue)
}

/// Returns true if the HarmCategory's `rawValue` is unknown to the SDK.
///
/// > Important: If an unknown value is encountered, check for updates to the SDK as support for
/// > the new value may have been added; see
/// > [Release Notes](https://firebase.google.com/support/release-notes/ios). Alternatively,
/// > search for the `rawValue` in the Firebase Apple SDK
/// > [Issue Tracker](https://github.com/firebase/firebase-ios-sdk/issues) and file a
/// > [Bug Report](https://github.com/firebase/firebase-ios-sdk/issues/new/choose) if none found.
public func isUnknown() -> Bool {
return Kind(rawValue: rawValue) == nil
}

/// Returns the raw string representation of the `HarmCategory` value.
///
/// > Note: This value directly corresponds to the values in the
/// > [REST API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/HarmCategory).
public let rawValue: String

init(rawValue: String) {
self.rawValue = rawValue
}
}

// MARK: - Codable Conformances
Expand Down Expand Up @@ -140,15 +180,13 @@ extension SafetyRating: Decodable {}
extension HarmCategory: Codable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedCategory = HarmCategory(rawValue: value) else {
let decodedCategory = HarmCategory(rawValue: value)
if decodedCategory.isUnknown() {
VertexLog.error(
code: .generateContentResponseUnrecognizedHarmCategory,
"Unrecognized HarmCategory with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedCategory
}
}
Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ final class GenerativeModelTests: XCTestCase {
let expectedSafetyRatings = [
SafetyRating(category: .harassment, probability: .medium),
SafetyRating(category: .dangerousContent, probability: .unknown),
SafetyRating(category: .unknown, probability: .high),
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
]
MockURLProtocol
.requestHandler = try httpRequestHandler(
Expand Down Expand Up @@ -978,7 +978,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTAssertNotNil(content.text)
if let ratings = content.candidates.first?.safetyRatings,
ratings.contains(where: { $0.category == .unknown }) {
ratings.contains(where: { $0.category.isUnknown() }) {
hadUnknown = true
}
}
Expand Down

0 comments on commit 5dc2415

Please sign in to comment.