Skip to content

Commit

Permalink
fix: revert breaking Vertex AI changes
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Dec 18, 2024
1 parent 6735773 commit 003b3ad
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const BaseClient = require('./BaseClient');

const loc = process.env.GOOGLE_LOC || 'us-central1';
const publisher = 'google';
const endpointPrefix = `https://${loc}-aiplatform.googleapis.com`;
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
const tokenizersCache = {};

const settings = endpointSettings[EModelEndpoint.google];
Expand Down Expand Up @@ -67,6 +67,11 @@ class GoogleClient extends BaseClient {
this.setOptions(options);
}

/* Google specific methods */
constructUrl() {
return `https://${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`;
}

async getClient() {
const scopes = ['https://www.googleapis.com/auth/cloud-platform'];
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes);
Expand Down Expand Up @@ -204,6 +209,12 @@ class GoogleClient extends BaseClient {
this.modelOptions.stop = stopTokens;
}

if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else {
this.completionsUrl = this.constructUrl();
}

return this;
}

Expand Down Expand Up @@ -597,29 +608,23 @@ class GoogleClient extends BaseClient {

if (this.authHeader) {
requestOptions.customHeaders = {
'Authorization': `Bearer ${this.apiKey}`,
Authorization: `Bearer ${this.apiKey}`,
};
}
}

if (this.project_id && this.isTextModel) {
logger.debug('Creating Google VertexAI client');
return new GoogleVertexAI(clientOptions);
}
else if (this.project_id && this.isChatModel) {
} else if (this.project_id && this.isChatModel) {
logger.debug('Creating Chat Google VertexAI client');
return new ChatGoogleVertexAI(clientOptions);
}
else if (this.project_id) {
} else if (this.project_id) {
logger.debug('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
}
else if (!EXCLUDED_GENAI_MODELS.test(model)) {
} else if (!EXCLUDED_GENAI_MODELS.test(model)) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{ ...clientOptions, model },
requestOptions,
);
return new GenAI(this.apiKey).getGenerativeModel({ ...clientOptions, model }, requestOptions);
}

logger.debug('Creating Chat Google Generative AI client');
Expand Down Expand Up @@ -899,13 +904,11 @@ class GoogleClient extends BaseClient {
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold:
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold:
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
Expand All @@ -914,8 +917,7 @@ class GoogleClient extends BaseClient {
},
{
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
threshold:
process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
];
}
Expand Down

0 comments on commit 003b3ad

Please sign in to comment.