Skip to content

Commit

Permalink
merge in main and resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonroach committed Sep 5, 2024
2 parents 2e8ae2c + ef6c6a5 commit e767c1b
Show file tree
Hide file tree
Showing 21 changed files with 378 additions and 128 deletions.
9 changes: 9 additions & 0 deletions .github/ISSUE_TEMPLATE/DOCKER_ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
title: "Workflow Failure: docker-requirements"
labels: "bug"
---

The workflow 'docker-requirements' has failed. Please check the details at: [Workflow
Run Details](https://github.com/{{ env.REPOSITORY }}/actions/runs/{{ env.RUN_ID }})

Triggered by: @{{ actor }}
56 changes: 41 additions & 15 deletions .github/workflows/update-docker-requirements.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
name: docker-requirements

on:
pull_request:
push:
branches:
- main
paths:
- ".github/workflows/update-docker-requirements.yml"
- "pyproject.toml"
- "requirements-docker.lock"
workflow_dispatch:

permissions:
pull-requests: write
contents: write
issues: write

jobs:
update:
Expand Down Expand Up @@ -42,24 +44,48 @@ jobs:
NEEDS_UPDATE=$(git diff --quiet && echo 'false' || echo 'true')
echo "needs-update=${NEEDS_UPDATE}" | tee --append $GITHUB_OUTPUT
- name: Check if commit is associated with a PR
id: pr-check
env:
GH_TOKEN: ${{ github.token }}
run: |
PR_NUMBER=$(gh api \
-H 'Accept: application/vnd.github+json' \
/repos/${{ github.repository }}/commits/${{ github.sha }}/pulls \
--jq '.[0].number')
if [ $? -eq 0 ]; then
PR_INFO=$(gh pr view ${PR_NUMBER} --json author,mergedBy --jq '{author: .author.login, merger: .mergedBy.login}')
AUTHOR=$(echo ${PR_INFO} | jq -r .author)
MERGER=$(echo ${PR_INFO} | jq -r .merger)
else
AUTHOR=${{ github.actor }}
MERGER="none"
fi
echo "author=${AUTHOR}" | tee --append $GITHUB_OUTPUT
echo "merger=${MERGER}" | tee --append $GITHUB_OUTPUT
- name: Open a PR to update the requirements
# prettier-ignore
if:
${{ steps.update.outputs.needs-update && github.event_name == 'workflow_dispatch' }}
if: ${{ steps.update.outputs.needs-update }}
uses: peter-evans/create-pull-request@v5
with:
commit-message: update requirements-docker.lock
branch: update-docker-requirements
branch-suffix: timestamp
base: ${{ github.head_ref || github.ref_name }}
base: main
title: Update requirements-docker.lock
# prettier-ignore
body:
https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
reviewers: ${{ github.actor }}
body: |
Automatic update of requirements-docker.lock.
[Workflow run details](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})
${{ steps.pr-check.outputs.author != 'none' && format('cc @{0}', steps.pr-check.outputs.author) || '' }}
reviewers: ${{ steps.pr-check.outputs.merger }}

- name: Show diff
# prettier-ignore
if:
${{ steps.update.outputs.needs-update && github.event_name != 'workflow_dispatch' }}
run: git diff --exit-code
- name: Create failure issue
if: failure()
uses: JasonEtco/create-an-issue@v2
env:
GITHUB_TOKEN: ${{ github.token }}
REPOSITORY: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
with:
filename: .github/ISSUE_TEMPLATE/DOCKER_ISSUE_TEMPLATE.md
3 changes: 3 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ dependencies:
- pip
- git-lfs
- pip:
- httpx_sse
- ijson
- sse_starlette
- python-dotenv
- pytest >=6
- pytest-mock
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"httpx",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel==1.4.2",
"panel==1.4.4",
"pydantic>=2",
"pydantic-core",
"pydantic-settings>=2",
Expand Down
9 changes: 5 additions & 4 deletions ragna/_docs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import inspect
import itertools
import os
Expand Down Expand Up @@ -31,6 +32,9 @@
class RestApi:
def __init__(self) -> None:
self._process: Optional[subprocess.Popen] = None
# In case the documentation errors before we call RestApi.stop, we still need to
# stop the server to avoid zombie processes
atexit.register(self.stop, quiet=True)

