From b407adca931d7e02a33c865179bf6aa06cce09dd Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:05:39 -0500 Subject: [PATCH] Pass details param into client (#265) --- clients/python/lorax/client.py | 20 ++++++++++++++++---- clients/python/lorax/types.py | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 130b6c44e..a52e74bab 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -81,6 +81,7 @@ def generate( watermark: bool = False, response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, decoder_input_details: bool = False, + details: bool = True, ) -> Response: """ Given a prompt, generate the following text @@ -138,6 +139,8 @@ def generate( ``` decoder_input_details (`bool`): Return the decoder input token logprobs and ids + details (`bool`): + Return the token logprobs and ids for generated tokens Returns: Response: generated response @@ -149,7 +152,7 @@ def generate( merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, - details=True, + details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, @@ -202,6 +205,7 @@ def generate_stream( typical_p: Optional[float] = None, watermark: bool = False, response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, + details: bool = True, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -255,6 +259,8 @@ def generate_stream( } } ``` + details (`bool`): + Return the token logprobs and ids for generated tokens Returns: Iterator[StreamResponse]: stream of generated tokens @@ -266,7 +272,7 @@ def generate_stream( merged_adapters=merged_adapters, api_token=api_token, best_of=None, - details=True, + details=details, decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -384,6 +390,7 @@ async def generate( watermark: bool = False, response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, decoder_input_details: bool = False, + details: bool = True, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -441,6 +448,8 @@ async def generate( ``` decoder_input_details (`bool`): Return the decoder input token logprobs and ids + details (`bool`): + Return the token logprobs and ids for generated tokens Returns: Response: generated response @@ -452,7 +461,7 @@ async def generate( merged_adapters=merged_adapters, api_token=api_token, best_of=best_of, - details=True, + details=details, decoder_input_details=decoder_input_details, do_sample=do_sample, max_new_tokens=max_new_tokens, @@ -500,6 +509,7 @@ async def generate_stream( typical_p: Optional[float] = None, watermark: bool = False, response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, + details: bool = True, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -553,6 +563,8 @@ async def generate_stream( } } ``` + details (`bool`): + Return the token logprobs and ids for generated tokens Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -564,7 +576,7 @@ async def generate_stream( merged_adapters=merged_adapters, api_token=api_token, best_of=None, - details=True, + details=details, decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 8dd024c7c..e16a37761 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -289,7 +289,7 @@ class Response(BaseModel): # Generated text generated_text: str # Generation details - details: Details + details: Optional[Details] # `generate_stream` details