Skip to content
This repository has been archived by the owner on Jul 21, 2024. It is now read-only.

Commit

Permalink
[chore] bump version 1.2.0 => 1.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
synacktraa committed Apr 11, 2024
1 parent 089b913 commit ad91068
Show file tree
Hide file tree
Showing 12 changed files with 837 additions and 80 deletions.
313 changes: 312 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "swiftrank"
version = "1.2.0"
version = "1.3.0"
description = "Compact, ultra-fast SoTA reranker enhancing retrieval pipelines and terminal applications."
authors = ["Harsh Verma <[email protected]>"]
license = "Apache Software License (Apache 2.0)"
Expand All @@ -17,10 +17,17 @@ tqdm = "4.66.1"
cyclopts = "2.1.2"
pyyaml = "6.0.1"
orjson = "3.9.10"
pydantic = "2.6.4"
fastapi = "0.110.1"
uvicorn = "0.29.0"

[tool.poetry.scripts]
swiftrank = "swiftrank.cli:app.meta"
srank = "swiftrank.cli:app.meta"
swiftrank = "swiftrank.interface.cli:app.meta"
srank = "swiftrank.interface.cli:app.meta"

[tool.poetry.group.dev.dependencies]
pytest = "8.1.1"
requests = "2.31.0"

[build-system]
requires = ["poetry-core"]
Expand Down
27 changes: 26 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
⌨️ **Terminal Integration**:
- Pipe your output into `swiftrank` cli tool and get reranked output

🌐 **API Integration**:
- Deploy `swiftrank` as an API service for seamless integration into your workflow.

---