def start(
self,
Expand Down Expand Up @@ -174,11 +178,8 @@ def stop(self, *, quiet: bool = False) -> None:
if self._process is None:
return

self._process.kill()
self._process.terminate()
stdout, _ = self._process.communicate()

if not quiet:
print(stdout.decode())

def __del__(self) -> None:
self.stop(quiet=True)
2 changes: 1 addition & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def timeout_after(
seconds: float = 30, *, message: str = ""
) -> Callable[[Callable], Callable]:
timeout = f"Timeout after {seconds:.1f} seconds"
message = timeout if message else f"{timeout}: {message}"
message = f"{timeout}: {message}" if message else timeout

def decorator(fn: Callable) -> Callable:
if is_debugging():
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def generate(
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response

async for data in self._call_api(
async with self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
Expand All @@ -79,8 +79,9 @@ async def generate(
"messages": _render_prompt(prompt),
"system": system_prompt,
},
):
yield cast(str, data["outputs"][0]["text"])
) as stream:
async for data in stream:
yield cast(str, data["outputs"][0]["text"])

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
Expand Down
23 changes: 12 additions & 11 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def generate(
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming

async for data in self._call_api(
async with self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand All @@ -94,16 +94,17 @@ async def generate(
"temperature": 0.0,
"stream": True,
},
):
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))
) as stream:
async for data in stream:
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
if "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["type"] == "message_stop":
break
elif data["type"] != "content_block_delta":
continue

yield cast(str, data["delta"].pop("text"))

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
Expand Down
19 changes: 10 additions & 9 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def generate(
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag

async for event in self._call_api(
async with self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand All @@ -81,14 +81,15 @@ async def generate(
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
):
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])
) as stream:
async for event in stream:
if event["event_type"] == "stream-end":
if event["event_type"] == "COMPLETE":
break

raise RagnaException(event["error_message"])
if "text" in event:
yield cast(str, event["text"])

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
Expand Down
7 changes: 4 additions & 3 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def generate(
Returns:
async streamed inference response string chunks
"""
async for chunk in self._call_api(
async with self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand All @@ -80,8 +80,9 @@ async def generate(
},
},
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk
) as stream:
async for chunk in stream:
yield chunk

async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
Expand Down
40 changes: 29 additions & 11 deletions ragna/assistants/_http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
import json
import os
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncContextManager, AsyncIterator, Optional

import httpx

Expand Down Expand Up @@ -47,7 +47,7 @@ def __call__(
*,
parse_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
) -> AsyncContextManager[AsyncIterator[Any]]:
if self._protocol is None:
call_method = self._no_stream
else:
Expand All @@ -56,8 +56,10 @@ def __call__(
HttpStreamingProtocol.JSONL: self._stream_jsonl,
HttpStreamingProtocol.JSON: self._stream_json,
}[self._protocol]

return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs)

@contextlib.asynccontextmanager
async def _no_stream(
self,
method: str,
Expand All @@ -68,8 +70,13 @@ async def _no_stream(
) -> AsyncIterator[Any]:
response = await self._client.request(method, url, **kwargs)
await self._assert_api_call_is_success(response)
yield response.json()

async def stream() -> AsyncIterator[Any]:
yield response.json()

yield stream()

@contextlib.asynccontextmanager
async def _stream_sse(
self,
method: str,
Expand All @@ -85,9 +92,13 @@ async def _stream_sse(
) as event_source:
await self._assert_api_call_is_success(event_source.response)

async for sse in event_source.aiter_sse():
yield json.loads(sse.data)
async def stream() -> AsyncIterator[Any]:
async for sse in event_source.aiter_sse():
yield json.loads(sse.data)

yield stream()

@contextlib.asynccontextmanager
async def _stream_jsonl(
self,
method: str,
Expand All @@ -99,8 +110,11 @@ async def _stream_jsonl(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in response.aiter_lines():
yield json.loads(chunk)
async def stream() -> AsyncIterator[Any]:
async for chunk in response.aiter_lines():
yield json.loads(chunk)

yield stream()

# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
Expand All @@ -120,6 +134,7 @@ async def read(self, n: int) -> bytes:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]

@contextlib.asynccontextmanager
async def _stream_json(
self,
method: str,
Expand All @@ -136,10 +151,13 @@ async def _stream_json(
async with self._client.stream(method, url, **kwargs) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk
async def stream() -> AsyncIterator[Any]:
async for chunk in ijson.items(
self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item
):
yield chunk

yield stream()

async def _assert_api_call_is_success(self, response: httpx.Response) -> None:
if response.is_success:
Expand Down
17 changes: 10 additions & 7 deletions ragna/assistants/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ async def answer(
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])
async with self._call_openai_api(
prompt, sources, max_new_tokens=max_new_tokens
) as stream:
async for data in stream:
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
if "error" in data:
raise RagnaException(data["error"])
if not data["done"]:
yield cast(str, data["message"]["content"])


class OllamaGemma2B(OllamaAssistant):
Expand Down
Loading

0 comments on commit e767c1b

Please sign in to comment.