diff --git a/nlu/components/embeddings/mxbai/MxbaiEmbeddings.py b/nlu/components/embeddings/mxbai/MxbaiEmbeddings.py new file mode 100644 index 00000000..09385b67 --- /dev/null +++ b/nlu/components/embeddings/mxbai/MxbaiEmbeddings.py @@ -0,0 +1,19 @@ +from sparknlp.annotator import * + + +class MxbaiEmbeddings: + @staticmethod + def get_default_model(): + from sparknlp.annotator import MxbaiEmbeddings + return MxbaiEmbeddings.pretrained() \ + .setInputCols(["document"]) \ + .setOutputCol("mxbai_embeddings") + + # @staticmethod + # def get_pretrained_model(name, language, bucket=None): + # return MxbaiEmbeddings.pretrained(name,language,bucket) \ + # .setInputCols(["document"]) \ + # .setOutputCol("sentence_embeddings") + + + diff --git a/nlu/components/embeddings/mxbai/__init__.py b/nlu/components/embeddings/mxbai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nlu/spellbook.py b/nlu/spellbook.py index dd2225bb..e3396d9b 100644 --- a/nlu/spellbook.py +++ b/nlu/spellbook.py @@ -4556,6 +4556,7 @@ class Spellbook: 'en.dep.untyped': 'dependency_conllu', 'en.dep.untyped.conllu': 'dependency_conllu', 'en.e2e': 'multiclassifierdl_use_e2e', + 'en.embed.mxbai' : 'mxbai_large_v1', 'en.embed': 'glove_100d', 'en.embed.Bible_roberta_base': 'roberta_embeddings_Bible_roberta_base', 'en.embed.COVID_SciBERT': 'bert_embeddings_COVID_SciBERT', @@ -16051,6 +16052,7 @@ class Spellbook: 'genericclassifier_sdoh_tobacco_usage_sbiobert_cased_mli': 'GenericClassifierModel', 'github_issues_mpnet_southern_sotho_e10': 'MPNetEmbeddings', 'github_issues_preprocessed_mpnet_southern_sotho_e10': 'MPNetEmbeddings', + 'mxbai_large_v1': 'MxbaiEmbeddings', 'glove_100d': 'WordEmbeddingsModel', 'glove_6B_100': 'WordEmbeddingsModel', 'glove_6B_300': 'WordEmbeddingsModel', diff --git a/nlu/universe/annotator_class_universe.py b/nlu/universe/annotator_class_universe.py index 0eeecc13..5cf1848a 100644 --- a/nlu/universe/annotator_class_universe.py +++ b/nlu/universe/annotator_class_universe.py @@ -97,6 +97,7 @@ class AnnoClassRef: A_N.ALBERT_EMBEDDINGS: 'AlbertEmbeddings', A_N.ALBERT_FOR_TOKEN_CLASSIFICATION: 'AlbertForTokenClassification', A_N.BERT_EMBEDDINGS: 'BertEmbeddings', + A_N.MXBAI_EMBEDDINGS: 'MxbaiEmbeddings', A_N.BERT_FOR_TOKEN_CLASSIFICATION: 'BertForTokenClassification', A_N.BERT_SENTENCE_EMBEDDINGS: 'BertSentenceEmbeddings', A_N.DISTIL_BERT_EMBEDDINGS: 'DistilBertEmbeddings', diff --git a/nlu/universe/component_universes.py b/nlu/universe/component_universes.py index 0ace75e6..d385bd69 100644 --- a/nlu/universe/component_universes.py +++ b/nlu/universe/component_universes.py @@ -96,6 +96,7 @@ from nlu.components.embeddings.xlm.xlm import XLM from nlu.components.embeddings.xlnet.spark_nlp_xlnet import SparkNLPXlnet from nlu.components.embeddings_chunks.chunk_embedder.chunk_embedder import ChunkEmbedder +from nlu.components.embeddings.mxbai.MxbaiEmbeddings import MxbaiEmbeddings from nlu.components.lemmatizers.lemmatizer.spark_nlp_lemmatizer import SparkNLPLemmatizer from nlu.components.matchers.regex_matcher.regex_matcher import RegexMatcher from nlu.components.normalizers.document_normalizer.spark_nlp_document_normalizer import SparkNLPDocumentNormalizer @@ -1988,6 +1989,27 @@ class ComponentUniverse: is_storage_ref_producer=True, has_storage_ref=True ), + + A.MXBAI_EMBEDDINGS: partial(NluComponent, + name=A.MXBAI_EMBEDDINGS, + type=T.DOCUMENT_EMBEDDING, + get_default_model=MxbaiEmbeddings.get_default_model, + pdf_extractor_methods={'default': default_sentence_embedding_config, + 'default_full': default_full_config, }, + pdf_col_name_substitutor=substitute_sent_embed_cols, + output_level=L.INPUT_DEPENDENT_DOCUMENT_EMBEDDING, + node=NLP_FEATURE_NODES.nodes[A.MXBAI_EMBEDDINGS], + description='Converts Word Embeddings to Sentence/Document Embeddings', + provider=ComponentBackends.open_source, + license=Licenses.open_source, + computation_context=ComputeContexts.spark, + output_context=ComputeContexts.spark, + jsl_anno_class_id=A.MXBAI_EMBEDDINGS, + jsl_anno_py_class=ACR.JSL_anno2_py_class[ + A.MXBAI_EMBEDDINGS], + is_storage_ref_producer=True, + has_storage_ref=True + ), A.STEMMER: partial(NluComponent, name=A.STEMMER, type=T.TOKEN_NORMALIZER, diff --git a/nlu/universe/feature_node_ids.py b/nlu/universe/feature_node_ids.py index 49617140..257095ba 100644 --- a/nlu/universe/feature_node_ids.py +++ b/nlu/universe/feature_node_ids.py @@ -70,6 +70,7 @@ class NLP_NODE_IDS: SENTENCE_DETECTOR = JslAnnoId('sentence_detector') SENTENCE_DETECTOR_DL = JslAnnoId('sentence_detector_dl') SENTENCE_EMBEDDINGS_CONVERTER = JslAnnoId('sentence_embeddings_converter') + MXBAI_EMBEDDINGS = JslAnnoId('mxbai_embeddings') STEMMER = JslAnnoId('stemmer') STOP_WORDS_CLEANER = JslAnnoId('stop_words_cleaner') SYMMETRIC_DELETE_SPELLCHECKER = JslAnnoId('symmetric_delete_spellchecker') diff --git a/nlu/universe/feature_node_universes.py b/nlu/universe/feature_node_universes.py index 21a0d3ae..2c34d929 100644 --- a/nlu/universe/feature_node_universes.py +++ b/nlu/universe/feature_node_universes.py @@ -76,6 +76,7 @@ class NLP_FEATURE_NODES: # or Mode Node? A.INSTRUCTOR_SENTENCE_EMBEDDINGS: NlpFeatureNode(A.INSTRUCTOR_SENTENCE_EMBEDDINGS, [F.DOCUMENT], [F.SENTENCE_EMBEDDINGS]), A.E5_SENTENCE_EMBEDDINGS: NlpFeatureNode(A.E5_SENTENCE_EMBEDDINGS, [F.DOCUMENT],[F.SENTENCE_EMBEDDINGS]), + A.MXBAI_EMBEDDINGS: NlpFeatureNode(A.MXBAI_EMBEDDINGS, [F.DOCUMENT],[F.SENTENCE_EMBEDDINGS]), A.BGE_SENTENCE_EMBEDDINGS: NlpFeatureNode(A.BGE_SENTENCE_EMBEDDINGS, [F.DOCUMENT], [F.SENTENCE_EMBEDDINGS]), A.MPNET_SENTENCE_EMBEDDINGS: NlpFeatureNode(A.MPNET_SENTENCE_EMBEDDINGS, [F.DOCUMENT], [F.SENTENCE_EMBEDDINGS]), diff --git a/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/mxbai_embeddings.py b/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/mxbai_embeddings.py new file mode 100644 index 00000000..8faa70a5 --- /dev/null +++ b/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/mxbai_embeddings.py @@ -0,0 +1,27 @@ +# import tests.secrets as sct + +import os +import sys + +# sys.path.append(os.getcwd()) +import unittest +import nlu + +os.environ["PYTHONPATH"] = "F:/Work/repos/nlu_new/nlu" +os.environ['PYSPARK_PYTHON'] = sys.executable +os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable +from johnsnowlabs import nlp, visual + +# nlp.install(json_license_path="license.json") + +nlp.start() + +class EmbeddingTests(unittest.TestCase): + def test_mxbai_embeddings_model(self): + + res = nlu.load("en.embed.mxbai").predict('This is an example sentence', output_level='document') + print(res) + + +if __name__ == "__main__": + EmbeddingTests().test_mxbai_embeddings_model() diff --git a/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/sentence_e5_tests.py b/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/sentence_e5_tests.py index 5c8dda98..edec64e2 100644 --- a/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/sentence_e5_tests.py +++ b/tests/nlu_core_tests/component_tests/embed_tests/sentence_embeddings/sentence_e5_tests.py @@ -1,7 +1,17 @@ +import sys +import os +# sys.path.append(os.getcwd()) import unittest +import nlu -from nlu import * +os.environ["PYTHONPATH"] = "F:/Work/repos/nlu_new/nlu" +os.environ['PYSPARK_PYTHON'] = sys.executable +os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable +from johnsnowlabs import nlp, visual +# nlp.install(json_license_path="license.json") + +nlp.start() class TestE5SentenceEmbeddings(unittest.TestCase): def test_e5_embeds(self): diff --git a/tests/nlu_hc_tests/component_tests/few_shot_assertion_classifier/assertion_tests.py b/tests/nlu_hc_tests/component_tests/few_shot_assertion_classifier/assertion_tests.py index a2e61e77..c9395770 100644 --- a/tests/nlu_hc_tests/component_tests/few_shot_assertion_classifier/assertion_tests.py +++ b/tests/nlu_hc_tests/component_tests/few_shot_assertion_classifier/assertion_tests.py @@ -7,7 +7,7 @@ import unittest import nlu -# os.environ["PYTHONPATH"] = "F:/Work/repos/nlu_new/nlu" +os.environ["PYTHONPATH"] = "F:/Work/repos/nlu_new/nlu" os.environ['PYSPARK_PYTHON'] = sys.executable os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable from johnsnowlabs import nlp, visual