Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 添加模型mucgec_bart #507

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions README.md

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion examples/evaluate_models/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys

sys.path.append("../..")

from pycorrector import eval_sighan2015_by_model_batch


Expand Down Expand Up @@ -61,6 +60,14 @@ def main(args):
model = GptCorrector()
eval_sighan2015_by_model_batch(model.correct_batch)
# chatglm3-6b-csc: Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
elif args.model=="mucgec_bart":
import sys
sys.path.append("./")
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector
model = MuCGECBartCorrector()
eval_sighan2015_by_model_batch(model.correct_batch)
# 该数据集无法体现模型的能力, 表现高于正确标准答案
# Sentence Level: acc:0.2645, precision:0.2442, recall:0.2339, f1:0.2389, cost time:346.37 s, total num: 1100
else:
raise ValueError('model name error.')

Expand Down
19 changes: 19 additions & 0 deletions examples/mucgec_bart/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys

sys.path.append('../..')
from pycorrector import MuCGECBartCorrector
from pycorrector.utils.sentence_utils import is_not_chinese_error


if __name__ == "__main__":
bc = MuCGECBartCorrector()
result = bc.correct_batch(['这洋的话,下一年的福气来到自己身上。', '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', '随着中国经济突飞猛近,建造工业与日俱增']+["北京是中国的都。", "他说:”我最爱的运动是打蓝球“", "我每天大约喝5次水左右。", "今天,我非常开开心。"])
print(result)
long_text = "在一个充满生活热闹和忙碌的城市中,有一个年轻人名叫李华。他生活在北京,这座充满着现代化建筑和繁忙街道的都市。每天,他都要穿行在拥挤的人群中,追逐着自己的梦想和生活节奏。\n\n李华从小就听祖辈讲述关于福气和努力的故事。他相信,“这洋的话,下一年的福气来到自己身上”。因此,尽管每天都很忙碌,他总是尽力保持乐观和积极。\n\n某天早晨,李华骑着自行车准备去上班。北京的交通总是非常繁忙,尤其是在早高峰时段。他经过一个交通路口,看到至少两个交警正在维持交通秩序。这些交警穿着整齐的制服,手势有序而又果断,让整个路口的车辆有条不紊地行驶着。这让李华想起了他父亲曾经告诫过他的话:“在拥挤的时间里,为了让人们遵守交通规则,至少要派两个警察或者交通管理者。”\n\n李华心中感慨万千,他想要在自己的生活中也如此积极地影响他人。他虽然只是一名普通的白领,却希望能够通过自己的努力和行动,为这座城市的安全与和谐贡献一份力量。\n\n随着时间的推移,中国的经济不断发展,北京的建设也日益繁荣。李华所在的公司也因为他的努力和创新精神而蓬勃发展。他喜欢打篮球,每周都会和朋友们一起去运动场,放松身心。他也十分重视健康,每天都保持适量的饮水量,大约喝五次左右。\n\n今天,李华觉得格外开心。他意识到,自己虽然只是一个普通人,却通过日复一日的努力,终于在生活中找到了属于自己的那份福气。他明白了祖辈们口中的那句话的含义——“这洋的话,下一年的福气来到自己身上”,并且深信不疑。\n\n在这个充满希望和机遇的时代里,李华将继续努力工作,为自己的梦想而奋斗,也希望能够在这座城市中留下自己的一份足迹,为他人带来更多的希望和正能量。\n\n这就是李华的故事,一个在现代城市中追寻梦想和福气的普通青年。"
result = bc.correct(long_text)
print(result)
# 模型结果后处理
result = bc.correct(long_text, ignore_function=is_not_chinese_error)
print(result)
for e in result["errors"]:
assert result["source"][e[2]] == e[0]
19 changes: 19 additions & 0 deletions examples/nasgec_bart/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys

sys.path.append('../..')
from pycorrector import NaSGECBartCorrector
from pycorrector.utils.sentence_utils import is_not_chinese_error


