Skip to content

Commit

Permalink
community: add new parameter default_headers (#28700)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"
- "community: 1. add new parameter `default_headers` for oci model
deployments and oci chat model deployments. 2. updated k parameter in
OCIModelDeploymentLLM class."


- [x] **PR message**:
- **Description:** 1. add new parameters `default_headers` for oci model
deployments and oci chat model deployments. 2. updated k parameter in
OCIModelDeploymentLLM class.


- [x] **Add tests and docs**:
  1. unit tests
  2. notebook

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
lu-ohai and efriis authored Dec 18, 2024
1 parent 1e88ada commit 50afa7c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 6 deletions.
6 changes: 5 additions & 1 deletion docs/docs/integrations/chat/oci_data_science.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -156,6 +156,10 @@
" \"temperature\": 0.2,\n",
" \"max_tokens\": 512,\n",
" }, # other model params...\n",
" default_headers={\n",
" \"route\": \"/v1/chat/completions\",\n",
" # other request headers ...\n",
" },\n",
")"
]
},
Expand Down
37 changes: 35 additions & 2 deletions libs/community/langchain_community/chat_models/oci_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)

logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"


def _is_pydantic_class(obj: Any) -> bool:
Expand All @@ -56,6 +57,13 @@ def _is_pydantic_class(obj: Any) -> bool:
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
"""OCI Data Science Model Deployment chat model integration.
Prerequisite
The OCI Model Deployment plugins are installable only on
python version 3.9 and above. If you're working inside the notebook,
try installing the python 3.10 based conda pack and running the
following setup.
Setup:
Install ``oracle-ads`` and ``langchain-openai``.
Expand Down Expand Up @@ -90,6 +98,8 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
Key init args — client params:
auth: dict
ADS auth dictionary for OCI authentication.
default_headers: Optional[Dict]
The headers to be added to the Model Deployment request.
Instantiate:
.. code-block:: python
Expand All @@ -98,14 +108,18 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
chat = ChatOCIModelDeployment(
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
model="odsc-llm",
model="odsc-llm", # this is the default model name if deployed with AQUA
streaming=True,
max_retries=3,
model_kwargs={
"max_token": 512,
"temperature": 0.2,
# other model parameters ...
},
default_headers={
"route": "/v1/chat/completions",
# other request headers ...
},
)
Invocation:
Expand Down Expand Up @@ -288,6 +302,25 @@ def _default_params(self) -> Dict[str, Any]:
"stream": self.streaming,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -701,7 +734,7 @@ def _process_response(self, response_json: dict) -> ChatResult:

for choice in choices:
message = _convert_dict_to_message(choice["message"])
generation_info = dict(finish_reason=choice.get("finish_reason"))
generation_info = {"finish_reason": choice.get("finish_reason")}
if "logprobs" in choice:
generation_info["logprobs"] = choice["logprobs"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from langchain_community.utilities.requests import Requests

logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"


DEFAULT_TIME_OUT = 300
Expand Down Expand Up @@ -81,6 +82,9 @@ class BaseOCIModelDeployment(Serializable):
max_retries: int = 3
"""Maximum number of retries to make when generating."""

default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -120,12 +124,12 @@ def _headers(
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
req = _req.prepare()
req = signer(req)
headers = {}
for key, value in req.headers.items():
headers[key] = value

Expand All @@ -135,7 +139,7 @@ def _headers(
)
return headers

return (
headers.update(
{
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
"enable-streaming": "true",
Expand All @@ -147,6 +151,8 @@ def _headers(
}
)

return headers

def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -383,6 +389,10 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
model="odsc-llm",
streaming=True,
model_kwargs={"frequency_penalty": 1.0},
headers={
"route": "/v1/completions",
# other request headers ...
}
)
llm.invoke("tell me a joke.")
Expand Down Expand Up @@ -426,7 +436,7 @@ def _construct_json_body(self, prompt: str, param:dict) -> dict:
temperature: float = 0.2
"""A non-negative float that tunes the degree of randomness in generation."""

k: int = -1
k: int = 50
"""Number of most likely tokens to consider at each step."""

p: float = 0.75
Expand Down Expand Up @@ -472,6 +482,25 @@ def _identifying_params(self) -> Dict[str, Any]:
**self._default_params,
}

def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
**super()._headers(is_async=is_async, body=body),
}

def _generate(
self,
prompts: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
CONST_COMPLETION_RESPONSE = {
"id": "chat-123456789",
"object": "chat.completion",
Expand Down Expand Up @@ -120,6 +121,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -132,6 +134,7 @@ def test_invoke_vllm(*args: Any) -> None:
def test_invoke_tgi(*args: Any) -> None:
"""Tests invoking TGI endpoint using OpenAI Spec."""
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert isinstance(output, AIMessage)
assert output.content == CONST_COMPLETION
Expand All @@ -146,6 +149,7 @@ def test_stream_vllm(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = None
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand Down Expand Up @@ -184,6 +188,7 @@ async def test_stream_async(*args: Any) -> None:
llm = ChatOCIModelDeploymentVLLM(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
CONST_PROMPT = "This is a prompt."
CONST_COMPLETION = "This is a completion."
CONST_COMPLETION_ROUTE = "/v1/completions"
CONST_COMPLETION_RESPONSE = {
"choices": [
{
Expand Down Expand Up @@ -114,6 +115,7 @@ async def mocked_async_streaming_response(
def test_invoke_vllm(*args: Any) -> None:
"""Tests invoking vLLM endpoint."""
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -126,6 +128,7 @@ def test_stream_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = ""
count = 0
for chunk in llm.stream(CONST_PROMPT):
Expand All @@ -143,6 +146,7 @@ def test_generate_tgi(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
output = llm.invoke(CONST_PROMPT)
assert output == CONST_COMPLETION

Expand All @@ -161,6 +165,7 @@ async def test_stream_async(*args: Any) -> None:
llm = OCIModelDeploymentTGI(
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
)
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
with mock.patch.object(
llm,
"_aiter_sse",
Expand Down

0 comments on commit 50afa7c

Please sign in to comment.