-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff95cce
commit 676c690
Showing
15 changed files
with
314 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: mwahdan | ||
""" | ||
|
||
# diasable the GPU | ||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
|
||
from dialognlu import TransformerNLU | ||
from dialognlu.readers.goo_format_reader import Reader | ||
import time | ||
|
||
|
||
num_process = 2 | ||
|
||
|
||
model_path = "../saved_models/joint_distilbert_model" | ||
# model_path = "../saved_models/joint_trans_bert_model" | ||
# model_path = "../saved_models/joint_trans_albert_model" | ||
# model_path = "../saved_models/joint_trans_roberta_model" | ||
|
||
print("Loading model ...") | ||
nlu = TransformerNLU.load(model_path, quantized=True, num_process=num_process) | ||
|
||
print("Loading dataset ...") | ||
test_path = "../data/snips/test" | ||
test_dataset = Reader.read(test_path) | ||
|
||
print("Evaluating model ...") | ||
t1 = time.time() | ||
token_f1_score, tag_f1_score, report, acc = nlu.evaluate(test_dataset) | ||
t2 = time.time() | ||
|
||
print('Slot Classification Report:', report) | ||
print('Slot token f1_score = %f' % token_f1_score) | ||
print('Slot tag f1_score = %f' % tag_f1_score) | ||
print('Intent accuracy = %f' % acc) | ||
|
||
print("Using %d processes took %f seconds" % (num_process, t2 - t1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: mwahdan | ||
""" | ||
|
||
# diasable the GPU | ||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
|
||
from dialognlu import TransformerNLU | ||
|
||
|
||
# model_path = "../saved_models/joint_distilbert_model" | ||
# model_path = "../saved_models/joint_trans_bert_model" | ||
# model_path = "../saved_models/joint_trans_albert_model" | ||
model_path = "../saved_models/joint_trans_roberta_model" | ||
|
||
print("Loading model ...") | ||
nlu = TransformerNLU.load(model_path, quantized=True, num_process=1) | ||
|
||
print("Prediction ...") | ||
utterance = "add sabrina salerno to the grime instrumentals playlist" | ||
print ("utterance: {}".format(utterance)) | ||
result = nlu.predict(utterance) | ||
print ("result: {}".format(result)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
|
||
setup( | ||
name="dialognlu", | ||
version="0.1.0", | ||
version="0.2.0", | ||
author="Mahmoud Wahdan", | ||
author_email="[email protected]", | ||
description="State-of-the-art Dialog NLU (Natural Language Understanding) Library with TensorFlow 2.x and keras", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
|
||
__version__ = "0.1.0" | ||
__version__ = "0.2.0" | ||
|
||
from .nlu_components import TransformerNLU, BertNLU | ||
from .auto_nlu import AutoNLU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.