diff --git a/README.md b/README.md index 633f1e6..1d0e699 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,13 @@ ``` ## 1.3 模型文件 +### 版本v0.3.0 + - 新增一种生成反义词/近义词的算法, 构建提示词prompt, 基于BERT-MLM等继续训练, 类似beam_search方法, 生成反义词/近义词; + ``` + prompt: "xx"的反义词是"[MASK][MASK]"。 + ``` + - 模型权重在[Macropodus/mlm_antonym_model](https://huggingface.co/Macropodus/mlm_antonym_model), 国内镜像[Macropodus/mlm_antonym_model](https://hf-mirror.com/Macropodus/mlm_antonym_model) + ### 版本v0.1.0 - github项目源码自带模型文件只有1w+词向量, 完整模型文件在near_synonym/near_synonym_model, - pip下载pypi包里边没有数据和模型(只有代码), 第一次加载使用huggface_hub下载, 大约为420M; @@ -28,10 +35,9 @@ - 或完整的词向量详见百度网盘分享链接[https://pan.baidu.com/s/1lDSCtpr0r2hKrGrK8ZLlFQ](https://pan.baidu.com/s/1lDSCtpr0r2hKrGrK8ZLlFQ), 密码: ff0y - # 二、使用方式 -## 2.1 快速使用, 反义词, 近义词, 相似度 +## 2.1 快速使用方法一, 反义词, 近义词, 相似度 ```python3 import near_synonym @@ -59,7 +65,7 @@ print(w1, w2, score) ``` -## 2.2 详细使用, 反义词, 相似度 +## 2.2 详细使用方法一, 反义词, 相似度 ```python3 import near_synonym @@ -78,16 +84,48 @@ print(score) ``` +## 2.3 使用方法二, 基于继续训练 + promt的bert-mlm形式 +```python3 +import traceback +import os +os.environ["FLAG_MLM_ANTONYM"] = "1" # 必须先指定 + +from near_synonym import mlm_synonyms, mlm_antonyms + + +word = "喜欢" +word_antonyms = mlm_antonyms(word) +word_synonyms = mlm_synonyms(word) +print("反义词:") +print(word_antonyms) +print("近义词:") +print(word_synonyms) + +""" +反义词: +[('厌恶', 0.77), ('讨厌', 0.72), ('憎恶', 0.56), ('反恶', 0.49), ('忌恶', 0.48), ('反厌', 0.46), ('厌烦', 0.46), ('反感', 0.45)] +近义词: +[('喜好', 0.75), ('喜爱', 0.64), ('爱好', 0.54), ('倾爱', 0.5), ('爱爱', 0.49), ('喜慕', 0.49), ('向好', 0.48), ('倾向', 0.48)] +""" +``` + # 三、技术原理 ## 3.1 技术详情 ``` near-synonym, 中文反义词/近义词工具包. -流程: Word2vec -> ANN -> NLI -> Length +流程一(neg_antonym): Word2vec -> ANN -> NLI -> Length # Word2vec, 词向量, 使用skip-ngram的词向量; # ANN, 近邻搜索, 使用annoy检索召回; # NLI, 自然语言推断, 使用Roformer-sim的v2版本, 区分反义词/近义词; # Length, 惩罚项, 词语的文本长度惩罚; + +流程二(mlm_antonym): 构建提示词prompt等重新训练BERT类模型("引号等着重标注, 带句号, 不训练效果很差) -> BERT-MLM(第一个char取topk, 然后从左往右依次beam_search) +# 构建prompt: + - "xxx"的反义词是"[MASK][MASK][MASK]"。 + - "xxx"的近义词是"[MASK][MASK][MASK]"。 +# 训练MLM +# 一个char一个char地预测, 同beam_search ``` ## 3.2 TODO @@ -101,6 +139,7 @@ near-synonym, 中文反义词/近义词工具包. ## 3.3 其他实验 ``` +choice, prompt + bert-mlm; choice, 如何处理数据/模型文件, 1.huggingface_hub("√") 2.gzip compress whitin 100M in pypi("×"); fail, 使用情感识别, 取得不同情感下的词语(失败, 例如可爱/漂亮同为积极情感); fail, 使用NLI自然推理, 已有的语料是句子, 不是太适配; @@ -137,6 +176,7 @@ fail, 使用NLI自然推理, 已有的语料是句子, 不是太适配; # 六、日志 ``` +2024.10.06, 完成prompt + bert-mlm形式生成反义词/近义词; 2024.04.14, 修改词向量计算方式(句子级别), 使得句向量的相似度/近义词/反义词更准确一些(依旧很不准, 待改进); 2024.04.13, 使用huggface_hub下载数据, 即near_synonym_model目录, 在[Macropodus/near_synonym_model](https://huggingface.co/Macropodus/near_synonym_model); 2024.04.07, qwen-7b-chat模型构建28w+词典的近义词/反义词表, 即ci_atmnonym_synonym.json, v0.1.0版本; diff --git a/near_synonym/__init__.py b/near_synonym/__init__.py index d6c92d1..49e205d 100644 --- a/near_synonym/__init__.py +++ b/near_synonym/__init__.py @@ -5,9 +5,16 @@ # @function: init +import os + from near_synonym.neg_antonym import NS synonyms = NS.near_synonym antonyms = NS.near_antonym sim = NS.similarity + +if os.environ.get("FLAG_MLM_ANTONYM") == "1": + from near_synonym.mlm_antonym import MA + mlm_synonyms = MA.near_synonym + mlm_antonyms = MA.near_antonym diff --git a/near_synonym/mlm_antonym.py b/near_synonym/mlm_antonym.py new file mode 100644 index 0000000..5ab91b2 --- /dev/null +++ b/near_synonym/mlm_antonym.py @@ -0,0 +1,386 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/1/18 21:34 +# @author : Mo +# @function: BERT-MLM to Antonym + + +from __future__ import absolute_import, division, print_function +import traceback +import time +import copy +import sys +import os +path_sys = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(path_sys) +print(path_sys) +from near_synonym.tools import download_model_from_huggface, load_json +from transformers import BertForMaskedLM, BertTokenizer, BertConfig +import torch + + +class MLM4Antonym: + def __init__(self, path_pretrain_model_dir="", + path_trained_model_dir="", + device="cuda:0"): + if not path_pretrain_model_dir: + self.path_pretrain_model_dir = os.path.join(path_sys, "near_synonym/mlm_antonym_model") + self.path_trained_model_dir = os.path.join(path_sys, "near_synonym/mlm_antonym_model") + else: + self.path_pretrain_model_dir = path_pretrain_model_dir + self.path_trained_model_dir = path_trained_model_dir + self.path_w2i = os.path.join(path_sys, "near_synonym/near_synonym_model/word2vec.w2i") + self.flag_skip = False + self.device = device + self.topk_times = 5 # topk重复次数, 避免非中文的情况 + self.topk = 8 # beam-search + self.check_or_download_hf_model() # 检测模型目录是否存在, 不存在就下载模型 + self.load_trained_model() + self.flag_filter_word = False # 用于过滤词汇, [MASK]有时候可能不成词 + self.w2i = {} + if os.path.exists(self.path_w2i): + self.w2i = load_json(self.path_w2i) + if not self.w2i: + self.flag_filter_word = False + + def check_or_download_hf_model(self): + """ 从hf国内镜像加载数据 """ + if os.path.exists(self.path_pretrain_model_dir): + pass + else: + # dowload model from hf + download_model_from_huggface(repo_id="Macropodus/mlm_antonym_model") + + def load_trained_model(self): + """ 加载训练好的模型 """ + if "mlm_antonym" in self.path_trained_model_dir: + self.tokenizer = BertTokenizer.from_pretrained(self.path_trained_model_dir) + config = BertConfig.from_pretrained(pretrained_model_name_or_path=self.path_trained_model_dir) + self.model = BertForMaskedLM(config) + path_model_real = os.path.join(self.path_trained_model_dir, "pytorch_model.bin") + model_dict_new = torch.load(path_model_real, map_location=torch.device(self.device)) + model_dict_new = {k.replace("pretrain_model.", ""): v for k, v in model_dict_new.items()} + self.model.load_state_dict(model_dict_new, strict=False) + else: + self.tokenizer = BertTokenizer.from_pretrained(self.path_trained_model_dir) + self.model = BertForMaskedLM.from_pretrained(self.path_trained_model_dir) + self.model.to(self.device) + self.model.eval() + + def prompt(self, word, category="antonym"): + """ 组装提示词 """ + if category == "antonym": + input_text = '"{}"的反义词是"{}"。'.format(word, len(word) * '[MASK]') + elif category == "same": + input_text = '"{}"的同义词是"{}"。'.format(word, len(word) * '[MASK]') + else: + input_text = '"{}"的近义词是"{}"。'.format(word, len(word) * '[MASK]') + # input_text = '"{}"的相似词是"{}"。'.format(word, len(word) * '[MASK]') + # input_text = '你觉得{}的近义词是{}。'.format(len_word * '[MASK]', word) + return input_text + + def decode(self, input_ids): + """ 解码 """ + return self.tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=self.flag_skip) + + def predict(self, word, topk=8, category="antonym"): + """ 推理 + category可选: "antonym", "synonym", "same" + """ + topk_times = self.topk_times * topk + len_word = len(word) + text = self.prompt(word, category=category) + count_mask = text.count("[MASK]") + # 对输入句子进行编码 + inputs = self.tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"] + input_ids = input_ids.to(self.device) + input_bs = input_ids.repeat(topk, 1) + score_bs = [0] * topk + res = [] + # 进行预测第一个char, 保证取得topk个不一样的char, 类似beam-search + with torch.no_grad(): + outputs = self.model(input_ids=input_ids) + # 获取预测结果中的logits + logits = outputs.logits + # 获取[MASK]标记位置的索引(假设句子中只有一个[MASK]标记) + mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] + # 获取[MASK]标记位置上的预测结果 + mask_token_logits = logits[0, mask_token_index, :] + # 获取概率最高的token及其对应的id和文本 + top5_softmax = torch.softmax(mask_token_logits, dim=-1) + largest, indices = torch.topk(top5_softmax, topk_times, dim=1, largest=True, sorted=True) + topk_ids = indices[0] + topk_prob = largest[0] + count = 0 + ### input_bs填充第一个[MASK], 获取word_list + for idx, topk_id in enumerate(topk_ids): + topk_id_token = self.tokenizer.decode([topk_id], skip_special_tokens=True) + # token存在且是中文, 取topk个 + if topk_id_token.strip() and "\u4e00" <= topk_id_token <= "\u9fa5" and count <= topk-1: # 必须是中文 + input_bs[count][mask_token_index[0]] = topk_id + input_bs_idx = [self.tokenizer.decode([id]) for id in input_bs[count] + ][mask_token_index[0] + 1 - len_word: mask_token_index[0] + 1] + score_bs[count] = torch.log(topk_prob[idx]) + score_count = torch.exp(score_bs[count]) + res.append(("".join(input_bs_idx), float(score_count.detach().cpu().numpy()))) + count += 1 + + ### 进行预测第一个k(1>1)个char, 保证取得topk个不一样的char, 类似beam-search + ### 得分取得topk个数值(torch.log), 然后重排序; + for i in range(count_mask-1): + # 进行预测 + with torch.no_grad(): + outputs = self.model(input_ids=input_bs) + # 获取预测结果中的logits + logits = outputs.logits + # 获取[MASK]标记位置的索引(假设句子中只有一个[MASK]标记) + mask_token_index = torch.where(input_bs == self.tokenizer.mask_token_id)[1] + # input_bs_topk = input_bs.repeat(topk, 1) + score_bs_topk = [0] * topk * topk + input_bs_topk_all = [] + res = [] + for tdx in range(topk): + # 获取[MASK]标记位置上的预测结果 + mask_token_logits = logits[tdx, mask_token_index, :] + input_bs_tdx = copy.deepcopy(input_bs[tdx]) + input_bs_tdx_topk = input_bs_tdx.repeat(topk, 1) + # 获取概率最高的token及其对应的id和文本 + top5_softmax = torch.softmax(mask_token_logits, dim=-1) + largest, indices = torch.topk(top5_softmax, topk_times, + dim=-1, largest=True, sorted=True) + topk_ids = indices[0] + topk_prob = largest[0] + count = 0 + for jdx, topk_id in enumerate(topk_ids): + topk_id_token = self.tokenizer.decode([topk_id], skip_special_tokens=True) + if topk_id_token.strip() and "\u4e00" <= topk_id_token <= "\u9fa5" and count <= topk-1: # 必须是中文 + input_bs_tdx_topk[count][mask_token_index[0]] = topk_id + input_bs_topk_idx = [self.tokenizer.decode([id]) for id in input_bs_tdx_topk[count] + ][1:mask_token_index[0] + 1] + score_bs_topk[tdx * topk + count] = score_bs[tdx] + torch.log(topk_prob[count]) + score_count = torch.exp(score_bs_topk[tdx * topk + count]) + res.append(("".join(input_bs_topk_idx[mask_token_index[0] - len_word:mask_token_index[0] + 1]), + float(score_count.detach().cpu().numpy()))) + input_bs_topk_all.append([input_bs_tdx_topk[count], score_count]) + count += 1 + input_bs_topk_all_sort = [a[0].unsqueeze(0) for a in sorted(iter(input_bs_topk_all), + key=lambda x: x[-1], reverse=True)[:topk]] + input_bs = torch.cat(input_bs_topk_all_sort, dim=0) + ### 词典过滤 + if self.flag_filter_word: + res_sort = sorted(iter(res), key=lambda x: x[-1], reverse=True) + res = [s for s in res_sort + if s[0] != word and s[0] in self.w2i][:topk] + ### 如果全部不在词典就保留一个 + if not res: + res = res_sort[:1] + else: + res_sort = sorted(iter(res), key=lambda x: x[-1], reverse=True) + res = [s for s in res_sort + if s[0] != word][:topk] + return res + + def near_antonym(self, word, topk=8, flag_normalize=True): + """ 获取反义词 """ + topk = topk or self.topk + word_list = self.predict(word, topk=topk, category="antonym") + if flag_normalize: + score_total = sum([w[-1] for w in word_list]) + word_list = [(w[0], round(0.4 + 0.59*min(1, w[1]/score_total*2), 2)) for w in word_list] + return word_list + + def near_synonym(self, word, topk=8, flag_normalize=True): + """ 获取近义词 """ + topk = topk or self.topk + word_list = self.predict(word, topk=topk, category="synonym") + if flag_normalize: + score_total = sum([w[-1] for w in word_list]) + word_list = [(w[0], round(0.4 + 0.59*min(1, w[1]/score_total*2), 2)) for w in word_list] + return word_list + + def near_same(self, word, topk=8, flag_normalize=True): + """ 获取同义词 """ + topk = topk or self.topk + word_list = self.predict(word, topk=topk, category="same") + if flag_normalize: + score_total = sum([w[-1] for w in word_list]) + word_list = [(w[0], round(0.4 + 0.59 * min(1, w[1] / score_total * 2), 2)) for w in word_list] + return word_list + + +def tet_predict(): + """ 测试 predict 接口 """ + path_trained_model_dir = "./mlm_antonym_model" + model = MLM4Antonym(path_trained_model_dir, path_trained_model_dir) + model.flag_skip = False + # model.topk = 16 # beam-search + word = "喜欢" + categorys = ["antonym", "synonym", "same"] + for category in categorys: + time_start = time.time() + res = model.predict(word, category=category) + time_end = time.time() + print(f"{word}的{category}: ") + for r in res: + print(r) + print(time_end - time_start) + while 1: + try: + print("请输入:") + word = input() + word = word.strip() + time_start = time.time() + categorys = ["antonym", "synonym", "same"] + for category in categorys: + res = model.predict(word, category=category) + time_end = time.time() + print(f"{word}的{category}: ") + for r in res: + print(r) + print(time_end - time_start) + except Exception as e: + print(traceback.print_exc()) +def tet_antonym(): + """ 测试 对外函数 """ + ### 可以放训练好的模型, 也可以放开源的bert类模型(效果差一些) + path_trained_model_dir = "./mlm_antonym_model" + # path_trained_model_dir = "E:/DATA/bert-model/00_pytorch/MacBERT-chinese_finetuned_correction" + # path_trained_model_dir = "E:/DATA/bert-model/00_pytorch/LLM/hfl_chinese-macbert-base" + # path_trained_model_dir = "E:/DATA/bert-model/00_pytorch/bert-base-chinese" + # path_trained_model_dir = "E:/DATA/bert-model/00_pytorch/chinese-roberta-wwm-ext" + # path_trained_model_dir = "E:/DATA/bert-model/00_pytorch/pai-ckbert-base-zh" + model = MLM4Antonym(path_trained_model_dir, path_trained_model_dir) + model.flag_skip = False + # model.topk = 16 # beam-search + word = "喜欢" + + ### antonym + time_start = time.time() + res = model.near_antonym(word) + time_end = time.time() + print(f"{word}的antonym: ") + for r in res: + print(r) + print(time_end - time_start) + ### synonym + time_start = time.time() + res = model.near_synonym(word) + time_end = time.time() + print(f"{word}的synonym: ") + for r in res: + print(r) + print(time_end - time_start) + + + while 1: + try: + print("请输入:") + word = input() + word = word.strip() + ### antonym + time_start = time.time() + res = model.near_antonym(word) + time_end = time.time() + print(f"{word}的antonym: ") + for r in res: + print(r) + print(time_end - time_start) + ### synonym + time_start = time.time() + res = model.near_synonym(word) + time_end = time.time() + print(f"{word}的synonym: ") + for r in res: + print(r) + print(time_end - time_start) + except Exception as e: + print(traceback.print_exc()) + + +### 初始化bert-mlm +MA = MLM4Antonym() + + +if __name__ == "__main__": + yz = 0 + + """ 测试 predict 函数 """ + # tet_predict() + """ 测试 对外函数 """ + # tet_antonym() + + """ 测试 初始化模型 """ + # MA.flag_filter_word = True # 用于过滤词汇, [MASK]有时候可能不成词 + # MA.flag_skip = False # decode的时候, 特殊字符是否跳过 + # MA.topk_times = 5 # topk重复次数, 避免非中文的情况 + # MA.topk = 8 # eg.5, 16, 32; 类似beam-search, 但是第一个char的topk必须全选 + + word = "喜欢" + ### antonym + time_start = time.time() + res = MA.near_antonym(word, topk=8) + time_end = time.time() + print(f"{word}的antonym: ") + for r in res: + print(r) + print(time_end - time_start) + ### synonym + time_start = time.time() + res = MA.near_synonym(word, topk=8) + time_end = time.time() + print(f"{word}的synonym: ") + for r in res: + print(r) + print(time_end - time_start) + + + while 1: + try: + print("请输入:") + word = input() + word = word.strip() + ### antonym + time_start = time.time() + res = MA.near_antonym(word) + time_end = time.time() + print(f"{word}的antonym: ") + for r in res: + print(r) + print(time_end - time_start) + ### synonym + time_start = time.time() + res = MA.near_synonym(word) + time_end = time.time() + print(f"{word}的synonym: ") + for r in res: + print(r) + print(time_end - time_start) + except Exception as e: + print(traceback.print_exc()) + + +""" +喜欢的antonym: +('厌恶', 0.77) +('讨厌', 0.72) +('憎恶', 0.56) +('反恶', 0.49) +('忌恶', 0.48) +('反厌', 0.46) +('厌烦', 0.46) +('反感', 0.45) +4.830108404159546 +喜欢的synonym: +('喜好', 0.75) +('喜爱', 0.64) +('爱好', 0.54) +('倾爱', 0.5) +('爱爱', 0.49) +('喜慕', 0.49) +('向好', 0.48) +('倾向', 0.48) +0.15957283973693848 +""" + + diff --git a/near_synonym/version.py b/near_synonym/version.py index f9690d2..d978806 100644 --- a/near_synonym/version.py +++ b/near_synonym/version.py @@ -5,5 +5,5 @@ # @function: -__version__ = "0.1.0" +__version__ = "0.3.0" diff --git a/requirements-all.txt b/requirements-all.txt index 5f8743d..2fcfce0 100644 --- a/requirements-all.txt +++ b/requirements-all.txt @@ -1,5 +1,7 @@ huggingface_hub>=0.20.3 onnxruntime-gpu==1.12.1 +transformers>=4.30.2 smart_open>=4.2.0 annoy>=1.17.3 -numpy>=1.20.3 \ No newline at end of file +numpy>=1.20.3 +torch>=1.8.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 43f1fe8..9858240 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ huggingface_hub>=0.20.3 onnxruntime-gpu==1.12.1 +transformers>=4.30.2 smart_open>=4.2.0 annoy>=1.17.3 diff --git a/tet/tet_mlm_antonym.py b/tet/tet_mlm_antonym.py new file mode 100644 index 0000000..2af365b --- /dev/null +++ b/tet/tet_mlm_antonym.py @@ -0,0 +1,75 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/2/29 21:52 +# @author : Mo +# @function: test units + + +import traceback +import os +os.environ["FLAG_MLM_ANTONYM"] = "1" + +from near_synonym import mlm_synonyms, mlm_antonyms + + + +word = "喜欢" +word_antonyms = mlm_antonyms(word, topk=5) +word_synonyms = mlm_synonyms(word, topk=5) +print("反义词:") +print(word_antonyms) +print("近义词:") +print(word_synonyms) + +word_antonym = [("前", "后"), + ("冷", "热"), + ("高", "矮"), + ("进", "退"), + ("死", "活"), + ("快", "慢"), + ("轻", "重"), + ("缓", "急"), + ("宽", "窄"), + ("强", "弱"), + ("宽阔", "狭窄"), + ("平静", "动荡"), + ("加重", "减轻"), + ("缓慢", "快速"), + ("节省", "浪费"), + ("分散", "聚拢"), + ("茂盛", "枯萎"), + ("美丽", "丑陋"), + ("静寂", "热闹"), + ("清楚", "模糊"), + ("恍恍惚惚", "清清楚楚"), + ("一模一样", "截然不同"), + ("柳暗花明", "山穷水尽"), + ("风平浪静", "风号浪啸"), + ("人声鼎沸", "鸦雀无声"), + ("勤勤恳恳", "懒懒散散"), + ("一丝不苟", "敷衍了事"), + ("隐隐约约", "清清楚楚"), + ("享誉世界", "默默无闻"), + ("相背而行", "相向而行"), + ] +for w in word_antonym: + w_ant = mlm_antonyms(w[0]) + print(w[0], w[1], w_ant[0][0], w_ant) + + + + +while True: + try: + print("请输入word: ") + word = input() + if word.strip(): + word_antonyms = mlm_antonyms(word) + word_synonyms = mlm_synonyms(word) + print("反义词:") + print(word_antonyms) + print("近义词:") + print(word_synonyms) + except Exception as e: + print(traceback.print_exc()) +