### 🚀 Installation
Expand All @@ -57,6 +60,7 @@ Rerank contexts provided on stdin.
╭─ Commands ─────────────────────────────────────────────────────╮
│ process STDIN processor. [ json | jsonl | yaml ] │
│ serve Startup a swiftrank server │
│ --help,-h Display this message and exit. │
│ --version Display application version. │
╰────────────────────────────────────────────────────────────────╯
Expand Down Expand Up @@ -180,6 +184,27 @@ STDIN processor. [ json | jsonl | yaml ]
Monogatari Series: Second Season
```

#### Startup a FastAPI server instance

```
Usage: swiftrank serve [OPTIONS]
Startup a swiftrank server
╭─ Parameters ──────────────────────────────╮
│ --host Host name [default: 0.0.0.0] │
│ --port Port number. [default: 12345] │
╰───────────────────────────────────────────╯
```

```sh
swiftrank serve
```
```
[GET] /models - List Models
[POST] /rerank - Rerank Endpoint
```

### Library Usage 🤗

- Build a `ReRankPipeline` instance
Expand Down Expand Up @@ -311,4 +336,4 @@ url = {https://github.com/PrithivirajDamodaran/FlashRank},
version = {1.0.0},
year = {2023}
}
```
```
Empty file added swiftrank/interface/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions swiftrank/interface/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Any, Optional

from fastapi import FastAPI, status
from fastapi.responses import ORJSONResponse
from fastapi.exceptions import HTTPException
from pydantic import BaseModel, Field

from .utils import ObjectCollection, api_object_parser
from ..settings import MODEL_MAP
from ..ranker import ReRankPipeline

server = FastAPI()
pipeline_map: dict[str, ReRankPipeline] = {}

def get_pipeline(__id: str):
if pipeline_map.get(__id) is None:
pipeline_map[__id] = ReRankPipeline.from_model_id(__id)
return pipeline_map[__id]


class SchemaContext(BaseModel):
pre: Optional[str] = Field(None, description="schema for pre-processing input.")
ctx: Optional[str] = Field(None, description="schema for extracting context.")
post: Optional[str] = Field(None, description="schema for extracting field after reranking.")

class RerankContext(BaseModel):
model: str = Field("ms-marco-TinyBERT-L-2-v2", description="model to use for reranking.")
contexts: ObjectCollection = Field(..., description="contexts to rerank.")
query: str = Field(..., description="query for reranking evaluation.")
threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="filter contexts using threshold.")
map_score: bool = Field(False, description="map relevance score with context")
schema_: Optional[SchemaContext] = Field(default=None, alias='schema')


@server.get('/models', response_class=ORJSONResponse)
def list_models():
return list(MODEL_MAP.keys())

@server.post('/rerank')
def rerank_endpoint(ctx: RerankContext):
if not ctx.contexts:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="contexts field cannot be an empty array or object"
)

if ctx.model not in MODEL_MAP:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{ctx.model!r} model is not available"
)

schema = ctx.schema_ or SchemaContext()
if schema.pre is not None:
contexts = api_object_parser(ctx.contexts, schema=schema.pre)
if isinstance(contexts, list) and not contexts:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Empty array after pre-processing"
)
no_list_err = "Pre-processing must result into an array of objects"

else:
contexts = ctx.contexts
no_list_err = "Expected an array of string or object. 'pre' schema might help"

if not isinstance(contexts, list):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=no_list_err
)

ctx_schema = schema.ctx or '.'
post_schema = schema.post or '.'
pipeline = get_pipeline(ctx.model)
try:
if ctx.map_score is False:
reranked = pipeline.invoke(
query=ctx.query,
contexts=contexts,
threshold=ctx.threshold,
key=lambda x: api_object_parser(x, ctx_schema)
)

return [api_object_parser(context, post_schema) for context in reranked]
else:
reranked_tup = pipeline.invoke_with_score(
query=ctx.query,
contexts=contexts,
threshold=ctx.threshold,
key=lambda x: api_object_parser(x, ctx_schema)
)

return [
{'score': score, 'context': api_object_parser(context, post_schema)}
for (score, context) in reranked_tup
]
except TypeError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail='Context processing must result into string'
)

def _serve(host: str, port: int):
import uvicorn
try:
uvicorn.run(server, host=host, port=port)
except KeyboardInterrupt:
exit(0)

if __name__ == "__main__":
_serve(host='0.0.0.0', port=12345)
68 changes: 42 additions & 26 deletions swiftrank/cli/__init__.py → swiftrank/interface/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated

from cyclopts import App, Parameter
from cyclopts import App, Parameter, validators

try:
from signal import signal, SIGPIPE, SIG_DFL
Expand All @@ -23,22 +23,25 @@ def build_processing_parameters(
help="schema for extracting field after reranking.", show_default=False
)] = None
):
from .utils import object_parser, print_and_exit
from .utils import cli_object_parser, print_and_exit
def preprocessor(_input: str):
if _input.startswith(('{', '[')):
from orjson import loads, JSONDecodeError
try:
return object_parser(loads(_input), pre)
return cli_object_parser(loads(_input), pre)
except JSONDecodeError:
from io import StringIO
with StringIO(_input) as handler:
return list(map(loads, handler))
try:
from io import StringIO
with StringIO(_input) as handler:
return list(map(loads, handler))
except (JSONDecodeError, Exception):
print_and_exit("Input data format not valid.", code=1)
except Exception:
print_and_exit("Malformed JSON object not parseable.", code=1)
print_and_exit("Input data format not valid.", code=1)

import yaml
try:
return object_parser(yaml.safe_load(_input), pre)
return cli_object_parser(yaml.safe_load(_input), pre)
except yaml.MarkedYAMLError:
return list(yaml.safe_load_all(_input))
except yaml.YAMLError:
Expand All @@ -53,11 +56,11 @@ def __entry__(
query: Annotated[str, Parameter(
name=("-q", "--query"), help="query for reranking evaluation.")],
threshold: Annotated[float, Parameter(
name=("-t", "--threshold"), help="filter contexts using threshold.")] = None,
name=("-t", "--threshold"), help="filter contexts using threshold.", validator=validators.Number(gte=0.0, lte=1.0))] = None,
first: Annotated[bool, Parameter(
name=("-f", "--first"), help="get most relevant context.", negative="", show_default=False)] = False,
):
from .utils import read_stdin, object_parser, print_and_exit
from .utils import read_stdin, cli_object_parser, print_and_exit

processing_params: dict = {}
if tokens:
Expand All @@ -66,18 +69,17 @@ def __entry__(
if not _input:
return
contexts = processing_params['preprocessor'](_input)

else:
contexts = read_stdin(readlines=True)

ctx_schema = processing_params.get('ctx_schema', '.')
post_schema = processing_params.get('post_schema') or ctx_schema

if not isinstance(contexts, list):
print_and_exit(object_parser(contexts, ctx_schema))
print_and_exit(cli_object_parser(contexts, post_schema))

if not all(contexts):
print_and_exit("No contexts found on stdin", code=1)
if len(contexts) == 1:
print_and_exit(contexts[0])
if not contexts:
print_and_exit("Nothing to rerank!", code=1)

from .. import settings
from ..ranker import ReRankPipeline
Expand All @@ -88,18 +90,32 @@ def __entry__(
query=query,
contexts=contexts,
threshold=threshold,
key=lambda x: object_parser(x, ctx_schema)
key=lambda x: cli_object_parser(x, ctx_schema)
)

if reranked and first:
print_and_exit(
cli_object_parser(reranked[0], post_schema)
)

for context in reranked:
print(cli_object_parser(context, post_schema))

except TypeError:
print_and_exit(
'Context processing must result into string. Hint: `--ctx` flag might help.', code=1
'Context processing must result into string.', code=1
)

post_schema = processing_params.get('post_schema') or ctx_schema
if reranked and first:
print_and_exit(
object_parser(reranked[0], post_schema)
)

for context in reranked:
print(object_parser(context, post_schema))
@app.meta.command(name="serve", help="Startup a swiftrank server")
def serve(
*,
host: Annotated[str, Parameter(
name=('--host'), help="Host name")] = '0.0.0.0',
port: Annotated[int, Parameter(
name=('--port',), help="Port number.")] = 12345
):
from .api import _serve
_serve(host=host, port=port)

if __name__ == "__main__":
app.meta()
54 changes: 35 additions & 19 deletions swiftrank/cli/utils.py → swiftrank/interface/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
import sys
from typing import TypeAlias, Any

def print_and_exit(msg: str, code: int = 0):
stream = sys.stdout if not code else sys.stderr
print(msg, file=stream)
exit(code)

def read_stdin(readlines: bool = False):
"""Read values from standard input (stdin). """
if sys.stdin.isatty():
return
try:
if readlines is False:
return sys.stdin.read().rstrip('\n')
return [_.strip('\n') for _ in sys.stdin if _]
except KeyboardInterrupt:
return


ObjectCollection: TypeAlias = dict[str, Any] | list[Any]
ObjectScalar: TypeAlias = bool | float | int | str
Expand All @@ -33,7 +17,7 @@ def object_parser(obj: ObjectValue, schema: str) -> ObjectValue:
if not re.match(
pattern=r"^(?:(?:[.](?:[\w]+|\[\d?\]))+)$", string=usable_schema
):
print_and_exit(f'{schema!r} is not a valid schema.', code=1)
raise ValueError(f'{schema!r} is not a valid schema.')

def __inner__(_in: ObjectValue, keys: list[str]):
for idx, key in enumerate(keys):
Expand All @@ -48,7 +32,7 @@ def __inner__(_in: ObjectValue, keys: list[str]):
_in = _in[int(obj_idx)]
continue
except (KeyError, IndexError):
print_and_exit(f'{schema!r} schema not compatible with input data.', code=1)
raise ValueError(f'{schema!r} schema not compatible with input data.')

_keys = keys[idx + 1:]
if not _keys:
Expand All @@ -58,4 +42,36 @@ def __inner__(_in: ObjectValue, keys: list[str]):
return _in
return __inner__(
obj, [k for k, _ in groupby(usable_schema.lstrip('.').split('.'))]
)
)

def read_stdin(readlines: bool = False):
"""Read values from standard input (stdin). """
if sys.stdin.isatty():
return
try:
if readlines is False:
return sys.stdin.read().rstrip('\n')
return [_.strip('\n') for _ in sys.stdin if _]
except KeyboardInterrupt:
return

def print_and_exit(msg: str, code: int = 0):
stream = sys.stdout if not code else sys.stderr
print(msg, file=stream)
exit(code)

def cli_object_parser(obj: ObjectValue, schema: str):
try:
return object_parser(obj=obj, schema=schema)
except ValueError as e:
print_and_exit(e.args[0], code=1)

def api_object_parser(obj: ObjectValue, schema: str):
from fastapi import status, HTTPException
try:
return object_parser(obj=obj, schema=schema)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.args[0]
)
Loading

0 comments on commit ad91068

Please sign in to comment.