diff --git a/cmd/workspace/models.go b/cmd/workspace/models.go index e25d62f01..c272dacc5 100644 --- a/cmd/workspace/models.go +++ b/cmd/workspace/models.go @@ -9,4 +9,5 @@ import ( _ "github.com/kaito-project/kaito/presets/workspace/models/mistral" _ "github.com/kaito-project/kaito/presets/workspace/models/phi2" _ "github.com/kaito-project/kaito/presets/workspace/models/phi3" + _ "github.com/kaito-project/kaito/presets/workspace/models/qwen" ) diff --git a/docs/inference/README.md b/docs/inference/README.md index ef7032a75..bf28ef835 100644 --- a/docs/inference/README.md +++ b/docs/inference/README.md @@ -96,6 +96,89 @@ For detailed `InferenceSpec` API definitions, refer to the [documentation](https The OpenAPI specification for the inference API is available at [vLLM API](../../presets/workspace/inference/vllm/api_spec.json), [transformers API](../../presets/workspace/inference/text-generation/api_spec.json). +#### vLLM inference API + +vLLM supports OpenAI-compatible inference APIs. Check [here](https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html) for more details. + +#### Transformers inference API + +The inference service endpoint is `/chat`. + +**basic example** +``` +curl -X POST "http://:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"YOUR_PROMPT_HERE"}' +``` + +**example with full configurable parameters** +``` +curl -X POST \ + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ + "prompt":"YOUR_PROMPT_HERE", + "return_full_text": false, + "clean_up_tokenization_spaces": false, + "prefix": null, + "handle_long_generation": null, + "generate_kwargs": { + "max_length":200, + "min_length":0, + "do_sample":true, + "early_stopping":false, + "num_beams":1, + "num_beam_groups":1, + "diversity_penalty":0.0, + "temperature":1.0, + "top_k":10, + "top_p":1, + "typical_p":1, + "repetition_penalty":1, + "length_penalty":1, + "no_repeat_ngram_size":0, + "encoder_no_repeat_ngram_size":0, + "bad_words_ids":null, + "num_return_sequences":1, + "output_scores":false, + "return_dict_in_generate":false, + "forced_bos_token_id":null, + "forced_eos_token_id":null, + "remove_invalid_values":null + } + }' \ + "http://:80/chat" +``` + +**parameters** +- `prompt`: The initial text provided by the user, from which the model will continue generating text. +- `return_full_text`: If False only generated text is returned, else full text is returned. +- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output. +- `prefix`: Prefix added to the prompt. +- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity. +- `max_length`: The maximum total number of tokens in the generated text. +- `min_length`: The minimum total number of tokens that should be generated. +- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation. +- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search. +- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive. +- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results. +- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs. +- `temperature`: Controls the randomness of the output by scaling the logits before sampling. +- `top_k`: Restricts sampling to the k most likely next tokens. +- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass. +- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context. +- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition. +- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs. +- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once. +- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models. +- `bad_words_ids`: A list of token ids that should not be generated. +- `num_return_sequences`: The number of different sequences to generate. +- `output_scores`: Whether to output the prediction scores. +- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information. +- `pad_token_id`: The token ID used for padding sequences to the same length. +- `eos_token_id`: The token ID that signifies the end of a sequence. +- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token. +- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached. +- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes. + # Inference workload Depending on whether the specified model supports distributed inference or not, the Kaito controller will choose to use either Kubernetes **apps.deployment** workload (by default) or Kubernetes **apps.statefulset** workload (if the model supports distributed inference) to manage the inference service, which is exposed using a Cluster-IP type of Kubernetes `service`. diff --git a/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml b/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml new file mode 100644 index 000000000..15586cdae --- /dev/null +++ b/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml @@ -0,0 +1,12 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-qwen-2.5-coder-7b-instruct +resource: + instanceType: "Standard_NC24ads_A100_v4" + labelSelector: + matchLabels: + apps: qwen-2.5-coder +inference: + preset: + name: qwen2.5-coder-7b-instruct diff --git a/presets/workspace/models/falcon/README.md b/presets/workspace/models/falcon/README.md index b71a59e5d..c99a0242e 100644 --- a/presets/workspace/models/falcon/README.md +++ b/presets/workspace/models/falcon/README.md @@ -11,79 +11,4 @@ ## Usage -The inference service endpoint is `/chat`. - -### Basic example -``` -curl -X POST "http://:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"YOUR_PROMPT_HERE"}' -``` - -### Example with full configurable parameters -``` -curl -X POST \ - -H "accept: application/json" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt":"YOUR_PROMPT_HERE", - "return_full_text": false, - "clean_up_tokenization_spaces": false, - "prefix": null, - "handle_long_generation": null, - "generate_kwargs": { - "max_length":200, - "min_length":0, - "do_sample":true, - "early_stopping":false, - "num_beams":1, - "num_beam_groups":1, - "diversity_penalty":0.0, - "temperature":1.0, - "top_k":10, - "top_p":1, - "typical_p":1, - "repetition_penalty":1, - "length_penalty":1, - "no_repeat_ngram_size":0, - "encoder_no_repeat_ngram_size":0, - "bad_words_ids":null, - "num_return_sequences":1, - "output_scores":false, - "return_dict_in_generate":false, - "forced_bos_token_id":null, - "forced_eos_token_id":null, - "remove_invalid_values":null - } - }' \ - "http://:80/chat" -``` - -### Parameters -- `prompt`: The initial text provided by the user, from which the model will continue generating text. -- `return_full_text`: If False only generated text is returned, else full text is returned. -- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output. -- `prefix`: Prefix added to the prompt. -- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity. -- `max_length`: The maximum total number of tokens in the generated text. -- `min_length`: The minimum total number of tokens that should be generated. -- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation. -- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search. -- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive. -- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results. -- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs. -- `temperature`: Controls the randomness of the output by scaling the logits before sampling. -- `top_k`: Restricts sampling to the k most likely next tokens. -- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass. -- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context. -- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition. -- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs. -- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once. -- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models. -- `bad_words_ids`: A list of token ids that should not be generated. -- `num_return_sequences`: The number of different sequences to generate. -- `output_scores`: Whether to output the prediction scores. -- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information. -- `pad_token_id`: The token ID used for padding sequences to the same length. -- `eos_token_id`: The token ID that signifies the end of a sequence. -- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token. -- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached. -- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes. +See [document](../../../../docs/inference/README.md). diff --git a/presets/workspace/models/mistral/README.md b/presets/workspace/models/mistral/README.md index c1aa47e0e..62a5ae3d0 100644 --- a/presets/workspace/models/mistral/README.md +++ b/presets/workspace/models/mistral/README.md @@ -10,79 +10,4 @@ ## Usage -The inference service endpoint is `/chat`. - -### Basic example -``` -curl -X POST "http://:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"YOUR_PROMPT_HERE"}' -``` - -### Example with full configurable parameters -``` -curl -X POST \ - -H "accept: application/json" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt":"YOUR_PROMPT_HERE", - "return_full_text": false, - "clean_up_tokenization_spaces": false, - "prefix": null, - "handle_long_generation": null, - "generate_kwargs": { - "max_length":200, - "min_length":0, - "do_sample":true, - "early_stopping":false, - "num_beams":1, - "num_beam_groups":1, - "diversity_penalty":0.0, - "temperature":1.0, - "top_k":10, - "top_p":1, - "typical_p":1, - "repetition_penalty":1, - "length_penalty":1, - "no_repeat_ngram_size":0, - "encoder_no_repeat_ngram_size":0, - "bad_words_ids":null, - "num_return_sequences":1, - "output_scores":false, - "return_dict_in_generate":false, - "forced_bos_token_id":null, - "forced_eos_token_id":null, - "remove_invalid_values":null - } - }' \ - "http://:80/chat" -``` - -### Parameters -- `prompt`: The initial text provided by the user, from which the model will continue generating text. -- `return_full_text`: If False only generated text is returned, else full text is returned. -- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output. -- `prefix`: Prefix added to the prompt. -- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity. -- `max_length`: The maximum total number of tokens in the generated text. -- `min_length`: The minimum total number of tokens that should be generated. -- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation. -- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search. -- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive. -- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results. -- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs. -- `temperature`: Controls the randomness of the output by scaling the logits before sampling. -- `top_k`: Restricts sampling to the k most likely next tokens. -- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass. -- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context. -- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition. -- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs. -- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once. -- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models. -- `bad_words_ids`: A list of token ids that should not be generated. -- `num_return_sequences`: The number of different sequences to generate. -- `output_scores`: Whether to output the prediction scores. -- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information. -- `pad_token_id`: The token ID used for padding sequences to the same length. -- `eos_token_id`: The token ID that signifies the end of a sequence. -- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token. -- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached. -- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes. +See [document](../../../../docs/inference/README.md). diff --git a/presets/workspace/models/phi2/README.md b/presets/workspace/models/phi2/README.md index 7dfae1274..7e3588c13 100644 --- a/presets/workspace/models/phi2/README.md +++ b/presets/workspace/models/phi2/README.md @@ -9,79 +9,4 @@ ## Usage -The inference service endpoint is `/chat`. - -### Basic example -``` -curl -X POST "http://:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"YOUR_PROMPT_HERE"}' -``` - -### Example with full configurable parameters -``` -curl -X POST \ - -H "accept: application/json" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt":"YOUR_PROMPT_HERE", - "return_full_text": false, - "clean_up_tokenization_spaces": false, - "prefix": null, - "handle_long_generation": null, - "generate_kwargs": { - "max_length":200, - "min_length":0, - "do_sample":true, - "early_stopping":false, - "num_beams":1, - "num_beam_groups":1, - "diversity_penalty":0.0, - "temperature":1.0, - "top_k":10, - "top_p":1, - "typical_p":1, - "repetition_penalty":1, - "length_penalty":1, - "no_repeat_ngram_size":0, - "encoder_no_repeat_ngram_size":0, - "bad_words_ids":null, - "num_return_sequences":1, - "output_scores":false, - "return_dict_in_generate":false, - "forced_bos_token_id":null, - "forced_eos_token_id":null, - "remove_invalid_values":null - } - }' \ - "http://:80/chat" -``` - -### Parameters -- `prompt`: The initial text provided by the user, from which the model will continue generating text. -- `return_full_text`: If False only generated text is returned, else full text is returned. -- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output. -- `prefix`: Prefix added to the prompt. -- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity. -- `max_length`: The maximum total number of tokens in the generated text. -- `min_length`: The minimum total number of tokens that should be generated. -- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation. -- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search. -- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive. -- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results. -- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs. -- `temperature`: Controls the randomness of the output by scaling the logits before sampling. -- `top_k`: Restricts sampling to the k most likely next tokens. -- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass. -- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context. -- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition. -- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs. -- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once. -- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models. -- `bad_words_ids`: A list of token ids that should not be generated. -- `num_return_sequences`: The number of different sequences to generate. -- `output_scores`: Whether to output the prediction scores. -- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information. -- `pad_token_id`: The token ID used for padding sequences to the same length. -- `eos_token_id`: The token ID that signifies the end of a sequence. -- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token. -- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached. -- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes. +See [document](../../../../docs/inference/README.md). diff --git a/presets/workspace/models/phi3/README.md b/presets/workspace/models/phi3/README.md index fd0d3d526..a7aa09613 100644 --- a/presets/workspace/models/phi3/README.md +++ b/presets/workspace/models/phi3/README.md @@ -5,95 +5,11 @@ | phi-3-mini-128k-instruct | [microsoft](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | [link](../../../../examples/inference/kaito_workspace_phi_3_mini_128k.yaml) | Deployment | false | | phi-3-medium-4k-instruct | [microsoft](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) | [link](../../../../examples/inference/kaito_workspace_phi_3_medium_4k.yaml) | Deployment | false | | phi-3-medium-128k-instruct | [microsoft](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) | [link](../../../../examples/inference/kaito_workspace_phi_3_medium_128k.yaml) | Deployment | false | +| phi-3.5-mini-instruct | [microsoft](https://huggingface.co/microsoft/Phi-3.5-mini-instruct) | [link](../../../../examples/inference/kaito_workspace_phi_3.5-instruct.yaml) | Deployment | false | ## Image Source - **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). ## Usage -Phi-3 Mini models are best suited for prompts using the chat format as follows. You can provide the prompt as a question with a generic template as follows: - -``` -<|user|>\nQuestion<|end|>\n<|assistant|> -``` - -For more information on usage, check the phi-3 repo: [Phi-3 Repo](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - -The inference service endpoint is `/chat`. - - -### Basic example -``` -curl -X POST "http://:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"<|user|> How to explain Internet for a medieval knight?<|end|><|assistant|>"}' -``` - - -### Example with full configurable parameters -``` -curl -X POST \ - -H "accept: application/json" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt":"<|user|> What is the meaning of life?<|end|><|assistant|>", - "return_full_text": false, - "clean_up_tokenization_spaces": false, - "prefix": null, - "handle_long_generation": null, - "generate_kwargs": { - "max_length":200, - "min_length":0, - "do_sample":true, - "early_stopping":false, - "num_beams":1, - "num_beam_groups":1, - "diversity_penalty":0.0, - "temperature":1.0, - "top_k":10, - "top_p":1, - "typical_p":1, - "repetition_penalty":1, - "length_penalty":1, - "no_repeat_ngram_size":0, - "encoder_no_repeat_ngram_size":0, - "bad_words_ids":null, - "num_return_sequences":1, - "output_scores":false, - "return_dict_in_generate":false, - "forced_bos_token_id":null, - "forced_eos_token_id":null, - "remove_invalid_values":null - } - }' \ - "http://:80/chat" -``` - -### Parameters -- `prompt`: The initial text provided by the user, from which the model will continue generating text. -- `return_full_text`: If False only generated text is returned, else full text is returned. -- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output. -- `prefix`: Prefix added to the prompt. -- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity. -- `max_length`: The maximum total number of tokens in the generated text. -- `min_length`: The minimum total number of tokens that should be generated. -- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation. -- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search. -- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive. -- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results. -- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs. -- `temperature`: Controls the randomness of the output by scaling the logits before sampling. -- `top_k`: Restricts sampling to the k most likely next tokens. -- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass. -- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context. -- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition. -- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs. -- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once. -- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models. -- `bad_words_ids`: A list of token ids that should not be generated. -- `num_return_sequences`: The number of different sequences to generate. -- `output_scores`: Whether to output the prediction scores. -- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information. -- `pad_token_id`: The token ID used for padding sequences to the same length. -- `eos_token_id`: The token ID that signifies the end of a sequence. -- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token. -- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached. -- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes. +See [document](../../../../docs/inference/README.md). diff --git a/presets/workspace/models/qwen/README.md b/presets/workspace/models/qwen/README.md new file mode 100644 index 000000000..9ca6aa50a --- /dev/null +++ b/presets/workspace/models/qwen/README.md @@ -0,0 +1,11 @@ +## Supported Models +| Model name | Model source | Sample workspace | Kubernetes Workload | Distributed inference | +|---------------------|:----------------------------------------------------------------------:|:-------------------------------------------------------------------------------:|:-------------------:|:---------------------:| +| qwen2.5-coder-7b-instruct | [qwen](https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct) | [link](../../../../examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml) | Deployment | false | + +## Image Source +- **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). + +## Usage + +See [document](../../../../docs/inference/README.md). diff --git a/presets/workspace/models/qwen/model.go b/presets/workspace/models/qwen/model.go new file mode 100644 index 000000000..20a09df74 --- /dev/null +++ b/presets/workspace/models/qwen/model.go @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package qwen + +import ( + "time" + + kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1" + "github.com/kaito-project/kaito/pkg/model" + "github.com/kaito-project/kaito/pkg/utils/plugin" + "github.com/kaito-project/kaito/pkg/workspace/inference" +) + +func init() { + plugin.KaitoModelRegister.Register(&plugin.Registration{ + Name: PresetQwen2_5Coder7BInstructModel, + Instance: &qwen2_5coder7bInst, + }) +} + +var ( + PresetQwen2_5Coder7BInstructModel = "qwen2.5-coder-7b-instruct" + + PresetTagMap = map[string]string{ + "Qwen2.5-Coder-7B-Instruct": "0.0.1", + } + + baseCommandPresetQwenInference = "accelerate launch" + baseCommandPresetQwenTuning = "cd /workspace/tfs/ && python3 metrics_server.py & accelerate launch" + qwenRunParams = map[string]string{ + "torch_dtype": "bfloat16", + "pipeline": "text-generation", + } + qwenRunParamsVLLM = map[string]string{ + "dtype": "float16", + } +) + +var qwen2_5coder7bInst qwen2_5Coder7BInstruct + +type qwen2_5Coder7BInstruct struct{} + +func (*qwen2_5Coder7BInstruct) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Qwen", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "24Gi", + PerGPUMemoryRequirement: "0Gi", // We run qwen using native vertical model parallel, no per GPU memory requirement. + RuntimeParam: model.RuntimeParam{ + Transformers: model.HuggingfaceTransformersParam{ + TorchRunParams: inference.DefaultAccelerateParams, + ModelRunParams: qwenRunParams, + BaseCommand: baseCommandPresetQwenInference, + InferenceMainFile: inference.DefautTransformersMainFile, + }, + VLLM: model.VLLMParam{ + BaseCommand: inference.DefaultVLLMCommand, + ModelName: PresetQwen2_5Coder7BInstructModel, + ModelRunParams: qwenRunParamsVLLM, + }, + }, + ReadinessTimeout: time.Duration(30) * time.Minute, + Tag: PresetTagMap["Qwen2.5-Coder-7B-Instruct"], + } +} + +func (*qwen2_5Coder7BInstruct) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "qwen", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "24Gi", + PerGPUMemoryRequirement: "24Gi", + RuntimeParam: model.RuntimeParam{ + Transformers: model.HuggingfaceTransformersParam{ + //TorchRunParams: tuning.DefaultAccelerateParams, + //ModelRunParams: qwenRunParams, + BaseCommand: baseCommandPresetQwenTuning, + }, + }, + ReadinessTimeout: time.Duration(30) * time.Minute, + Tag: PresetTagMap["Qwen2.5-Coder-7B-Instruct"], + } +} + +func (*qwen2_5Coder7BInstruct) SupportDistributedInference() bool { + return false +} +func (*qwen2_5Coder7BInstruct) SupportTuning() bool { + return true +}