From 36332210ce366baa6661aee42483729bdeecda5f Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 25 Dec 2023 01:16:49 +0900 Subject: [PATCH 1/5] added an argument `rangesOut` to `AutoJoiner::getU8/16` --- include/kiwi/Joiner.h | 9 ++--- include/kiwi/Utils.h | 25 +++++++++++++ src/Combiner.cpp | 12 +++---- src/Combiner.h | 5 ++- src/Joiner.cpp | 81 ++++++++++++++++++++++++++++++++++--------- src/StrUtils.h | 50 ++++++++++++++++++++++++++ 6 files changed, 155 insertions(+), 27 deletions(-) diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 298c82db..4447d3d1 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -28,6 +28,7 @@ namespace kiwi template friend struct Candidate; const CompiledRule* cr = nullptr; KString stack; + std::vector> ranges; size_t activeStart = 0; POSTag lastTag = POSTag::unknown, anteLastTag = POSTag::unknown; @@ -45,8 +46,8 @@ namespace kiwi void add(const std::u16string& form, POSTag tag, Space space = Space::none); void add(const char16_t* form, POSTag tag, Space space = Space::none); - std::u16string getU16() const; - std::string getU8() const; + std::u16string getU16(std::vector>* rangesOut = nullptr) const; + std::string getU8(std::vector>* rangesOut = nullptr) const; }; template @@ -115,8 +116,8 @@ namespace kiwi void add(const std::u16string& form, POSTag tag, bool inferRegularity = true, Space space = Space::none); void add(const char16_t* form, POSTag tag, bool inferRegularity = true, Space space = Space::none); - std::u16string getU16() const; - std::string getU8() const; + std::u16string getU16(std::vector>* rangesOut = nullptr) const; + std::string getU8(std::vector>* rangesOut = nullptr) const; }; } } diff --git a/include/kiwi/Utils.h b/include/kiwi/Utils.h index fbe786c1..bf35076c 100644 --- a/include/kiwi/Utils.h +++ b/include/kiwi/Utils.h @@ -108,6 +108,31 @@ namespace kiwi return ret; } + template + inline std::u16string joinHangul(It first, It last, std::vector& positionOut) + { + std::u16string ret; + ret.reserve(std::distance(first, last)); + positionOut.clear(); + positionOut.reserve(std::distance(first, last)); + for (; first != last; ++first) + { + auto c = *first; + if (isHangulCoda(c) && !ret.empty() && isHangulSyllable(ret.back())) + { + if ((ret.back() - 0xAC00) % 28) ret.push_back(c); + else ret.back() += c - 0x11A7; + positionOut.emplace_back(ret.size() - 1); + } + else + { + ret.push_back(c); + positionOut.emplace_back(ret.size() - 1); + } + } + return ret; + } + inline std::u16string joinHangul(const KString& hangul) { return joinHangul(hangul.begin(), hangul.end()); diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 7c518db4..eface277 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -1147,7 +1147,7 @@ Vector CompiledRule::combineImpl( return ret; } -pair CompiledRule::combineOneImpl( +tuple CompiledRule::combineOneImpl( U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv, CondPolarity cp @@ -1163,12 +1163,12 @@ pair CompiledRule::combineOneImpl( { for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { - if(p.score >= 0) return make_pair(p.str, p.rightBegin); + if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin); KString ret; ret.reserve(leftForm.size() + rightForm.size()); ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.insert(ret.end(), rightForm.begin(), rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } } @@ -1183,7 +1183,7 @@ pair CompiledRule::combineOneImpl( { for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { - return make_pair(p.str, p.rightBegin); + return make_tuple(p.str, p.leftEnd, p.rightBegin); } } } @@ -1198,14 +1198,14 @@ pair CompiledRule::combineOneImpl( ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.push_back(u'아'); // `어`를 `아`로 교체하여 삽입 ret.insert(ret.end(), rightForm.begin() + 1, rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } } KString ret; ret.reserve(leftForm.size() + rightForm.size()); ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.insert(ret.end(), rightForm.begin(), rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } Vector> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const diff --git a/src/Combiner.h b/src/Combiner.h index 542ec3dc..0d407d83 100644 --- a/src/Combiner.h +++ b/src/Combiner.h @@ -215,7 +215,10 @@ namespace kiwi CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none ) const; - std::pair combineOneImpl( + /** + * @return tuple(combinedForm, leftFormBoundary, rightFormBoundary) + */ + std::tuple combineOneImpl( U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 10081341..445252ec 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -79,9 +79,11 @@ namespace kiwi void Joiner::add(U16StringView form, POSTag tag, Space space) { + KString normForm = normalizeHangul(form); if (stack.size() == activeStart) { - stack += normalizeHangul(form); + ranges.emplace_back(stack.size(), stack.size() + normForm.size()); + stack += normForm; lastTag = tag; return; } @@ -90,7 +92,8 @@ namespace kiwi { if (stack.empty() || !isSpace(stack.back())) stack.push_back(u' '); activeStart = stack.size(); - stack += normalizeHangul(form); + ranges.emplace_back(stack.size(), stack.size() + normForm.size()); + stack += normForm; } else { @@ -100,8 +103,6 @@ namespace kiwi cv = isHangulSyllable(stack[activeStart - 1]) ? CondVowel::vowel : CondVowel::non_vowel; } - KString normForm = normalizeHangul(form); - if (!stack.empty() && (isJClass(tag) || isEClass(tag))) { if (isEClass(tag) && normForm[0] == u'아') normForm[0] = u'어'; @@ -148,8 +149,10 @@ namespace kiwi } auto r = cr->combineOneImpl({ stack.data() + activeStart, stack.size() - activeStart }, lastTag, normForm, tag, cv); stack.erase(stack.begin() + activeStart, stack.end()); - stack += r.first; - activeStart += r.second; + ranges.back().second = activeStart + get<1>(r); + ranges.emplace_back(activeStart + get<2>(r), activeStart + get<0>(r).size()); + stack += get<0>(r); + activeStart += get<2>(r); } anteLastTag = lastTag; lastTag = tag; @@ -165,14 +168,46 @@ namespace kiwi return add(U16StringView{ form }, tag, space); } - u16string Joiner::getU16() const + u16string Joiner::getU16(vector>* rangesOut) const { - return joinHangul(stack); + if (rangesOut) + { + rangesOut->clear(); + rangesOut->reserve(ranges.size()); + Vector u16pos; + auto ret = joinHangul(stack.begin(), stack.end(), u16pos); + u16pos.emplace_back(ret.size()); + for (auto& r : ranges) + { + auto endOffset = u16pos[r.second] + (r.second > 0 && u16pos[r.second - 1] == u16pos[r.second] ? 1 : 0); + rangesOut->emplace_back(u16pos[r.first], endOffset); + } + return ret; + } + else + { + return joinHangul(stack); + } } - string Joiner::getU8() const + string Joiner::getU8(vector>* rangesOut) const { - return utf16To8(joinHangul(stack)); + auto u16 = getU16(rangesOut); + if (rangesOut) + { + Vector positions; + auto ret = utf16To8(u16, positions); + for (auto& r : *rangesOut) + { + r.first = positions[r.first]; + r.second = positions[r.second]; + } + return ret; + } + else + { + return utf16To8(u16); + } } AutoJoiner::~AutoJoiner() @@ -422,19 +457,33 @@ namespace kiwi struct GetU16Visitor { + vector>* rangesOut; + + GetU16Visitor(vector>* _rangesOut) + : rangesOut{ _rangesOut } + { + } + template u16string operator()(const Vector>& o) const { - return o[0].joiner.getU16(); + return o[0].joiner.getU16(rangesOut); } }; struct GetU8Visitor { + vector>* rangesOut; + + GetU8Visitor(vector>* _rangesOut) + : rangesOut{ _rangesOut } + { + } + template string operator()(const Vector>& o) const { - return o[0].joiner.getU8(); + return o[0].joiner.getU8(rangesOut); } }; @@ -458,14 +507,14 @@ namespace kiwi return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast(candBuf)); } - u16string AutoJoiner::getU16() const + u16string AutoJoiner::getU16(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU16Visitor{}, reinterpret_cast(candBuf)); + return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast(candBuf)); } - string AutoJoiner::getU8() const + string AutoJoiner::getU8(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU8Visitor{}, reinterpret_cast(candBuf)); + return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast(candBuf)); } } } diff --git a/src/StrUtils.h b/src/StrUtils.h index c8bc34da..0c508b36 100644 --- a/src/StrUtils.h +++ b/src/StrUtils.h @@ -393,6 +393,56 @@ namespace kiwi return ret; } + template + inline std::string utf16To8(nonstd::u16string_view str, std::vector& positions) + { + std::string ret; + positions.clear(); + for (auto it = str.begin(); it != str.end(); ++it) + { + size_t code = *it; + positions.emplace_back(ret.size()); + if (isHighSurrogate(code)) + { + if (++it == str.end()) throw UnicodeException{ "unpaired surrogate" }; + size_t code2 = *it; + if (!isLowSurrogate(code2)) throw UnicodeException{ "unpaired surrogate" }; + positions.emplace_back(ret.size()); + code = mergeSurrogate(code, code2); + } + + if (code <= 0x7F) + { + ret.push_back((char)code); + } + else if (code <= 0x7FF) + { + ret.push_back((char)(0xC0 | (code >> 6))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else if (code <= 0xFFFF) + { + ret.push_back((char)(0xE0 | (code >> 12))); + ret.push_back((char)(0x80 | ((code >> 6) & 0x3F))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else if (code <= 0x10FFFF) + { + ret.push_back((char)(0xF0 | (code >> 18))); + ret.push_back((char)(0x80 | ((code >> 12) & 0x3F))); + ret.push_back((char)(0x80 | ((code >> 6) & 0x3F))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else + { + throw UnicodeException{ "unicode error" }; + } + } + positions.emplace_back(ret.size()); + + return ret; + } + inline std::string utf16To8(const char16_t* str) { return utf16To8(U16StringView{ str }); From 9e814baaf25026f06b91126a323de9ecaca8068d Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 25 Dec 2023 01:17:48 +0900 Subject: [PATCH 2/5] added test cases for rangeOut of `AutoJoiner::getU16` --- test/test_cpp.cpp | 76 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 20 deletions(-) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 51275fd9..9dceb6aa 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -32,6 +32,21 @@ inline testing::AssertionResult testTokenization(Kiwi& kiwi, const std::u16strin } } +#if defined(__GNUC__) && __GNUC__ < 5 +template +constexpr std::vector> toPair(std::initializer_list init) +{ + return std::vector>{ (const std::pair*)init.begin(), (const std::pair*)init.begin() + init.size() / 2 }; +} +#else +template +constexpr std::vector> toPair(const ATy(&init)[n]) +{ + static_assert(n % 2 == 0, "initializer_list must have an even number of elements."); + return std::vector>{ (const std::pair*)init, (const std::pair*)init + n / 2 }; +} +#endif + Kiwi& reuseKiwiInstance() { static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, }.build(); @@ -960,118 +975,139 @@ TEST(KiwiCpp, JoinAffix) TEST(KiwiCpp, AutoJoiner) { Kiwi& kiwi = reuseKiwiInstance(); + std::vector> ranges; auto joiner = kiwi.newJoiner(); joiner.add(u"시동", POSTag::nng); joiner.add(u"를", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시동을"); + EXPECT_EQ(joiner.getU16(&ranges), u"시동을"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"시동", POSTag::nng); joiner.add(u"ᆯ", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시동을"); + EXPECT_EQ(joiner.getU16(&ranges), u"시동을"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"나", POSTag::np); joiner.add(u"ᆯ", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"날"); + EXPECT_EQ(joiner.getU16(&ranges), u"날"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 1 })); joiner = kiwi.newJoiner(); joiner.add(u"시도", POSTag::nng); joiner.add(u"를", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시도를"); + EXPECT_EQ(joiner.getU16(&ranges), u"시도를"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"바다", POSTag::nng); joiner.add(u"가", POSTag::jks); - EXPECT_EQ(joiner.getU16(), u"바다가"); + EXPECT_EQ(joiner.getU16(&ranges), u"바다가"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"바닥", POSTag::nng); joiner.add(u"가", POSTag::jks); - EXPECT_EQ(joiner.getU16(), u"바닥이"); + EXPECT_EQ(joiner.getU16(&ranges), u"바닥이"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"불", POSTag::nng); joiner.add(u"으로", POSTag::jkb); - EXPECT_EQ(joiner.getU16(), u"불로"); + EXPECT_EQ(joiner.getU16(&ranges), u"불로"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"북", POSTag::nng); joiner.add(u"으로", POSTag::jkb); - EXPECT_EQ(joiner.getU16(), u"북으로"); + EXPECT_EQ(joiner.getU16(&ranges), u"북으로"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"갈", POSTag::vv); joiner.add(u"면", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"갈면"); + EXPECT_EQ(joiner.getU16(&ranges), u"갈면"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"갈", POSTag::vv); joiner.add(u"시", POSTag::ep); joiner.add(u"았", POSTag::ep); joiner.add(u"면", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"가셨으면"); + EXPECT_EQ(joiner.getU16(&ranges), u"가셨으면"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 1, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"하", POSTag::vv); joiner.add(u"았", POSTag::ep); joiner.add(u"다", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"했다"); + EXPECT_EQ(joiner.getU16(&ranges), u"했다"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"날", POSTag::vv); joiner.add(u"어", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"날아"); + EXPECT_EQ(joiner.getU16(&ranges), u"날아"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"고기", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"굽", POSTag::vv); joiner.add(u"어", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"고기를 구워"); + EXPECT_EQ(joiner.getU16(&ranges), u"고기를 구워"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3, 4, 6, 5, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"길", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"걷", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"길을 걸어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"길을 걸어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(false); joiner.add(u"길", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"걷", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"길을 걷어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"길을 걷어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"땅", POSTag::nng); joiner.add(u"에", POSTag::jkb); joiner.add(u"묻", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"땅에 묻어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"땅에 묻어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"땅", POSTag::nng); joiner.add(u"이", POSTag::vcp); joiner.add(u"에요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"땅이에요"); + EXPECT_EQ(joiner.getU16(&ranges), u"땅이에요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"바다", POSTag::nng); joiner.add(u"이", POSTag::vcp); joiner.add(u"에요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"바다에요"); + EXPECT_EQ(joiner.getU16(&ranges), u"바다에요"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"좋", POSTag::va); joiner.add(u"은데", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"좋은데"); + EXPECT_EQ(joiner.getU16(&ranges), u"좋은데"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"크", POSTag::va); joiner.add(u"은데", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"큰데"); + EXPECT_EQ(joiner.getU16(&ranges), u"큰데"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 2 })); } TEST(KiwiCpp, UserWordWithNumeric) From 2938d2e7c84954ae7ca1fb0e877a6bb3128c4ba3 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 25 Dec 2023 19:03:24 +0900 Subject: [PATCH 3/5] optimized performance of `AutoJoiner` for many morphs --- src/Joiner.cpp | 27 ++++++++++++++++++++++++++- src/LmState.hpp | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 445252ec..35c4e1b7 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -1,4 +1,4 @@ -#include "Joiner.hpp" +#include "Joiner.hpp" #include "FrozenTrie.hpp" using namespace std; @@ -343,6 +343,31 @@ namespace kiwi n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId); n.joiner.add(form, cands[0]->tag, space); } + + UnorderedMap> bestScoreByState; + for (size_t i = 0; i < candidates.size(); ++i) + { + auto& c = candidates[i]; + auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i)); + if (!inserted.second) + { + if (inserted.first->second.first < c.score) + { + inserted.first->second = make_pair(c.score, i); + } + } + } + + if (bestScoreByState.size() < candidates.size()) + { + Vector> newCandidates; + newCandidates.reserve(bestScoreByState.size()); + for (auto& p : bestScoreByState) + { + newCandidates.emplace_back(move(candidates[p.second.second])); + } + candidates = move(newCandidates); + } } } else diff --git a/src/LmState.hpp b/src/LmState.hpp index ea225886..089cf224 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -30,6 +30,7 @@ namespace kiwi template class KnLMState { + friend class Hash>; int32_t node = 0; public: static constexpr ArchType arch = _arch; @@ -56,6 +57,7 @@ namespace kiwi template class SbgState : public KnLMState<_arch, VocabTy> { + friend class Hash>; size_t historyPos = 0; std::array history = { {0,} }; public: @@ -99,6 +101,41 @@ namespace kiwi } }; + // hash for LmState + template + struct Hash> + { + size_t operator()(const VoidState& state) const + { + return 0; + } + }; + + template + struct Hash> + { + size_t operator()(const KnLMState& state) const + { + std::hash hasher; + return hasher(state.node); + } + }; + + template + struct Hash> + { + size_t operator()(const SbgState& state) const + { + Hash> hasher; + std::hash vocabHasher; + size_t ret = hasher(state); + for (size_t i = 0; i < windowSize; ++i) + { + ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + return ret; + } + }; template struct WrappedKnLM From 3a7ab76c95be2d76b023af4ebefe30bbfe93dbee Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 25 Dec 2023 19:03:58 +0900 Subject: [PATCH 4/5] improved auto spacing of `Joiner` --- src/Joiner.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 35c4e1b7..f69bdffb 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -26,6 +26,7 @@ namespace kiwi if (l == POSTag::sn && r == POSTag::nr) return false; if (l == POSTag::sso || l == POSTag::ssc) return false; if (r == POSTag::sso) return true; + if ((isJClass(l) || isEClass(l)) && r == POSTag::ss) return true; if (r == POSTag::vx && rform.size() == 1 && (rform[0] == u'하' || rform[0] == u'지')) return false; From 5f44f860993599da8921976fea292407283e45ac Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 25 Dec 2023 19:05:48 +0900 Subject: [PATCH 5/5] fixed a bug where `AutoJoiner` failed with unk tag input --- src/Joiner.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Joiner.cpp b/src/Joiner.cpp index f69bdffb..8de2b999 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -1,4 +1,4 @@ -#include "Joiner.hpp" +#include "Joiner.hpp" #include "FrozenTrie.hpp" using namespace std; @@ -300,16 +300,23 @@ namespace kiwi if (!node) break; } + // prevent unknown or partial tag + POSTag fixedTag = tag; + if (tag == POSTag::unknown || tag == POSTag::p) + { + fixedTag = POSTag::nnp; + } + if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) { Vector cands; foreachMorpheme(formHead, [&](const Morpheme* m) { - if (inferRegularity && clearIrregular(m->tag) == clearIrregular(tag)) + if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag)) { cands.emplace_back(m); } - else if (!inferRegularity && m->tag == tag) + else if (!inferRegularity && m->tag == fixedTag) { cands.emplace_back(m); } @@ -317,7 +324,7 @@ namespace kiwi if (cands.size() <= 1) { - auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(tag)) : cands[0]->lmMorphemeId; + auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; if (!cands.empty()) tag = cands[0]->tag; for (auto& cand : candidates) { @@ -373,7 +380,7 @@ namespace kiwi } else { - auto lmId = getDefaultMorphemeId(clearIrregular(tag)); + auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); for (auto& cand : candidates) { cand.score += cand.lmState.next(kiwi->langMdl, lmId);