Skip to content

Commit

Permalink
Merge branch 'main' into corpus-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Aug 1, 2024
2 parents c328403 + 19d326b commit ad9c130
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 63 deletions.
3 changes: 3 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ dependencies:
- pip
- git-lfs
- pip:
- httpx_sse
- ijson
- sse_starlette
- python-dotenv
- pytest >=6
- pytest-mock
Expand Down
2 changes: 1 addition & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def timeout_after(
seconds: float = 30, *, message: str = ""
) -> Callable[[Callable], Callable]:
timeout = f"Timeout after {seconds:.1f} seconds"
message = timeout if message else f"{timeout}: {message}"
message = f"{timeout}: {message}" if message else timeout

def decorator(fn: Callable) -> Callable:
if is_debugging():
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def answer(
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
async with self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
Expand All @@ -49,8 +49,9 @@ async def answer(
],
"system": self._make_system_content(sources),
},
):
yield cast(str, data["outputs"][0]["text"])
) as stream:
async for data in stream:
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
23 changes: 12 additions & 11 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def answer(
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
async with self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand All @@ -59,16 +59,17 @@ async def answer(
"temperature": 0.0,
"stream": True,
},
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
) as stream:
async for data in stream:
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))


class ClaudeOpus(AnthropicAssistant):
Expand Down
19 changes: 10 additions & 9 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def answer(
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
prompt, sources = (message := messages[-1]).content, message.sources
async for event in self._call_api(
async with self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand All @@ -48,14 +48,15 @@ async def answer(
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
):
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])
) as stream:
async for event in stream:
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])


class Command(CohereAssistant):
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for chunk in self._call_api(
async with self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand Down Expand Up @@ -58,8 +58,9 @@ async def answer(
},
},
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk
) as stream:
async for chunk in stream:
yield chunk


class GeminiPro(GoogleAssistant):
Expand Down
40 changes: 29 additions & 11 deletions ragna/assistants/_http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import json
import os
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncContextManager, AsyncIterator, Optional

import httpx

Expand Down Expand Up @@ -47,7 +47,7 @@ def __call__(
*,
parse_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
) -> AsyncContextManager[AsyncIterator[Any]]:
if self._protocol is None:
call_method = self._no_stream
else:
Expand All @@ -56,8 +56,10 @@ def __call__(
HttpStreamingProtocol.JSONL: self._stream_jsonl,
HttpStreamingProtocol.JSON: self._stream_json,
}[self._protocol]

return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs)

@contextlib.asynccontextmanager
async def _no_stream(
self,
method: str,
Expand All @@ -68,8 +70,13 @@ async def _no_stream(
) -> AsyncIterator[Any]:
response = await self._client.request(method, url, **kwargs)
await self._assert_api_call_is_success(response)
yield response.json()

async def stream() -> AsyncIterator[Any]:
yield response.json()

yield stream()

@contextlib.asynccontextmanager
async def _stream_sse(
self,
method: str,
Expand All @@ -85,9 +92,13 @@ async def _stream_sse(
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
yield json.loads(sse.data)
async def stream() -> AsyncIterator[Any]:
async for sse in event_source.aiter_sse():
yield json.loads(sse.data)

yield stream()

@contextlib.asynccontextmanager
async def _stream_jsonl(
self,
method: str,
Expand All @@ -99,8 +110,11 @@ async def _stream_jsonl(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in response.aiter_lines():
yield json.loads(chunk)
async def stream() -> AsyncIterator[Any]:
async for chunk in response.aiter_lines():
yield json.loads(chunk)

yield stream()

# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
Expand All @@ -120,6 +134,7 @@ async def read(self, n: int) -> bytes:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]

@contextlib.asynccontextmanager
async def _stream_json(
self,
method: str,
Expand All @@ -136,10 +151,13 @@ async def _stream_json(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk
async def stream() -> AsyncIterator[Any]:
async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk

yield stream()

async def _assert_api_call_is_success(self, response: httpx.Response) -> None:
if response.is_success:
Expand Down
17 changes: 10 additions & 7 deletions ragna/assistants/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])
async with self._call_openai_api(
prompt, sources, max_new_tokens=max_new_tokens
) as stream:
async for data in stream:
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])


class OllamaGemma2B(OllamaAssistant):
Expand Down
21 changes: 12 additions & 9 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from functools import cached_property
from typing import Any, AsyncIterator, Optional, cast
from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast

from ragna.core import Message, Source

Expand All @@ -23,9 +23,9 @@ def _make_system_content(self, sources: list[Source]) -> str:
)
return instruction + "\n\n".join(source.content for source in sources)

def _stream(
def _call_openai_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[dict[str, Any]]:
) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]:
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
headers = {
Expand Down Expand Up @@ -58,12 +58,15 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])
async with self._call_openai_api(
prompt, sources, max_new_tokens=max_new_tokens
) as stream:
async for data in stream:
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])


class OpenaiAssistant(OpenaiLikeHttpApiAssistant):
Expand Down
49 changes: 49 additions & 0 deletions tests/assistants/streaming_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import random

import sse_starlette
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import StreamingResponse

app = FastAPI()


@app.get("/health")
async def health():
return Response(b"", status_code=status.HTTP_200_OK)


@app.post("/sse")
async def sse(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield sse_starlette.ServerSentEvent(json.dumps(obj))

return sse_starlette.EventSourceResponse(stream())


@app.post("/jsonl")
async def jsonl(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield f"{json.dumps(obj)}\n"

return StreamingResponse(stream())


@app.post("/json")
async def json_(request: Request):
data = await request.body()

async def stream():
start = 0
while start < len(data):
end = start + random.randint(1, 10)
yield data[start:end]
start = end

return StreamingResponse(stream())
Loading

0 comments on commit ad9c130

Please sign in to comment.