Skip to content

Commit

Permalink
Merge pull request #1 from lecoqnicolas/lecoqnicolas-comet-evaluation
Browse files Browse the repository at this point in the history
Add COMET evaluation
  • Loading branch information
lecoqnicolas authored May 21, 2024
2 parents f60c8fe + c0112e0 commit 0b9bebc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
5 changes: 5 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def get_flores_dataset_path(dataset="dev"):

return flores_dataset

def get_flores_file_path(lang_code, dataset="dev"):
flores_dataset = get_flores_dataset_path(dataset)
flores_file_path = os.path.join(flores_dataset, nllb_langs[lang_code] + f".{dataset}")
return flores_file_path

def get_flores(lang_code, dataset="dev"):
flores_dataset = get_flores_dataset_path(dataset)
source = os.path.join(flores_dataset, nllb_langs[lang_code] + f".{dataset}")
Expand Down
25 changes: 22 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import argparse
import ctranslate2
import sentencepiece
import subprocess
from sacrebleu import corpus_bleu
from data import get_flores
from data import get_flores, get_flores_file_path
from tokenizer import BPETokenizer, SentencePieceTokenizer

parser = argparse.ArgumentParser(description='Evaluate LibreTranslate compatible models')
Expand Down Expand Up @@ -32,6 +33,9 @@
parser.add_argument('--translate_flores',
action="store_true",
help='Translate the flores200 corpus into a text file with .evl extension. Default: %(default)s')
parser.add_argument('--comet',
action="store_true",
help='Run COMET score command on the translated flores text. Default: %(default)s')
parser.add_argument('--cpu',
action="store_true",
help='Force CPU use. Default: %(default)s')
Expand Down Expand Up @@ -65,7 +69,6 @@
exit(1)



def translator():
device = "cuda" if ctranslate2.get_cuda_device_count() > 0 and not args.cpu else "cpu"
model = ctranslate2.Translator(ct2_model_dir, device=device, compute_type="default")
Expand Down Expand Up @@ -98,7 +101,7 @@ def translate_flores():

data = translator()

if args.bleu or args.flores_id or args.translate_flores is not None:
if args.bleu or args.flores_id or args.translate_flores or args.comet is not None:
if args.flores_dataset:
dataset = args.flores_dataset
src_text = get_flores(config["from"]["code"], dataset)
Expand Down Expand Up @@ -127,6 +130,22 @@ def translate_flores():
if args.translate_flores:
translate_flores()

if args.comet:
src_f = get_flores_file_path(config["from"]["code"], dataset)
ref_f = get_flores_file_path(config["to"]["code"], dataset)
tra_f = translate_flores()

subprocess.run([
"comet-score",
"--sources",
src_f,
"--translations",
tra_f,
"--references",
ref_f,
"--quiet",
"--only_system"])

if args.flores_id is not None:
print(f"({config['from']['code']})> {src_text[0]}\n(gt)> {tgt_text[0]}\n({config['to']['code']})> {' '.join(translated_text)}")
else:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ sentencepiece==0.1.99
requests==2.31.0
PyYAML==6.0.1
sacrebleu==2.3.1
unbabel-comet==2.2.2
subword-nmt>=0.3.7
OpenNMT-py==3.4.1
tensorboard==2.14.0
six==1.16.0
iso639==0.1.4
sacremoses==0.0.53
removedup==1.0.6
fastshuffle==1.0.1
fastshuffle==1.0.1

0 comments on commit 0b9bebc

Please sign in to comment.