Skip to content

Commit

Permalink
Run ci api key (#2315)
Browse files Browse the repository at this point in the history
* Add API_Key for Auth and conditionally add authorisation for non info/health endpoints.

* change name to info routes

* Fix comment

* convert strings to lowercase for case insensitive comparison

* convert header to string

* fixes and update docs

* update docs again

* revert wrong update

---------

Co-authored-by: Kevin Duffy <[email protected]>
  • Loading branch information
ErikKaum and KevinDuffy94 authored Jul 29, 2024
1 parent fd2e063 commit 583d37a
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 4 deletions.
6 changes: 6 additions & 0 deletions docs/source/basic_tutorials/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ Options:
--cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=]
```
## API_KEY
```shell
--api-key <API_KEY>
[env: API_KEY=]
```
## WATERMARK_GAMMA
```shell
Expand Down
9 changes: 9 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ struct Args {

#[clap(long, env)]
cors_allow_origin: Vec<String>,

#[clap(long, env)]
api_key: Option<String>,

#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
Expand Down Expand Up @@ -1271,6 +1275,11 @@ fn spawn_webserver(
router_args.push(origin);
}

// API Key
if let Some(api_key) = args.api_key {
router_args.push("--api-key".to_string());
router_args.push(api_key);
}
// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct Args {
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
Expand Down Expand Up @@ -127,6 +129,7 @@ async fn main() -> Result<(), RouterError> {
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
Expand Down Expand Up @@ -446,6 +449,7 @@ async fn main() -> Result<(), RouterError> {
validation_workers,
addr,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
Expand Down
37 changes: 33 additions & 4 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream;
use futures::TryStreamExt;
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::convert::Infallible;
Expand Down Expand Up @@ -1417,6 +1418,7 @@ pub async fn run(
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
api_key: Option<String>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
Expand Down Expand Up @@ -1810,16 +1812,42 @@ pub async fn run(
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);

// Define base and health routes
let base_routes = Router::new()
let mut base_routes = Router::new()
.route("/", post(compat_generate))
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/tokenize", post(tokenize))
.route("/tokenize", post(tokenize));

if let Some(api_key) = api_key {
let mut prefix = "Bearer ".to_string();
prefix.push_str(&api_key);

// Leak to allow FnMut
let api_key: &'static str = prefix.leak();

let auth = move |headers: HeaderMap,
request: axum::extract::Request,
next: axum::middleware::Next| async move {
match headers.get(AUTHORIZATION) {
Some(token) => match token.to_str() {
Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => {
let response = next.run(request).await;
Ok(response)
}
_ => Err(StatusCode::UNAUTHORIZED),
},
None => Err(StatusCode::UNAUTHORIZED),
}
};

base_routes = base_routes.layer(axum::middleware::from_fn(auth))
}
let info_routes = Router::new()
.route("/", get(health))
.route("/info", get(get_model_info))
.route("/health", get(health))
.route("/ping", get(health))
.route("/metrics", get(metrics));
Expand All @@ -1838,6 +1866,7 @@ pub async fn run(
let mut app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(info_routes)
.merge(aws_sagemaker_route);

#[cfg(feature = "google")]
Expand Down

0 comments on commit 583d37a

Please sign in to comment.