if __name__ == "__main__":
bc = NaSGECBartCorrector()
result = bc.correct_batch(['这洋的话,下一年的福气来到自己身上。', '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', '随着中国经济突飞猛近,建造工业与日俱增']+["北京是中国的都。", "他说:”我最爱的运动是打蓝球“", "我每天大约喝5次水左右。", "今天,我非常开开心。"])
print(result)
long_text = "在一个充满生活热闹和忙碌的城市中,有一个年轻人名叫李华。他生活在北京,这座充满着现代化建筑和繁忙街道的都市。每天,他都要穿行在拥挤的人群中,追逐着自己的梦想和生活节奏。\n\n李华从小就听祖辈讲述关于福气和努力的故事。他相信,“这洋的话,下一年的福气来到自己身上”。因此,尽管每天都很忙碌,他总是尽力保持乐观和积极。\n\n某天早晨,李华骑着自行车准备去上班。北京的交通总是非常繁忙,尤其是在早高峰时段。他经过一个交通路口,看到至少两个交警正在维持交通秩序。这些交警穿着整齐的制服,手势有序而又果断,让整个路口的车辆有条不紊地行驶着。这让李华想起了他父亲曾经告诫过他的话:“在拥挤的时间里,为了让人们遵守交通规则,至少要派两个警察或者交通管理者。”\n\n李华心中感慨万千,他想要在自己的生活中也如此积极地影响他人。他虽然只是一名普通的白领,却希望能够通过自己的努力和行动,为这座城市的安全与和谐贡献一份力量。\n\n随着时间的推移,中国的经济不断发展,北京的建设也日益繁荣。李华所在的公司也因为他的努力和创新精神而蓬勃发展。他喜欢打篮球,每周都会和朋友们一起去运动场,放松身心。他也十分重视健康,每天都保持适量的饮水量,大约喝五次左右。\n\n今天,李华觉得格外开心。他意识到,自己虽然只是一个普通人,却通过日复一日的努力,终于在生活中找到了属于自己的那份福气。他明白了祖辈们口中的那句话的含义——“这洋的话,下一年的福气来到自己身上”,并且深信不疑。\n\n在这个充满希望和机遇的时代里,李华将继续努力工作,为自己的梦想而奋斗,也希望能够在这座城市中留下自己的一份足迹,为他人带来更多的希望和正能量。\n\n这就是李华的故事,一个在现代城市中追寻梦想和福气的普通青年。"
result = bc.correct(long_text)
print(result)
# 模型结果后处理
result = bc.correct(long_text, ignore_function=is_not_chinese_error)
print(result)
for e in result["errors"]:
assert result["source"][e[2]] == e[0]
2 changes: 2 additions & 0 deletions pycorrector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from pycorrector.proper_corrector import ProperCorrector
from pycorrector.seq2seq.conv_seq2seq_corrector import ConvSeq2SeqCorrector
from pycorrector.t5.t5_corrector import T5Corrector
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector
from pycorrector.nasgec_bart.nasgec_bart_corrector import NaSGECBartCorrector
from pycorrector.utils import text_utils, tokenizer, io_utils, math_utils, evaluate_utils
from pycorrector.utils.evaluate_utils import eval_sighan2015_by_model_batch, eval_sighan2015_by_model
from pycorrector.utils.get_file import get_file
Expand Down
Empty file.
70 changes: 70 additions & 0 deletions pycorrector/mucgec_bart/monkey_pack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from modelscope.pipelines import Pipeline
from typing import Any, Dict, List
from modelscope.utils.constant import Frameworks
from modelscope.utils.device import device_placement

# 批量推理问题
def _process_batch(self, input: List, batch_size,
**kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

# batch data
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
batched_out = self._batch(preprocessed_list)
if self._auto_collate:
batched_out = self._collate_fn(batched_out)
batched_out = self.forward(batched_out,
**forward_params)
else:
batched_out = self._batch(preprocessed_list)
batched_out = self.forward(batched_out, **forward_params)
model_name = kwargs.get("model_name")
# print("model_name", model_name)
if model_name=="batch_correct":
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
else:
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
if isinstance(element[0], torch.Tensor):
out[k] = type(element)(
e[batch_idx:batch_idx + 1]
for e in element)
else:
# Compatible with traditional pipelines
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)

return output_list


Pipeline._process_batch = _process_batch
91 changes: 91 additions & 0 deletions pycorrector/mucgec_bart/mucgec_bart_corrector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
import os
import time
from typing import List

import torch
from loguru import logger
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import sys
sys.path.append('../..')
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from pycorrector.mucgec_bart.monkey_pack import Pipeline
from pycorrector.utils.sentence_utils import long_sentence_split
import difflib


