Skip to content

Commit

Permalink
Pass details param into client (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Feb 21, 2024
1 parent dd68924 commit b407adc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
20 changes: 16 additions & 4 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def generate(
watermark: bool = False,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
decoder_input_details: bool = False,
details: bool = True,
) -> Response:
"""
Given a prompt, generate the following text
Expand Down Expand Up @@ -138,6 +139,8 @@ def generate(
```
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
details (`bool`):
Return the token logprobs and ids for generated tokens
Returns:
Response: generated response
Expand All @@ -149,7 +152,7 @@ def generate(
merged_adapters=merged_adapters,
api_token=api_token,
best_of=best_of,
details=True,
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
Expand Down Expand Up @@ -202,6 +205,7 @@ def generate_stream(
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
details: bool = True,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Expand Down Expand Up @@ -255,6 +259,8 @@ def generate_stream(
}
}
```
details (`bool`):
Return the token logprobs and ids for generated tokens
Returns:
Iterator[StreamResponse]: stream of generated tokens
Expand All @@ -266,7 +272,7 @@ def generate_stream(
merged_adapters=merged_adapters,
api_token=api_token,
best_of=None,
details=True,
details=details,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
Expand Down Expand Up @@ -384,6 +390,7 @@ async def generate(
watermark: bool = False,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
decoder_input_details: bool = False,
details: bool = True,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Expand Down Expand Up @@ -441,6 +448,8 @@ async def generate(
```
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
details (`bool`):
Return the token logprobs and ids for generated tokens
Returns:
Response: generated response
Expand All @@ -452,7 +461,7 @@ async def generate(
merged_adapters=merged_adapters,
api_token=api_token,
best_of=best_of,
details=True,
details=details,
decoder_input_details=decoder_input_details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
Expand Down Expand Up @@ -500,6 +509,7 @@ async def generate_stream(
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
details: bool = True,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Expand Down Expand Up @@ -553,6 +563,8 @@ async def generate_stream(
}
}
```
details (`bool`):
Return the token logprobs and ids for generated tokens
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
Expand All @@ -564,7 +576,7 @@ async def generate_stream(
merged_adapters=merged_adapters,
api_token=api_token,
best_of=None,
details=True,
details=details,
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
Expand Down
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details
details: Optional[Details]


# `generate_stream` details
Expand Down

0 comments on commit b407adc

Please sign in to comment.