Skip to content

Commit

Permalink
adds test, remove debugging code -- finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
akash5100 committed Nov 1, 2023
1 parent 67d20bb commit 1988531
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 68 deletions.
5 changes: 5 additions & 0 deletions contentcuration/automation/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from torch.cuda import is_available as is_gpu_available

DEVICE = "cuda:0" if is_gpu_available() else "cpu"
Expand All @@ -17,6 +18,10 @@
DEV_TRANSCRIPTION_MODEL = WHISPER_MODELS['TINY']
TRANSCRIPTION_MODEL = WHISPER_MODELS['TINY']

class WhisperTask(Enum):
TRANSLATE = "translate"
TRANSCRIBE = "transcribe"

# https://huggingface.co/docs/transformers/v4.29.1/en/generation_strategies#customize-text-generation
MAX_TOKEN_LENGTH = 448
CHUNK_LENGTH = 10
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# https://github.com/learningequality/studio/blob/unstable/contentcuration/contentcuration/frontend/shared/leUtils/TranscriptionLanguages.js

import json
from typing import List, Dict

import le_utils.resources as resources

WHISPER_LANGUAGES = {
Expand Down Expand Up @@ -113,22 +115,27 @@
"su": "sundanese",
}

def _load_kolibri_languages():
"""Load Kolibri languages from JSON file and return the language codes as a list."""

def _load_kolibri_languages() -> List[str]:
"""Loads the language codes from languagelookup.json and returns them as a list."""
filepath = resources.__path__[0]
kolibri_languages = []
with open(f'{filepath}/languagelookup.json') as f:
with open(f"{filepath}/languagelookup.json") as f:
kolibri_languages = list(json.load(f).keys())
return kolibri_languages

def _load_model_languages(languages):
"""Load languages supported by the speech-to-text model."""

def _load_model_languages(languages: Dict[str, str]) -> List[str]:
"""Load languages supported by the speech-to-text model.
:param: languages: dict mapping language codes to language names"""
return list(languages.keys())

def create_captions_languages():
"""Create the intersection of transcription model and Kolibri languages."""

def create_captions_languages() -> List[str]:
"""Finds the intersection of Kolibri languages and model languages and returns it."""
kolibri_set = set(_load_kolibri_languages())
model_set = set(_load_model_languages(languages=WHISPER_LANGUAGES))
return list(kolibri_set.intersection(model_set))


CAPTIONS_LANGUAGES = create_captions_languages()
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
{{ $tr('generateBtn') }}
</VBtn>
</div>

<button @click="logState">
log state
</button>

<!-- TODO -->
<!-- est. time, time elapsed can be good -->
<p v-if="isGeneratingCaptions">
Expand Down Expand Up @@ -101,11 +96,6 @@ import LanguageDropdown from 'shared/views/LanguageDropdown';
},
methods: {
...mapActions('caption', ['addCaptionFile']),
logState() {
console.log('nodeId ', this.nodeId);
console.log(this.captionFilesMap[this.nodeId]);
console.log(this.captionCuesMap[this.nodeId]);
},
addCaption() {
const id = uuid4();
const fileId = this.getLongestDurationFileId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ export function ADD_CAPTIONFILES(state, { captionFiles, nodeIds }) {

/* Mutations for Caption Cues */
export function ADD_CUE(state, { cue, nodeId }) {

console.log(cue, nodeId);

if (!cue && !nodeId) return;
// Check if there is Map for the current nodeId
if (!state.captionCuesMap[nodeId]) {
Expand Down Expand Up @@ -71,5 +68,4 @@ export function UPDATE_CAPTIONFILE_FROM_INDEXEDDB(state, { id, ...mods }) {
break;
}
}
console.log('done');
}
17 changes: 17 additions & 0 deletions contentcuration/contentcuration/frontend/shared/data/resources.js
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,23 @@ export const CaptionFile = new Resource({
syncable: true,
getChannelId: getChannelFromChannelScope,

updateContentNodeChanged(id) {
return this.transaction({ mode: 'rw' }, CHANGES_TABLE, TABLE_NAMES.CONTENTNODE, () => {
ContentNode.table
.where({ channel_id: id })
.and(node => node.has_children == true)
.modify({ changed: true })
});
},

_add: Resource.prototype.add,
add(obj) {
return this._add(obj).then(id => {

Check failure on line 1040 in contentcuration/contentcuration/frontend/shared/data/resources.js

View workflow job for this annotation

GitHub Actions / Frontend linting

'id' is defined but never used
const contentnodeId = this.getChannelId();
return this.updateContentNodeChanged(contentnodeId);
})
},

waitForCaptionCueGeneration(id) {
const observable = Dexie.liveQuery(() => {
return this.table
Expand Down
13 changes: 8 additions & 5 deletions contentcuration/contentcuration/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from contentcuration.models import ContentNode
from contentcuration.models import User
from contentcuration.utils.csv_writer import write_user_csv
from contentcuration.utils.transcription import WhisperAdapter, WhisperBackendFactory
from contentcuration.utils.nodes import calculate_resource_size
from contentcuration.utils.nodes import generate_diff
from contentcuration.viewsets.user import AdminUserFilter
Expand Down Expand Up @@ -144,8 +143,13 @@ def generatecaptioncues_task(caption_file_id: str, channel_id, user_id) -> None:
"""Start generating the Captions Cues for requested the Caption File"""

from contentcuration.viewsets.caption import CaptionCueSerializer
from contentcuration.viewsets.sync.constants import CAPTION_FILE, CAPTION_CUES
from contentcuration.viewsets.sync.utils import generate_update_event, generate_create_event
from contentcuration.viewsets.sync.constants import CAPTION_FILE
from contentcuration.viewsets.sync.constants import CAPTION_CUES
from contentcuration.viewsets.sync.utils import generate_update_event
from contentcuration.viewsets.sync.utils import generate_create_event
from contentcuration.utils.transcription import WhisperAdapter
from contentcuration.utils.transcription import WhisperBackendFactory


backend = WhisperBackendFactory().create_backend()
adapter = WhisperAdapter(backend=backend)
Expand All @@ -168,9 +172,8 @@ def generatecaptioncues_task(caption_file_id: str, channel_id, user_id) -> None:
},
channel_id=channel_id,
), applied=True, created_by_id=user_id)

else:
print(serializer.errors)
raise ValueError(f"Error in cue serialization: {serializer.errors}")

Change.create_change(generate_update_event(
caption_file_id,
Expand Down
41 changes: 41 additions & 0 deletions contentcuration/contentcuration/tests/test_exportchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,47 @@ def test_publish_no_modify_legacy_exercise_extra_fields(self):
'n': 2
})

def test_vtt_on_publish(self):
from contentcuration.utils.publish import process_webvtt_file_publishing
# Set up a video node with captions
new_video = create_node({'kind_id': 'video', 'title': 'caption creation test'})
new_video.complete = True
new_video.parent = self.content_channel.main_tree
new_video.save()

# create a CaptionFile associated with contentnode
video_files = new_video.files.all()
caption_file_data = {
"file_id": video_files[0].id,
"language": cc.Language.objects.get(pk="en"),
}
caption_file = cc.CaptionFile(**caption_file_data)
caption_file.save()

# create a CaptionCue associated with CaptionFile
cues = cc.CaptionCue(text='a test string', starttime=0, endtime=3, caption_file=caption_file)
cues.save()

assert caption_file.output_file is None
process_webvtt_file_publishing('create', new_video, caption_file)
assert caption_file.output_file is not None

expected_webvtt = 'WEBVTT\n\n0:00:00.000 --> 0:00:03.000\n- a test string\n\n'.encode('utf-8')
webvtt = caption_file.output_file.file_on_disk.read() # output_file is the VTT file
assert webvtt == expected_webvtt

# Update caption text
caption_cue = caption_file.caption_cue.first()
caption_cue.text = "Updated text"
caption_cue.save()

# Publish again to update VTT file
process_webvtt_file_publishing('update', new_video, caption_file)
updated_vtt = caption_file.output_file.file_on_disk.read()

# Assert VTT files are different
assert webvtt != updated_vtt
assert updated_vtt == 'WEBVTT\n\n0:00:00.000 --> 0:00:03.000\n- Updated text\n\n'.encode('utf-8')

class EmptyChannelTestCase(StudioTestCase):

Expand Down
58 changes: 20 additions & 38 deletions contentcuration/contentcuration/tests/viewsets/test_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def same_file_different_language_metadata(self):
return [
{
"file_id": id,
"language": Language.objects.get(pk="en"),
"language": Language.objects.get(pk="en"),
},
{
"file_id": id,
"language": Language.objects.get(pk="ru"),
"language": Language.objects.get(pk="ru"),
},
]

Expand All @@ -43,7 +43,7 @@ def caption_cue_metadata(self):
return {
"file": {
"file_id": uuid.uuid4().hex,
"language": Language.objects.get(pk="en").pk,
"language": Language.objects.get(pk="en").pk,
},
"cue": {
"text": "This is the beginning!",
Expand All @@ -63,12 +63,16 @@ def test_create_caption(self):
self.client.force_authenticate(user=self.user)
caption_file = self.caption_file_metadata

response = self.sync_changes([generate_create_event(
uuid.uuid4().hex,
CAPTION_FILE,
caption_file,
channel_id=self.channel.id,
)],)
response = self.sync_changes(
[
generate_create_event(
uuid.uuid4().hex,
CAPTION_FILE,
caption_file,
channel_id=self.channel.id,
)
],
)
self.assertEqual(response.status_code, 200, response.content)

try:
Expand All @@ -83,27 +87,11 @@ def test_create_caption(self):
self.assertEqual(caption_file_db.file_id, caption_file["file_id"])
self.assertEqual(caption_file_db.language_id, caption_file["language"])

def test_enqueue_caption_task(self):
self.client.force_authenticate(user=self.user)
caption_file = {
"file_id": uuid.uuid4().hex,
"language": Language.objects.get(pk="en").pk,
}

response = self.sync_changes([generate_create_event(
uuid.uuid4().hex,
CAPTION_FILE,
caption_file,
channel_id=self.channel.id,
)],)
self.assertEqual(response.status_code, 200, response.content)


def test_delete_caption_file(self):
self.client.force_authenticate(user=self.user)
caption_file = self.caption_file_metadata
# Explicitly set language to model object to follow Django ORM conventions
caption_file['language'] = Language.objects.get(pk='en')
caption_file["language"] = Language.objects.get(pk="en")
caption_file_1 = CaptionFile(**caption_file)
pk = caption_file_1.pk

Expand Down Expand Up @@ -144,8 +132,7 @@ def test_delete_file_with_same_file_id_different_language(self):

def test_caption_file_serialization(self):
metadata = self.caption_file_metadata
# Explicitly set language to model object to follow Django ORM conventions
metadata['language'] = Language.objects.get(pk="en")
metadata["language"] = Language.objects.get(pk="en")
caption_file = CaptionFile.objects.create(**metadata)
serializer = CaptionFileSerializer(instance=caption_file)
try:
Expand All @@ -155,8 +142,7 @@ def test_caption_file_serialization(self):

def test_caption_cue_serialization(self):
metadata = self.caption_cue_metadata
# Explicitly set language to model object to follow Django ORM conventions
metadata['file']['language'] = Language.objects.get(pk="en")
metadata["file"]["language"] = Language.objects.get(pk="en")
caption_file = CaptionFile.objects.create(**metadata["file"])
caption_cue = metadata["cue"]
caption_cue.update(
Expand All @@ -178,8 +164,7 @@ def test_create_caption_cue(self):
self.client.force_authenticate(user=self.user)
metadata = self.caption_cue_metadata

# Explicitly set language to model object to follow Django ORM conventions
metadata['file']['language'] = Language.objects.get(pk="en")
metadata["file"]["language"] = Language.objects.get(pk="en")

caption_file_1 = CaptionFile.objects.create(**metadata["file"])
caption_cue = metadata["cue"]
Expand Down Expand Up @@ -209,8 +194,7 @@ def test_create_caption_cue(self):
def test_delete_caption_cue(self):
self.client.force_authenticate(user=self.user)
metadata = self.caption_cue_metadata
# Explicitly set language to model object to follow Django ORM conventions
metadata['file']['language'] = Language.objects.get(pk="en")
metadata["file"]["language"] = Language.objects.get(pk="en")
caption_file_1 = CaptionFile.objects.create(**metadata["file"])
caption_cue = metadata["cue"]
caption_cue.update({"caption_file": caption_file_1})
Expand Down Expand Up @@ -245,8 +229,7 @@ def test_delete_caption_cue(self):
def test_update_caption_cue(self):
self.client.force_authenticate(user=self.user)
metadata = self.caption_cue_metadata
# Explicitly set language to model object to follow Django ORM conventions
metadata['file']['language'] = Language.objects.get(pk="en")
metadata["file"]["language"] = Language.objects.get(pk="en")
caption_file_1 = CaptionFile.objects.create(**metadata["file"])

caption_cue = metadata["cue"]
Expand Down Expand Up @@ -299,8 +282,7 @@ def test_update_caption_cue(self):

def test_invalid_caption_cue_data_serialization(self):
metadata = self.caption_cue_metadata
# Explicitly set language to model object to follow Django ORM conventions
metadata['file']['language'] = Language.objects.get(pk="en")
metadata["file"]["language"] = Language.objects.get(pk="en")
caption_file = CaptionFile.objects.create(**metadata["file"])
caption_cue = metadata["cue"]
caption_cue.update(
Expand Down
6 changes: 4 additions & 2 deletions contentcuration/contentcuration/utils/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def process_webvtt_file_publishing(
:param caption_file: A CaptionFile to associate with the WebVTT file.
:param user_id: The ID of the user creating the WebVTT file (optional).
"""
logging.debug(f"{action} WebVTT for Node {ccnode.title}")
logging.debug(f"{action[:-1]}ing WebVTT for Node {ccnode.title}")
vtt_content = generate_webvtt_file(caption_cues=caption_file.caption_cue.all())
filename = "{name}_{lang}.{ext}".format(name=ccnode.title, lang=caption_file.language, ext=file_formats.VTT)
temppath = None
Expand All @@ -207,9 +207,11 @@ def process_webvtt_file_publishing(

if action == 'update' and caption_file.output_file:
caption_file.output_file.contentnode = None
caption_file.save(update_fields=['contentnode'])
caption_file.output_file.save(update_fields=['contentnode'])

caption_file.output_file = new_vtt_file
# specifying output_field to be updated because by default the addition of FK updates
# the modified of CaptionFile obj results in always vtt_file.modified > caption_file.modified
caption_file.save(update_fields=['output_file'])
except Exception as e:
logging.error(f"Error creating VTT file for {ccnode.title}: {str(e)}")
Expand Down
6 changes: 4 additions & 2 deletions contentcuration/contentcuration/utils/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from automation.settings import DEVICE
from automation.settings import DEV_TRANSCRIPTION_MODEL
from automation.settings import MAX_TOKEN_LENGTH
# from automation.settings import WhisperTask
from automation.utils.appnexus.base import Adapter
from automation.utils.appnexus.base import Backend
from automation.utils.appnexus.base import BackendFactory
from automation.utils.appnexus.base import BackendRequest
from automation.utils.appnexus.base import BackendResponse
from contentcuration.constants.transcription_languages import WHISPER_LANGUAGES as LANGS
from contentcuration.models import CaptionFile
from contentcuration.models import File
from contentcuration.not_production_settings import WHISPER_BACKEND
Expand All @@ -35,7 +37,7 @@ def _get_binary(self) -> bytes:

def get_binary_data(self) -> bytes: return self.binary
def get_file_url(self) -> str: return self.url
def get_langauge(self) -> str: return self.language
def get_language(self) -> str: return self.language


class WhisperResponse(BackendResponse):
Expand Down Expand Up @@ -102,7 +104,7 @@ def create_backend(self) -> Backend:
class WhisperAdapter(Adapter):
def transcribe(self, caption_file_id: str) -> WhisperResponse:
f = CaptionFile.objects.get(pk=caption_file_id)
file_id, language = f.file_id, f.language # TODO: set language of transcription
file_id, language = f.file_id, LANGS[f.language.lang_code]
media_file = File.objects.get(pk=file_id).file_on_disk.url

request = WhisperRequest(url=media_file, language=language)
Expand Down

0 comments on commit 1988531

Please sign in to comment.