class MuCGECBartCorrector:
def __init__(self, model_name_or_path: str = "damo/nlp_bart_text-error-correction_chinese"):
t1 = time.time()
self.model = pipeline(Tasks.text_error_correction, model=model_name_or_path)
logger.debug("Device: {}".format(device))
logger.debug('Loaded mucgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1))

def _predict(self, sentences, batch_size=32, max_length=128, silent=True):
raise NotImplementedError


def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None):
"""
批量句子纠错
:param sentences: list[str], sentence list
:param max_length: int, max length of each sentence
:param batch_size: int, bz
:param silent: bool, show log
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
result = self.model(sentences, batch_size=batch_size, model_name="batch_correct")
start_idx = 0
n = len(sentences)
data = []
result = [r["output"] for r in result]
for i in range(n):
a, b = sentences[i], result[i]
if len(a)==0 or len(b)==0 or a=="\n":
start_idx += len(a)
return
s = difflib.SequenceMatcher(None, a, b)
errors = []
offset = 0
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag!="equal":
e = [a[i1:i2], b[j1+offset:j2+offset], i1]
if ignore_function and ignore_function(e):
# 因为不认为是错误, 所以改回原来的偏移值
b = b[:j1] + a[i1:i2] + b[j2:]
offset += i2-i1-j2+j1
continue

errors.append(tuple(e))
data.append({"source": a, "target": b, "errors": errors})
return data


def correct(self, sentence: str, **kwargs):
"""长句改为短句, 可直接调用长文本"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
batch_results = self.correct_batch(sentences, **kwargs)
source, target, errors = "", "", []
for sr in batch_results:
ll = len(source)
source += sr["source"]
target += sr["target"]
for e in sr["errors"]:
# 改写位置
e = list(e)
e[2] += ll
errors.append(tuple(e))

return {"source": source, "target": target, "errors": errors}






Empty file.
96 changes: 96 additions & 0 deletions pycorrector/nasgec_bart/nasgec_bart_corrector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
import os
import time
from typing import List

import torch
from loguru import logger
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import sys
sys.path.append('../..')
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline
from pycorrector.utils.sentence_utils import long_sentence_split
import difflib


class NaSGECBartCorrector:
def __init__(self, model_name_or_path: str = "HillZhang/real_learner_bart_CGEC"):
# https://github.com/HillZhang1999/NaSGEC
t1 = time.time()
self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path)
logger.debug("Device: {}".format(device))
logger.debug('Loaded nasgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1))

def _predict(self, sentences, batch_size=32, max_length=128, silent=True):
raise NotImplementedError


def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None):
"""
批量句子纠错
:param sentences: list[str], sentence list
:param max_length: int, max length of each sentence
:param batch_size: int, bz
:param silent: bool, show log
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
encoded_input = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
if "token_type_ids" in encoded_input:
del encoded_input["token_type_ids"]
output = self.model.generate(**encoded_input)
result = self.tokenizer.batch_decode(output, skip_special_tokens=True)
start_idx = 0
n = len(sentences)
data = []
result = [r.replace(" ", "") for r in result]
print(result)
for i in range(n):
a, b = sentences[i], result[i]
if len(a)==0 or len(b)==0 or a=="\n":
start_idx += len(a)
return
s = difflib.SequenceMatcher(None, a, b)
errors = []
offset = 0
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag!="equal":
e = [a[i1:i2], b[j1+offset:j2+offset], i1]
if ignore_function and ignore_function(e):
# 因为不认为是错误, 所以改回原来的偏移值
b = b[:j1] + a[i1:i2] + b[j2:]
offset += i2-i1-j2+j1
continue

errors.append(tuple(e))
data.append({"source": a, "target": b, "errors": errors})
return data


def correct(self, sentence: str, **kwargs):
"""长句改为短句, 可直接调用长文本"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
batch_results = self.correct_batch(sentences, **kwargs)
source, target, errors = "", "", []
for sr in batch_results:
ll = len(source)
source += sr["source"]
target += sr["target"]
for e in sr["errors"]:
# 改写位置
e = list(e)
e[2] += ll
errors.append(tuple(e))

return {"source": source, "target": target, "errors": errors}






Loading