Skip to content

Commit

Permalink
batch processing for translations
Browse files Browse the repository at this point in the history
Signed-off-by: Shashank Mittal <[email protected]>
  • Loading branch information
shashank-iitbhu committed Mar 4, 2024
1 parent 42bd042 commit ac7488e
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/scribe_data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,9 @@ def get_target_langcodes(source_lang)->list[str]:
continue
return target_langcodes

def translate_to_other_languages(source_language, word_list, translations):
def translate_to_other_languages(source_language, word_list, translations, batch_size=10):
"""
Translates a list of words from the source language to other target languages.
Translates a list of words from the source language to other target languages using batch processing.
Parameters
----------
Expand All @@ -499,24 +499,30 @@ def translate_to_other_languages(source_language, word_list, translations):
word_list : list[str]
The list of words to translate.
translations : list[dict]
The current list of translations.
translations : dict
The current dictionary of translations.
batch_size : int
The number of words to translate in each batch.
"""
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

for word in word_list[len(translations):]:
word_translations = {word: {}}
for i in range(0, len(word_list), batch_size):
batch_words = word_list[i:i+batch_size]
print(f"Translating batch {i//batch_size + 1}: {batch_words}")
for lang_code in get_target_langcodes(source_language):
tokenizer.src_lang = get_language_iso(source_language)
encoded_word = tokenizer(word, return_tensors="pt")
generated_tokens = model.generate(**encoded_word, forced_bos_token_id=tokenizer.get_lang_id(lang_code))
translated_word = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
word_translations[word][lang_code] = translated_word
translations.append(word_translations)
encoded_words = tokenizer(batch_words, return_tensors="pt", padding=True)
generated_tokens = model.generate(**encoded_words, forced_bos_token_id=tokenizer.get_lang_id(lang_code))
translated_words = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
for word, translation in zip(batch_words, translated_words):
if word not in translations:
translations[word] = {}
translations[word][lang_code] = translation
print(f"Batch {i//batch_size + 1} translation completed.")

with open(f"{get_language_dir_path(source_language)}/formatted_data/translated_words.json", 'w', encoding='utf-8') as file:
json.dump(translations, file, ensure_ascii=False, indent=4)
print(f"Translation results for the word '{word}' have been saved.")

print("Translation results for all words are saved to the translated_words.json file.")

0 comments on commit ac7488e

Please sign in to comment.