From 742d7b085b089b16d2271bcd1bd79fc2ef215975 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 9 Dec 2024 11:28:21 -0500 Subject: [PATCH] [Vertex AI] Bypass proxy for testing against staging --- FirebaseVertexAI/Sources/Constants.swift | 2 +- .../Sources/CountTokensRequest.swift | 5 +++- .../Sources/GenerateContentRequest.swift | 3 ++- .../Sources/GenerativeAIRequest.swift | 2 +- .../Sources/GenerativeAIService.swift | 24 +++++++------------ .../Sources/GenerativeModel.swift | 7 ++++++ FirebaseVertexAI/Sources/VertexAI.swift | 1 + 7 files changed, 24 insertions(+), 20 deletions(-) diff --git a/FirebaseVertexAI/Sources/Constants.swift b/FirebaseVertexAI/Sources/Constants.swift index 8f410c8768a..a88f95520d3 100644 --- a/FirebaseVertexAI/Sources/Constants.swift +++ b/FirebaseVertexAI/Sources/Constants.swift @@ -17,7 +17,7 @@ import Foundation /// Constants associated with the Vertex AI for Firebase SDK. enum Constants { /// The Vertex AI backend endpoint URL. - static let baseURL = "https://firebasevertexai.googleapis.com" + static let baseURL = "staging-aiplatform.sandbox.googleapis.com" /// The base reverse-DNS name for `NSError` or `CustomNSError` error domains. /// diff --git a/FirebaseVertexAI/Sources/CountTokensRequest.swift b/FirebaseVertexAI/Sources/CountTokensRequest.swift index 6c36d96b4c0..78f678030b9 100644 --- a/FirebaseVertexAI/Sources/CountTokensRequest.swift +++ b/FirebaseVertexAI/Sources/CountTokensRequest.swift @@ -17,6 +17,7 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) struct CountTokensRequest { let model: String + let location: String let contents: [ModelContent] let systemInstruction: ModelContent? @@ -31,7 +32,9 @@ extension CountTokensRequest: GenerativeAIRequest { typealias Response = CountTokensResponse var url: URL { - URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens")! + URL( + string: "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model):countTokens" + )! } } diff --git a/FirebaseVertexAI/Sources/GenerateContentRequest.swift b/FirebaseVertexAI/Sources/GenerateContentRequest.swift index ffa98fe4159..9f4cd8d0317 100644 --- a/FirebaseVertexAI/Sources/GenerateContentRequest.swift +++ b/FirebaseVertexAI/Sources/GenerateContentRequest.swift @@ -18,6 +18,7 @@ import Foundation struct GenerateContentRequest { /// Model name. let model: String + let location: String let contents: [ModelContent] let generationConfig: GenerationConfig? let safetySettings: [SafetySetting]? @@ -45,7 +46,7 @@ extension GenerateContentRequest: GenerativeAIRequest { typealias Response = GenerateContentResponse var url: URL { - let modelURL = "\(Constants.baseURL)/\(options.apiVersion)/\(model)" + let modelURL = "https://\(location)-\(Constants.baseURL)/\(options.apiVersion)/\(model)" if isStreaming { return URL(string: "\(modelURL):streamGenerateContent?alt=sse")! } else { diff --git a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift index b792830120e..9594ab3fbce 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift @@ -31,7 +31,7 @@ public struct RequestOptions { let timeout: TimeInterval /// The API version to use in requests to the backend. - let apiVersion = "v1beta" + let apiVersion = "v1beta1" /// Initializes a request options object. /// diff --git a/FirebaseVertexAI/Sources/GenerativeAIService.swift b/FirebaseVertexAI/Sources/GenerativeAIService.swift index 667819c5c76..37ea85c10a1 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIService.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIService.swift @@ -180,28 +180,20 @@ struct GenerativeAIService { private func urlRequest(request: T) async throws -> URLRequest { var urlRequest = URLRequest(url: request.url) urlRequest.httpMethod = "POST" - urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") + guard let accessToken = ProcessInfo.processInfo.environment["GCLOUD_ACCESS_TOKEN"] else { + fatalError(""" + Missing access token; run `gcloud auth print-access-token` and add an environment variable \ + `GCLOUD_ACCESS_TOKEN` with the printed value. + Note: This value will only be valid for 60 minutes. + """) + } + urlRequest.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") urlRequest.setValue( "\(GenerativeAIService.languageTag) \(GenerativeAIService.firebaseVersionTag)", forHTTPHeaderField: "x-goog-api-client" ) urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - if let appCheck { - let tokenResult = await appCheck.getToken(forcingRefresh: false) - urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck") - if let error = tokenResult.error { - VertexLog.error( - code: .appCheckTokenFetchFailed, - "Failed to fetch AppCheck token. Error: \(error)" - ) - } - } - - if let auth, let authToken = try await auth.getToken(forcingRefresh: false) { - urlRequest.setValue("Firebase \(authToken)", forHTTPHeaderField: "Authorization") - } - let encoder = JSONEncoder() encoder.keyEncodingStrategy = .convertToSnakeCase urlRequest.httpBody = try encoder.encode(request) diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index 0d2ea829f55..e972f1d8861 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -26,6 +26,8 @@ public final class GenerativeModel { /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService + let location: String + /// Configuration parameters used for the MultiModalModel. let generationConfig: GenerationConfig? @@ -61,6 +63,7 @@ public final class GenerativeModel { init(name: String, projectID: String, apiKey: String, + location: String = "us-central1", generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, tools: [Tool]?, @@ -78,6 +81,7 @@ public final class GenerativeModel { auth: auth, urlSession: urlSession ) + self.location = location self.generationConfig = generationConfig self.safetySettings = safetySettings self.tools = tools @@ -125,6 +129,7 @@ public final class GenerativeModel { try content.throwIfError() let response: GenerateContentResponse let generateContentRequest = GenerateContentRequest(model: modelResourceName, + location: location, contents: content, generationConfig: generationConfig, safetySettings: safetySettings, @@ -182,6 +187,7 @@ public final class GenerativeModel { -> AsyncThrowingStream { try content.throwIfError() let generateContentRequest = GenerateContentRequest(model: modelResourceName, + location: location, contents: content, generationConfig: generationConfig, safetySettings: safetySettings, @@ -253,6 +259,7 @@ public final class GenerativeModel { public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse { let countTokensRequest = CountTokensRequest( model: modelResourceName, + location: location, contents: content, systemInstruction: systemInstruction, tools: tools, diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index c0cd2cb66a3..00907b63843 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -93,6 +93,7 @@ public class VertexAI { name: modelResourceName(modelName: modelName), projectID: projectID, apiKey: apiKey, + location: location, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools,