Skip to content

Commit

Permalink
✨ Allow setting model / language on new doc page
Browse files Browse the repository at this point in the history
Fixes #264
Fixes #208
  • Loading branch information
pajowu committed Jun 26, 2023
1 parent bffc574 commit b78bbda
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 12 deletions.
49 changes: 49 additions & 0 deletions backend/openapi-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,19 @@ components:
format: binary
title: File
type: string
language:
title: Language
type: string
model:
title: Model
type: string
name:
title: Name
type: string
required:
- name
- model
- language
- file
title: Body_create_document_api_v1_documents__post
type: object
Expand Down Expand Up @@ -178,6 +186,36 @@ components:
- token
title: LoginResponse
type: object
ModelConfig:
properties:
id:
title: Id
type: string
languages:
items:
type: string
title: Languages
type: array
name:
title: Name
type: string
required:
- id
- name
- languages
title: ModelConfig
type: object
PublicConfig:
properties:
models:
items:
$ref: '#/components/schemas/ModelConfig'
title: Models
type: array
required:
- models
title: PublicConfig
type: object
SetDurationRequest:
properties:
duration:
Expand Down Expand Up @@ -351,6 +389,17 @@ paths:
schema: {}
description: Successful Response
summary: Root
/api/v1/config/:
get:
operationId: get_config_api_v1_config__get
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/PublicConfig'
description: Successful Response
summary: Get Config
/api/v1/documents/:
get:
operationId: list_documents_api_v1_documents__get
Expand Down
8 changes: 6 additions & 2 deletions backend/tests/test_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
@pytest.fixture
def document(memory_session: Session, logged_in_client: TestClient):
req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
Expand Down Expand Up @@ -48,7 +50,9 @@ def test_doc_delete(
files = set(str(x) for x in settings.storage_path.glob("*"))

req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
Expand Down
4 changes: 3 additions & 1 deletion backend/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

def create_doc(logged_in_client: TestClient):
req = logged_in_client.post(
"/api/v1/documents/", files={"file": b""}, data={"name": "test document"}
"/api/v1/documents/",
files={"file": b""},
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200

Expand Down
24 changes: 23 additions & 1 deletion backend/transcribee_backend/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import List

from pydantic import BaseSettings
from pydantic import BaseSettings, parse_file_as
from pydantic.main import BaseModel


class Settings(BaseSettings):
Expand All @@ -12,5 +14,25 @@ class Settings(BaseSettings):

media_url_base = "http://localhost:8000/"

model_config_path: Path = Path("data/models.json")


class ModelConfig(BaseModel):
id: str
name: str
languages: List[str]


class PublicConfig(BaseModel):
models: List[ModelConfig]


def get_model_config():
return parse_file_as(List[ModelConfig], settings.model_config_path)


def get_public_config():
return PublicConfig(models=get_model_config())


settings = Settings()
2 changes: 2 additions & 0 deletions backend/transcribee_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transcribee_backend.config import settings
from transcribee_backend.helpers.periodic_tasks import run_periodic
from transcribee_backend.helpers.tasks import timeout_attempts
from transcribee_backend.routers.config import config_router
from transcribee_backend.routers.document import document_router
from transcribee_backend.routers.task import task_router
from transcribee_backend.routers.user import user_router
Expand All @@ -27,6 +28,7 @@
app.include_router(user_router, prefix="/api/v1/users")
app.include_router(document_router, prefix="/api/v1/documents")
app.include_router(task_router, prefix="/api/v1/tasks")
app.include_router(config_router, prefix="/api/v1/config")


@app.get("/")
Expand Down
9 changes: 9 additions & 0 deletions backend/transcribee_backend/routers/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import APIRouter
from transcribee_backend.config import PublicConfig, get_public_config

config_router = APIRouter()


@config_router.get("/")
def get_config() -> PublicConfig:
return get_public_config()
37 changes: 33 additions & 4 deletions backend/transcribee_backend/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
WebSocketException,
status,
)
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper
from sqlalchemy.sql.expression import desc
from sqlmodel import Session, select
from transcribee_backend.auth import (
get_authorized_worker,
validate_user_authorization,
validate_worker_authorization,
)
from transcribee_backend.config import settings
from transcribee_backend.config import get_model_config, settings
from transcribee_backend.db import get_session
from transcribee_backend.helpers.sync import DocumentSyncConsumer
from transcribee_backend.helpers.time import now_tz_aware
Expand Down Expand Up @@ -100,7 +102,9 @@ def ws_get_document_from_url(
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)


def create_default_tasks_for_document(session: Session, document: Document):
def create_default_tasks_for_document(
session: Session, document: Document, model: str, language: str
):
reencode_task = Task(
task_type=TaskType.REENCODE,
task_parameters={},
Expand All @@ -110,7 +114,7 @@ def create_default_tasks_for_document(session: Session, document: Document):

transcribe_task = Task(
task_type=TaskType.TRANSCRIBE,
task_parameters={"lang": "auto", "model": "base"},
task_parameters={"lang": language, "model": model},
document_id=document.id,
dependencies=[reencode_task],
)
Expand All @@ -136,10 +140,35 @@ def create_default_tasks_for_document(session: Session, document: Document):
@document_router.post("/")
async def create_document(
name: str = Form(),
model: str = Form(),
language: str = Form(),
file: UploadFile = File(),
session: Session = Depends(get_session),
token: UserToken = Depends(get_user_token),
) -> ApiDocument:
model_configs = get_model_config()
selected_model = None
for model_config in model_configs:
if model_config.id == model:
selected_model = model_config

if selected_model is None:
raise RequestValidationError(
[ErrorWrapper(ValueError(f"Unknown model '{model}'"), ("body", "model"))]
)

if language not in selected_model.languages:
raise RequestValidationError(
[
ErrorWrapper(
ValueError(
f"Model '{model}' does not support language '{language}'"
),
("body", "language"),
)
]
)

document = Document(
name=name,
user_id=token.user_id,
Expand All @@ -165,7 +194,7 @@ async def create_document(
tag = DocumentMediaTag(media_file_id=media_file.id, tag="original")
session.add(tag)

create_default_tasks_for_document(session, document)
create_default_tasks_for_document(session, document, model, language)

session.commit()
return document.as_api_document()
Expand Down
19 changes: 19 additions & 0 deletions frontend/src/api/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { defaultConfig } from 'swr/_internal';
import { fetcher, makeSwrHook } from '../api';
import { SWRConfiguration } from 'swr';

export const getConfig = fetcher.path('/api/v1/config/').method('get').create();

const useGetConfigWithRetry = makeSwrHook('getConfig', getConfig);

export const useGetConfig = (
params: Parameters<typeof useGetConfigWithRetry>[0],
options?: Partial<SWRConfiguration>,
) =>
useGetConfigWithRetry(params, {
onErrorRetry: (err, key, config, revalidate, opts) => {
if (err.status === 422) return;
else defaultConfig.onErrorRetry(err, key, config, revalidate, opts);
},
...options,
});
33 changes: 33 additions & 0 deletions frontend/src/openapi-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ export interface paths {
/** Root */
get: operations["root__get"];
};
"/api/v1/config/": {
/** Get Config */
get: operations["get_config_api_v1_config__get"];
};
"/api/v1/documents/": {
/** List Documents */
get: operations["list_documents_api_v1_documents__get"];
Expand Down Expand Up @@ -133,6 +137,10 @@ export interface components {
* Format: binary
*/
file: string;
/** Language */
language: string;
/** Model */
model: string;
/** Name */
name: string;
};
Expand Down Expand Up @@ -185,6 +193,20 @@ export interface components {
/** Token */
token: string;
};
/** ModelConfig */
ModelConfig: {
/** Id */
id: string;
/** Languages */
languages: (string)[];
/** Name */
name: string;
};
/** PublicConfig */
PublicConfig: {
/** Models */
models: (components["schemas"]["ModelConfig"])[];
};
/** SetDurationRequest */
SetDurationRequest: {
/** Duration */
Expand Down Expand Up @@ -309,6 +331,17 @@ export interface operations {
};
};
};
/** Get Config */
get_config_api_v1_config__get: {
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["PublicConfig"];
};
};
};
};
/** List Documents */
list_documents_api_v1_documents__get: {
parameters: {
Expand Down
Loading

0 comments on commit b78bbda

Please sign in to comment.