diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 1a266897..6ce86c38 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -44,7 +44,7 @@ namespace kiwi static constexpr int32_t nonVocab = -1; - HiddenMember, sizeof(Vector) * 2> sents; + HiddenMember, sizeof(Vector) * 2> sents; std::shared_ptr knlm; std::unique_ptr workers; std::discrete_distribution<> dropout; @@ -55,6 +55,7 @@ namespace kiwi Deque> futures; const Vector* morphemes = nullptr; const Vector* forms = nullptr; + size_t knlmVocabSize = 0; size_t batchSize = 0; size_t windowSize = 0; size_t totalTokens = 0; @@ -87,12 +88,13 @@ namespace kiwi size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut); size_t vocabSize() const { return vocabToToken.size(); } + size_t getKnlmVocabSize() const; size_t ngramNodeSize() const; const MorphemeRaw& vocabInfo(uint32_t vocab) const; std::u16string vocabForm(uint32_t vocab) const; std::vector estimVocabFrequency() const; - Range::const_iterator> getSent(size_t idx) const; - std::vector getAugmentedSent(size_t idx); + Range::const_iterator> getSent(size_t idx) const; + std::vector getAugmentedSent(size_t idx); }; } diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index c8ebede1..791a5335 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -546,14 +546,19 @@ namespace kiwi FormRaw& addForm(const KString& form); size_t addForm(Vector& newForms, UnorderedMap& newFormMap, KString form) const; - using MorphemeMap = UnorderedMap, std::pair>; + using MorphemeMap = UnorderedMap, std::pair>; template MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter); - MorphemeMap restoreMorphemeMap() const; + MorphemeMap restoreMorphemeMap(bool separateDefaultMorpheme = false) const; + template + void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const; + + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; void updateForms(); void updateMorphemes(); @@ -610,6 +615,7 @@ namespace kiwi size_t lmMinCnt = 1; size_t lmLastOrderMinCnt = 2; size_t numWorkers = 1; + size_t sbgSize = 1000000; bool useLmTagHistory = true; bool quantizeLm = true; bool compressLm = true; @@ -799,6 +805,7 @@ namespace kiwi double dropoutProb = 0, const TokenFilter& tokenFilter = {}, double splitRatio = 0, + bool separateDefaultMorpheme = false, HSDataset* splitDataset = nullptr ) const; }; diff --git a/src/Dataset.cpp b/src/Dataset.cpp index 4a266ec4..bf9bd1a4 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -258,6 +258,11 @@ std::u16string HSDataset::vocabForm(uint32_t vocab) const return joinHangul((*forms)[(*morphemes)[vocabToToken[vocab]].kform].form); } +size_t HSDataset::getKnlmVocabSize() const +{ + return knlmVocabSize; +} + std::vector kiwi::HSDataset::estimVocabFrequency() const { std::vector ret(vocabSize()), augs(getDefaultMorphemeId(POSTag::max)); @@ -279,7 +284,7 @@ std::vector kiwi::HSDataset::estimVocabFrequency() const return ret; } -Range::const_iterator> HSDataset::getSent(size_t idx) const +Range::const_iterator> HSDataset::getSent(size_t idx) const { return sents.get()[idx]; } @@ -289,9 +294,9 @@ void HSDataset::seed(size_t newSeed) rng.seed(newSeed); } -std::vector HSDataset::getAugmentedSent(size_t idx) +std::vector HSDataset::getAugmentedSent(size_t idx) { - std::vector ret; + std::vector ret; auto sent = sents.get()[idx]; ret.emplace_back(*sent.begin()); for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 4ffe5dc7..b216fff5 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -36,6 +36,7 @@ KiwiBuilder& KiwiBuilder::operator=(KiwiBuilder&&) = default; namespace kiwi { static constexpr size_t defaultFormSize = defaultTagSize + 26; + static constexpr uint8_t undefSenseId = ((uint8_t)-1); } template @@ -50,11 +51,12 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem CondVowel cvowel = CondVowel::none; CondPolarity cpolar = CondPolarity::none; bool complex = false; - uint8_t senseId = 0; + uint8_t senseId = 0, origSenseId = 0; KString origForm, groupForm; int addAlias = 0; size_t origMorphId = 0; size_t groupPriority = 0; + size_t origMorphSenseCnt = 1; LongTail(const KString& _form = {}, float _weight = 0, @@ -64,13 +66,15 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem CondPolarity _cpolar = CondPolarity::none, bool _complex = false, uint8_t _senseId = 0, + uint8_t _origSenseId = 0, const KString& _origForm = {}, const KString& _groupForm = {}, int _addAlias = 0, size_t _origMorphId = 0, size_t _groupPriority = 0 ) : - form{ _form }, weight{ _weight }, tag{ _tag }, origTag{ _origTag }, cvowel{ _cvowel }, cpolar{ _cpolar }, complex{ _complex }, senseId{ _senseId }, + form{ _form }, weight{ _weight }, tag{ _tag }, origTag{ _origTag }, cvowel{ _cvowel }, cpolar{ _cpolar }, complex{ _complex }, + senseId{ _senseId }, origSenseId{ _origSenseId }, origForm{ _origForm }, groupForm{ _groupForm }, addAlias{ _addAlias }, origMorphId{ _origMorphId }, groupPriority{ _groupPriority } { } @@ -78,8 +82,9 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem Vector longTails; UnorderedMap longTailWeights; - UnorderedMap, u16string> complexChunks; + UnorderedMap, u16string> complexChunks; MorphemeMap morphMap; + UnorderedMap, Vector> morphSenseMap; UnorderedMap, size_t> groupMap; const auto& insertMorph = [&](KString&& form, float score, POSTag tag, CondVowel cvowel, CondPolarity cpolar, bool complex, uint8_t senseId, size_t origMorphemeId = 0, size_t groupId = 0) @@ -92,7 +97,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem unified = true; } - auto it = morphMap.find(make_pair(form, tag)); + auto it = morphMap.find(make_tuple(form, senseId, tag)); if (it != morphMap.end()) { // 어/아 통합 대상이면서 어xx 형태소와 아xx 형태소 모두 OOV로 취급받는 경우 @@ -121,7 +126,8 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem else { size_t mid = morphemes.size(); - morphMap.emplace(make_pair(form, tag), make_pair(origMorphemeId ? origMorphemeId : mid, mid)); + morphMap.emplace(make_tuple(form, senseId, tag), make_pair(origMorphemeId ? origMorphemeId : mid, mid)); + morphSenseMap[make_pair(form, tag)].emplace_back(senseId); fm.candidate.emplace_back(mid); morphemes.emplace_back(tag, cvowel, cpolar, complex); morphemes.back().kform = &fm - &forms[0]; @@ -163,7 +169,8 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem POSTag origTag = tag; int addAlias = 0; size_t groupPriority = 0; - uint8_t senseId = 0; + Vector senseIds; + u16string complexStr; if (fields.size() > 3) { for (size_t i = 3; i < fields.size(); ++i) @@ -171,7 +178,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem auto& f = fields[i]; if (f == u"vowel") { - if (cvowel != CondVowel::none) throw Exception{ "wrong line: " + line }; + if (cvowel != CondVowel::none) throw FormatException{ "wrong line: " + line }; cvowel = CondVowel::vowel; if (i + 1 < fields.size()) { @@ -205,7 +212,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem { if (complex) throw FormatException{ "wrong line: " + line }; complex = true; - complexChunks.emplace(make_pair(form, tag), f.substr(8).to_string()); + complexStr = f.substr(8).to_string(); } else if (f[0] == u'=') { @@ -266,7 +273,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem } else if (f[0] == '.') { - senseId = stol(f.begin() + 1, f.end()); + senseIds.emplace_back(stol(f.begin() + 1, f.end())); } else { @@ -275,20 +282,33 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem } } - if (filter(tag, morphWeight) && origMorphemeOfAlias.empty()) - { - // groupId 삽입용 long tail에 대해서 - if (form != groupForm) - { - longTails.emplace_back(form, 0, tag, POSTag::unknown, cvowel, cpolar, complex, senseId, u"", groupForm, 0, 0, groupPriority); - } + if (senseIds.empty()) senseIds.emplace_back(0); - insertMorph(move(form), morphWeight, tag, cvowel, cpolar, complex, senseId); + if (complex) + { + if (senseIds.size() > 1) throw FormatException{ "wrong line: " + line }; + complexChunks.emplace(make_tuple(form, senseIds[0], tag), move(complexStr)); } - else + + for (auto senseId : senseIds) { - longTails.emplace_back(form, altWeight < 0 ? altWeight : morphWeight, tag, origTag, cvowel, cpolar, complex, senseId, origMorphemeOfAlias, groupForm, addAlias, 0, groupPriority); - longTailWeights[tag] += morphWeight; + if (filter(tag, morphWeight) && origMorphemeOfAlias.empty()) + { + // groupId 삽입용 long tail에 대해서 + if (form != groupForm) + { + longTails.emplace_back(form, 0, tag, POSTag::unknown, cvowel, cpolar, complex, + senseId, undefSenseId, u"", groupForm, 0, 0, groupPriority); + } + + insertMorph(move(form), morphWeight, tag, cvowel, cpolar, complex, senseId); + } + else + { + longTails.emplace_back(form, altWeight < 0 ? altWeight : morphWeight, tag, origTag, cvowel, cpolar, complex, + senseId, undefSenseId, origMorphemeOfAlias, groupForm, addAlias, 0, groupPriority); + longTailWeights[tag] += morphWeight; + } } } @@ -301,8 +321,30 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem } if (p.origForm.empty()) continue; - auto it = isIrregular(p.origTag) ? morphMap.end() : morphMap.find(make_pair(p.origForm, clearIrregular(p.origTag))); - auto it2 = morphMap.find(make_pair(p.origForm, setIrregular(p.origTag))); + auto origSenseId = p.origSenseId; + size_t senseCnt = 1; + if (origSenseId == undefSenseId) + { + auto it = isIrregular(p.origTag) ? morphSenseMap.end() : morphSenseMap.find(make_pair(p.origForm, clearIrregular(p.origTag))); + auto it2 = morphSenseMap.find(make_pair(p.origForm, setIrregular(p.origTag))); + if (it != morphSenseMap.end() && it2 != morphSenseMap.end()) + { + throw FormatException{ "ambiguous base morpheme: " + utf16To8(p.origForm) + "/" + tagToString(clearIrregular(p.origTag)) }; + } + it = (it == morphSenseMap.end()) ? it2 : it; + if (it == morphSenseMap.end() || it->second.empty()) + { + origSenseId = 0; + } + else + { + origSenseId = it->second[0]; + senseCnt = it->second.size(); + } + } + + auto it = isIrregular(p.origTag) ? morphMap.end() : morphMap.find(make_tuple(p.origForm, origSenseId, clearIrregular(p.origTag))); + auto it2 = morphMap.find(make_tuple(p.origForm, origSenseId, setIrregular(p.origTag))); if (it != morphMap.end() && it2 != morphMap.end()) { @@ -311,9 +353,17 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem it = (it == morphMap.end()) ? it2 : it; if (it == morphMap.end()) { + if (p.origSenseId != undefSenseId) + { + throw FormatException{ "cannot find base morpheme: " + utf16To8(p.origForm) + "__" + to_string(p.origSenseId) + "/" + tagToString(p.origTag)}; + } + else + { throw FormatException{ "cannot find base morpheme: " + utf16To8(p.origForm) + "/" + tagToString(p.origTag) }; + } } p.origMorphId = it->second.first; + p.origMorphSenseCnt = senseCnt; if (!p.addAlias) continue; if (p.weight > 0) morphemes[it->second.first].userScore += p.weight; } @@ -332,7 +382,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem } else { - auto it = morphMap.find(make_pair(p.form, p.tag)); + auto it = morphMap.find(make_tuple(p.form, p.senseId, p.tag)); if (it != morphMap.end()) { morphemes[it->second.first].groupId = groupId; @@ -363,15 +413,24 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem { if (p.addAlias) { - float normalized = p.weight / morphemes[p.origMorphId].userScore; - float score = log(normalized); - if (p.addAlias > 1) score = 0; - else if (p.weight < 0) score = p.weight; - insertMorph(move(p.form), score, p.tag, p.cvowel, p.cpolar, p.complex, p.senseId, p.origMorphId, groupId); + for (size_t i = 0; i < p.origMorphSenseCnt; ++i) + { + float normalized = p.weight / morphemes[p.origMorphId].userScore; + float score = log(normalized); + if (p.addAlias > 1) score = 0; + else if (p.weight < 0) score = p.weight; + + auto senseId = p.senseId; + if (p.origMorphSenseCnt > 1) + { + senseId = morphemes[p.origMorphId + i].senseId; + } + insertMorph(move(p.form), score, p.tag, p.cvowel, p.cpolar, p.complex, senseId, p.origMorphId + i, groupId); + } } else { - morphMap.emplace(make_pair(move(p.form), p.tag), make_pair(p.origMorphId, p.origMorphId)); + morphMap.emplace(make_tuple(move(p.form), p.senseId, p.tag), make_pair(p.origMorphId, p.origMorphId)); } } } @@ -379,7 +438,7 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem // fill chunks of complex morphemes for (auto& morph : morphemes) { - auto it = complexChunks.find(make_pair(forms[morph.kform].form, morph.tag)); + auto it = complexChunks.find(make_tuple(forms[morph.kform].form, morph.senseId, morph.tag)); if (it == complexChunks.end()) continue; auto fd = split(it->second, u' '); @@ -387,14 +446,41 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem { throw FormatException{ "wrong position information : " + utf16To8(fd[0]) + " " + utf16To8(fd.back())}; } - auto posMap = normalizeHangulWithPosition(joinHangul(it->first.first)).second; + auto posMap = normalizeHangulWithPosition(joinHangul(get<0>(it->first))).second; for (size_t i = 0; i < fd.size() - 1; ++i) { auto f = split(fd[i], u'/'); if (f.size() != 2) throw FormatException{ "wrong format of morpheme : " + utf16To8(fd[i]) }; + uint8_t senseId = undefSenseId; + if (f[0].find(u"__") != f[0].npos) + { + auto p = f[0].find(u"__"); + auto s = f[0].substr(p + 2); + senseId = stol(s.begin(), s.end()); + f[0] = f[0].substr(0, p); + } auto norm = normalizeHangul(f[0]); auto tag = toPOSTag(f[1]); - auto it = morphMap.find(make_pair(norm, tag)); + + if (senseId == undefSenseId) + { + auto it = morphSenseMap.find(make_pair(norm, tag)); + if (it == morphSenseMap.end() || it->second.empty()) + { + throw FormatException{ "cannot find morpheme : " + utf16To8(fd[i]) }; + } + + if (it->second.size() == 1) + { + senseId = it->second[0]; + } + else + { + throw FormatException{ "ambiguous morpheme : " + utf16To8(fd[i]) }; + } + } + + auto it = morphMap.find(make_tuple(norm, senseId, tag)); if (it == morphMap.end()) { throw FormatException{ "cannot find morpheme : " + utf16To8(fd[i]) }; @@ -426,44 +512,64 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem for (auto& m : morphemes) { if (!isIrregular(m.tag)) continue; - auto it = morphMap.find(make_pair(forms[m.kform].form, clearIrregular(m.tag))); + auto it = morphMap.find(make_tuple(forms[m.kform].form, m.senseId, clearIrregular(m.tag))); if (it != morphMap.end()) continue; morphMap.emplace( - make_pair(forms[m.kform].form, clearIrregular(m.tag)), - make_pair(morphMap.find(make_pair(forms[m.kform].form, m.tag))->second.first, (size_t)(&m - morphemes.data())) + make_tuple(forms[m.kform].form, m.senseId, clearIrregular(m.tag)), + make_pair(morphMap.find(make_tuple(forms[m.kform].form, m.senseId, m.tag))->second.first, (size_t)(&m - morphemes.data())) + ); + } + + for (auto& p : morphSenseMap) + { + morphMap.emplace( + make_tuple(p.first.first, undefSenseId, p.first.second), + morphMap.find(make_tuple(p.first.first, p.second[0], p.first.second))->second ); } return morphMap; } -auto KiwiBuilder::restoreMorphemeMap() const -> MorphemeMap +auto KiwiBuilder::restoreMorphemeMap(bool separateDefaultMorpheme) const -> MorphemeMap { MorphemeMap ret; for (size_t i = defaultTagSize + 1; i < morphemes.size(); ++i) { size_t id = morphemes[i].lmMorphemeId; - if (!id) id = i; - ret.emplace(make_pair(forms[morphemes[i].kform].form, morphemes[i].tag), make_pair(id, id)); + if (!id) + { + id = i; + } + else if (separateDefaultMorpheme && id < defaultFormSize + 2) + { + id = i; + } + ret.emplace(make_tuple(forms[morphemes[i].kform].form, morphemes[i].senseId, morphemes[i].tag), make_pair(id, id)); + ret.emplace(make_tuple(forms[morphemes[i].kform].form, undefSenseId, morphemes[i].tag), make_pair(id, id)); } for (auto& m : morphemes) { if (!isIrregular(m.tag)) continue; - auto it = ret.find(make_pair(forms[m.kform].form, clearIrregular(m.tag))); + auto it = ret.find(make_tuple(forms[m.kform].form, m.senseId, clearIrregular(m.tag))); if (it != ret.end()) continue; ret.emplace( - make_pair(forms[m.kform].form, clearIrregular(m.tag)), - make_pair(ret.find(make_pair(forms[m.kform].form, m.tag))->second.first, (size_t)(&m - morphemes.data())) + make_tuple(forms[m.kform].form, m.senseId, clearIrregular(m.tag)), + make_pair(ret.find(make_tuple(forms[m.kform].form, m.senseId, m.tag))->second.first, (size_t)(&m - morphemes.data())) ); } return ret; } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, +template +void KiwiBuilder::_addCorpusTo( + RaggedVector& out, + std::istream& is, + MorphemeMap& morphMap, double splitRatio, - RaggedVector* splitOut + RaggedVector* splitOut ) const { - Vector wids; + Vector wids; double splitCnt = 0; size_t numLine = 0; string line; @@ -492,6 +598,14 @@ void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, Mor { auto f = normalizeHangul(fields[i]); if (f.empty()) continue; + auto senseId = 0; + auto spos = f.find(u"__"); + if (spos != f.npos) + { + auto s = f.substr(spos + 2); + senseId = stol(s.begin(), s.end()); + f = f.substr(0, spos); + } auto t = toPOSTag(fields[i + 1]); if (t == POSTag::max && !alreadyPrintError) @@ -505,7 +619,7 @@ void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, Mor f[0] = u'어'; } - auto it = morphMap.find(make_pair(f, t)); + auto it = morphMap.find(make_tuple(f, senseId, t)); if (it != morphMap.end()) { auto& morph = morphemes[it->second.first]; @@ -540,6 +654,17 @@ void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, Mor } } + if (senseId) + { + it = morphMap.find(make_tuple(f, undefSenseId, t)); + if (it != morphMap.end()) + { + cerr << "Wrong senseId for '" << utf16To8(joinHangul(f)) << "' at line " << numLine << " :\t" << line << endl; + wids.emplace_back(it->second.first); + continue; + } + } + if (t < POSTag::p && t != POSTag::unknown) { wids.emplace_back(getDefaultMorphemeId(t)); @@ -551,6 +676,21 @@ void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, Mor } } +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +{ + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); +} + +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +{ + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); +} + +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +{ + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); +} + void KiwiBuilder::updateForms() { vector> formOrder; @@ -723,6 +863,15 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) new (&pool) utils::ThreadPool{ args.numWorkers }; } auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); + // discount for bos node cnt + if (args.useLmTagHistory) + { + cntNodes.root().getNext(lmVocabSize)->val /= 2; + } + else + { + cntNodes.root().getNext(0)->val /= 2; + } std::vector minCnts(args.lmOrder, args.lmMinCnt); minCnts.back() = args.lmLastOrderMinCnt; langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( @@ -853,7 +1002,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) return true; }; - sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, 1000000, langMdl.knlm->nonLeafNodeSize() }; + sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, langMdl.knlm->nonLeafNodeSize() }; Vector lmLogProbs; Vector baseNodes; auto tc = sbg.newContext(); @@ -945,7 +1094,13 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) ofstream ofs{ modelPath + "/sbg.result.log" }; sbg.printParameters(ofs << "AvgLL: " << llMean << "\n", [&](size_t v) { - return utf16To8(joinHangul(forms[morphemes[v].kform].form)) + "/" + tagToString(morphemes[v].tag); + auto s = utf16To8(joinHangul(forms[morphemes[v].kform].form)); + if (morphemes[v].senseId) + { + s += "__"; + s += to_string(morphemes[v].senseId); + } + return s + "/" + tagToString(morphemes[v].tag); }); { @@ -2028,10 +2183,11 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, double dropoutProb, const TokenFilter& tokenFilter, double splitRatio, + bool separateDefaultMorpheme, HSDataset* splitDataset ) const { - auto realMorph = restoreMorphemeMap(); + auto realMorph = restoreMorphemeMap(separateDefaultMorpheme); HSDataset dataset{ batchSize, windowSize, numWorkers, dropoutProb }; auto& sents = dataset.sents.get(); dataset.knlm = langMdl.knlm; @@ -2059,8 +2215,14 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, tokenSize = std::max(tokenSize, sents.raw().empty() ? (size_t)0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1); } + const size_t knlmVocabSize = langMdl.knlm->getHeader().vocab_size; + size_t filteredKnlmVocabSize = 0; for (size_t i = 0; i < tokenSize; ++i) { + if (i == knlmVocabSize) + { + filteredKnlmVocabSize = dataset.vocabToToken.size(); + } if (tokenFilter && !tokenFilter(joinHangul(forms[morphemes[i].kform].form), morphemes[i].tag)) { dataset.tokenToVocab.emplace_back(HSDataset::nonVocab); @@ -2069,6 +2231,11 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, dataset.tokenToVocab.emplace_back(dataset.vocabToToken.size()); dataset.vocabToToken.emplace_back(i); } + if (tokenSize == knlmVocabSize) + { + filteredKnlmVocabSize = dataset.vocabToToken.size(); + } + dataset.knlmVocabSize = filteredKnlmVocabSize; for (size_t i = 0; i < sents.size(); ++i) { @@ -2079,6 +2246,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { splitDataset->tokenToVocab = dataset.tokenToVocab; splitDataset->vocabToToken = dataset.vocabToToken; + splitDataset->knlmVocabSize = dataset.knlmVocabSize; for (size_t i = 0; i < splitDataset->sents.get().size(); ++i) { splitDataset->totalTokens += splitDataset->numValidTokensInSent(i) - 1; diff --git a/src/SwTokenizer.cpp b/src/SwTokenizer.cpp index 5beb8407..ad948989 100644 --- a/src/SwTokenizer.cpp +++ b/src/SwTokenizer.cpp @@ -2163,7 +2163,8 @@ float UnigramSwTrainer::buildSubwordVocabs(const size_t minCnt, const size_t max bool isSubword = s[0] != u' '; size_t realSize = s.size() - (isSubword ? 0 : 1); if (realSize <= 1) return true; - if (isLowSurrogate(s.front()) || isHighSurrogate(s.back())) return false; + if (isLowSurrogate(s.front())) return false; + if (isHighSurrogate(s.back())) return true; if (count(s.begin(), s.end(), '\x00') || count(s.begin(), s.end(), '\x01') || count(s.begin() + 1, s.end(), u' ') || s.size() > maxPrefixLength) return false; if (trainConfig.removeRepetitive && testRepetition(s.data() + (isSubword ? 0 : 1), s.size() - (isSubword ? 0 : 1))) return false; diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index b47d4ba9..66a4815e 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -402,7 +402,7 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, windowSize, 1, 0., tokenFilter, 0.1, &devset); + trainset = kw.makeHSDataset(data, batchSize, windowSize, 1, 0., tokenFilter, 0.1, false, &devset); for (size_t i = 0; i < 2; ++i) { { diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 07c18d71..1bef793c 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -130,7 +130,10 @@ Evaluator::Score Evaluator::evaluate() ostream& operator<<(ostream& o, const kiwi::TokenInfo& t) { - return o << utf16To8(t.str) << "/" << kiwi::tagToString(t.tag); + o << utf16To8(t.str); + if (t.senseId) o << "__" << (int)t.senseId; + o << "/" << kiwi::tagToString(t.tag); + return o; } void Evaluator::TestResult::writeResult(ostream& out) const diff --git a/tools/model_builder.cpp b/tools/model_builder.cpp index 7776b68f..199e13ab 100644 --- a/tools/model_builder.cpp +++ b/tools/model_builder.cpp @@ -52,6 +52,7 @@ int main(int argc, const char* argv[]) ValueArg lmMinCnt{ "", "min_cnt", "min count of LM", false, 1, "int" }; ValueArg lmLastOrderMinCnt{ "", "last_min_cnt", "min count of the last order of LM", false, 2, "int" }; ValueArg output{ "o", "output", "output model path", true, "", "string" }; + ValueArg sbgSize{ "", "sbg_size", "sbg size", false, 1000000, "int" }; UnlabeledMultiArg inputs{ "inputs", "input copora", true, "string" }; cmd.add(output); @@ -66,6 +67,7 @@ int main(int argc, const char* argv[]) cmd.add(lmMinCnt); cmd.add(lmLastOrderMinCnt); cmd.add(workers); + cmd.add(sbgSize); try { @@ -87,6 +89,7 @@ int main(int argc, const char* argv[]) args.lmMinCnt = lmMinCnt; args.lmLastOrderMinCnt = lmLastOrderMinCnt; args.numWorkers = workers; + args.sbgSize = sbgSize; return run(args, output, skipBigram); }