diff --git a/common/common.cpp b/common/common.cpp index 382d585a5e6f9..59e8296604c9c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -77,6 +77,41 @@ using json = nlohmann::ordered_json; +// +// Environment variable utils +// + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::string(value) : target; +} + +template +static typename std::enable_if::value && std::is_integral::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stoi(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + target = value ? std::stof(value) : target; +} + +template +static typename std::enable_if::value, void>::type +get_env(std::string name, T & target) { + char * value = std::getenv(name.c_str()); + if (value) { + std::string val(value); + target = val == "1" || val == "true"; + } +} + // // CPU utils // @@ -220,12 +255,6 @@ int32_t cpu_get_num_math() { // CLI argument parsing // -void gpt_params_handle_hf_token(gpt_params & params) { - if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { - params.hf_token = std::getenv("HF_TOKEN"); - } -} - void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { gpt_params_handle_model_default(params); - gpt_params_handle_hf_token(params); + if (params.hf_token.empty()) { + get_env("HF_TOKEN", params.hf_token); + } if (params.escape) { string_process_escapes(params.prompt); @@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { return true; } +void gpt_params_parse_from_env(gpt_params & params) { + // we only care about server-related params for now + get_env("LLAMA_ARG_MODEL", params.model); + get_env("LLAMA_ARG_THREADS", params.n_threads); + get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx); + get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel); + get_env("LLAMA_ARG_BATCH", params.n_batch); + get_env("LLAMA_ARG_UBATCH", params.n_ubatch); + get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers); + get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http); + get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template); + get_env("LLAMA_ARG_N_PREDICT", params.n_predict); + get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics); + get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots); + get_env("LLAMA_ARG_EMBEDDINGS", params.embedding); + get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn); + get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold); +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { const auto params_org = params; // the example can modify the default params diff --git a/common/common.h b/common/common.h index df23460a50fe0..f603ba2be1d35 100644 --- a/common/common.h +++ b/common/common.h @@ -267,7 +267,7 @@ struct gpt_params { std::string lora_outfile = "ggml-lora-merged-f16.gguf"; }; -void gpt_params_handle_hf_token(gpt_params & params); +void gpt_params_parse_from_env(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); diff --git a/examples/server/README.md b/examples/server/README.md index 930ae15f64d8b..abe245271195b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -247,6 +247,25 @@ logging: --log-append Don't truncate the old log file. ``` +Available environment variables (if specified, these variables will override parameters specified in arguments): + +- `LLAMA_CACHE` (cache directory, used by `--hf-repo`) +- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`) +- `LLAMA_ARG_MODEL` +- `LLAMA_ARG_THREADS` +- `LLAMA_ARG_CTX_SIZE` +- `LLAMA_ARG_N_PARALLEL` +- `LLAMA_ARG_BATCH` +- `LLAMA_ARG_UBATCH` +- `LLAMA_ARG_N_GPU_LAYERS` +- `LLAMA_ARG_THREADS_HTTP` +- `LLAMA_ARG_CHAT_TEMPLATE` +- `LLAMA_ARG_N_PREDICT` +- `LLAMA_ARG_ENDPOINT_METRICS` +- `LLAMA_ARG_ENDPOINT_SLOTS` +- `LLAMA_ARG_EMBEDDINGS` +- `LLAMA_ARG_FLASH_ATTN` +- `LLAMA_ARG_DEFRAG_THOLD` ## Build diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ce711eadd29ac..e79e7aa2cb846 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2507,6 +2507,9 @@ int main(int argc, char ** argv) { return 1; } + // parse arguments from environment variables + gpt_params_parse_from_env(params); + // TODO: not great to use extern vars server_log_json = params.log_json; server_verbose = params.verbosity > 0;