Skip to content

Commit

Permalink
✨ Allow setting model / language on new doc page
Browse files Browse the repository at this point in the history
Fixes #264
Fixes #208
  • Loading branch information
pajowu committed Jun 27, 2023
1 parent 4e9756e commit d5cefd4
Show file tree
Hide file tree
Showing 12 changed files with 422 additions and 12 deletions.
127 changes: 127 additions & 0 deletions backend/data/models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
[
{
"id": "tiny",
"name": "Tiny",
"languages": [
"auto",
"ar",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"hu",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"tr",
"uk",
"zh"
]
},
{
"id": "base",
"name": "Base",
"languages": [
"auto",
"ar",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"hu",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"tr",
"uk",
"zh"
]
},
{
"id": "small",
"name": "Small",
"languages": [
"auto",
"ar",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"hu",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"tr",
"uk",
"zh"
]
},
{
"id": "medium",
"name": "Medium",
"languages": [
"auto",
"ar",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"hu",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"tr",
"uk",
"zh"
]
},
{
"id": "large",
"name": "Large",
"languages": [
"auto",
"ar",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"hu",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"tr",
"uk",
"zh"
]
}
]
49 changes: 49 additions & 0 deletions backend/openapi-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,19 @@ components:
format: binary
title: File
type: string
language:
title: Language
type: string
model:
title: Model
type: string
name:
title: Name
type: string
required:
- name
- model
- language
- file
title: Body_create_document_api_v1_documents__post
type: object
Expand Down Expand Up @@ -178,6 +186,36 @@ components:
- token
title: LoginResponse
type: object
ModelConfig:
properties:
id:
title: Id
type: string
languages:
items:
type: string
title: Languages
type: array
name:
title: Name
type: string
required:
- id
- name
- languages
title: ModelConfig
type: object
PublicConfig:
properties:
models:
items:
$ref: '#/components/schemas/ModelConfig'
title: Models
type: array
required:
- models
title: PublicConfig
type: object
SetDurationRequest:
properties:
duration:
Expand Down Expand Up @@ -351,6 +389,17 @@ paths:
schema: {}
description: Successful Response
summary: Root
/api/v1/config/:
get:
operationId: get_config_api_v1_config__get
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/PublicConfig'
description: Successful Response
summary: Get Config
/api/v1/documents/:
get:
operationId: list_documents_api_v1_documents__get
Expand Down
8 changes: 6 additions & 2 deletions backend/tests/test_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
@pytest.fixture
def document(memory_session: Session, logged_in_client: TestClient):
req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
Expand Down Expand Up @@ -48,7 +50,9 @@ def test_doc_delete(
files = set(str(x) for x in settings.storage_path.glob("*"))

req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
Expand Down
4 changes: 3 additions & 1 deletion backend/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

def create_doc(logged_in_client: TestClient):
req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200

Expand Down
24 changes: 23 additions & 1 deletion backend/transcribee_backend/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import List

from pydantic import BaseSettings
from pydantic import BaseSettings, parse_file_as
from pydantic.main import BaseModel


class Settings(BaseSettings):
Expand All @@ -12,5 +14,25 @@ class Settings(BaseSettings):

media_url_base = "http://localhost:8000/"

model_config_path: Path = Path("data/models.json")


class ModelConfig(BaseModel):
id: str
name: str
languages: List[str]


class PublicConfig(BaseModel):
models: List[ModelConfig]


def get_model_config():
return parse_file_as(List[ModelConfig], settings.model_config_path)


def get_public_config():
return PublicConfig(models=get_model_config())


settings = Settings()
2 changes: 2 additions & 0 deletions backend/transcribee_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transcribee_backend.config import settings
from transcribee_backend.helpers.periodic_tasks import run_periodic
from transcribee_backend.helpers.tasks import timeout_attempts
from transcribee_backend.routers.config import config_router
from transcribee_backend.routers.document import document_router
from transcribee_backend.routers.task import task_router
from transcribee_backend.routers.user import user_router
Expand All @@ -28,6 +29,7 @@
app.include_router(user_router, prefix="/api/v1/users")
app.include_router(document_router, prefix="/api/v1/documents")
app.include_router(task_router, prefix="/api/v1/tasks")
app.include_router(config_router, prefix="/api/v1/config")


@app.get("/")
Expand Down
9 changes: 9 additions & 0 deletions backend/transcribee_backend/routers/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import APIRouter
from transcribee_backend.config import PublicConfig, get_public_config

config_router = APIRouter()


@config_router.get("/")
def get_config() -> PublicConfig:
return get_public_config()
37 changes: 33 additions & 4 deletions backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
WebSocketException,
status,
)
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper
from sqlalchemy.sql.expression import desc
from sqlmodel import Session, col, select
from transcribee_proto.api import Document as ApiDocument
Expand All @@ -25,7 +27,7 @@
validate_user_authorization,
validate_worker_authorization,
)
from transcribee_backend.config import settings
from transcribee_backend.config import get_model_config, settings
from transcribee_backend.db import get_session
from transcribee_backend.helpers.sync import DocumentSyncConsumer
from transcribee_backend.helpers.time import now_tz_aware
Expand Down Expand Up @@ -101,7 +103,9 @@ def ws_get_document_from_url(
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)


def create_default_tasks_for_document(session: Session, document: Document):
def create_default_tasks_for_document(
session: Session, document: Document, model: str, language: str
):
reencode_task = Task(
task_type=TaskType.REENCODE,
task_parameters={},
Expand All @@ -111,7 +115,7 @@ def create_default_tasks_for_document(session: Session, document: Document):

transcribe_task = Task(
task_type=TaskType.TRANSCRIBE,
task_parameters={"lang": "auto", "model": "base"},
task_parameters={"lang": language, "model": model},
document_id=document.id,
dependencies=[reencode_task],
)
Expand All @@ -137,10 +141,35 @@ def create_default_tasks_for_document(session: Session, document: Document):
@document_router.post("/")
async def create_document(
name: str = Form(),
model: str = Form(),
language: str = Form(),
file: UploadFile = File(),
session: Session = Depends(get_session),
token: UserToken = Depends(get_user_token),
) -> ApiDocument:
model_configs = get_model_config()
selected_model = None
for model_config in model_configs:
if model_config.id == model:
selected_model = model_config

if selected_model is None:
raise RequestValidationError(
[ErrorWrapper(ValueError(f"Unknown model '{model}'"), ("body", "model"))]
)

if language not in selected_model.languages:
raise RequestValidationError(
[
ErrorWrapper(
ValueError(
f"Model '{model}' does not support language '{language}'"
),
("body", "language"),
)
]
)

document = Document(
name=name,
user_id=token.user_id,
Expand All @@ -166,7 +195,7 @@ async def create_document(
tag = DocumentMediaTag(media_file_id=media_file.id, tag="original")
session.add(tag)

create_default_tasks_for_document(session, document)
create_default_tasks_for_document(session, document, model, language)

session.commit()
return document.as_api_document()
Expand Down
Loading

0 comments on commit d5cefd4

Please sign in to comment.