diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 298c82db..4447d3d1 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -28,6 +28,7 @@ namespace kiwi template friend struct Candidate; const CompiledRule* cr = nullptr; KString stack; + std::vector> ranges; size_t activeStart = 0; POSTag lastTag = POSTag::unknown, anteLastTag = POSTag::unknown; @@ -45,8 +46,8 @@ namespace kiwi void add(const std::u16string& form, POSTag tag, Space space = Space::none); void add(const char16_t* form, POSTag tag, Space space = Space::none); - std::u16string getU16() const; - std::string getU8() const; + std::u16string getU16(std::vector>* rangesOut = nullptr) const; + std::string getU8(std::vector>* rangesOut = nullptr) const; }; template @@ -115,8 +116,8 @@ namespace kiwi void add(const std::u16string& form, POSTag tag, bool inferRegularity = true, Space space = Space::none); void add(const char16_t* form, POSTag tag, bool inferRegularity = true, Space space = Space::none); - std::u16string getU16() const; - std::string getU8() const; + std::u16string getU16(std::vector>* rangesOut = nullptr) const; + std::string getU8(std::vector>* rangesOut = nullptr) const; }; } } diff --git a/include/kiwi/Utils.h b/include/kiwi/Utils.h index fbe786c1..bf35076c 100644 --- a/include/kiwi/Utils.h +++ b/include/kiwi/Utils.h @@ -108,6 +108,31 @@ namespace kiwi return ret; } + template + inline std::u16string joinHangul(It first, It last, std::vector& positionOut) + { + std::u16string ret; + ret.reserve(std::distance(first, last)); + positionOut.clear(); + positionOut.reserve(std::distance(first, last)); + for (; first != last; ++first) + { + auto c = *first; + if (isHangulCoda(c) && !ret.empty() && isHangulSyllable(ret.back())) + { + if ((ret.back() - 0xAC00) % 28) ret.push_back(c); + else ret.back() += c - 0x11A7; + positionOut.emplace_back(ret.size() - 1); + } + else + { + ret.push_back(c); + positionOut.emplace_back(ret.size() - 1); + } + } + return ret; + } + inline std::u16string joinHangul(const KString& hangul) { return joinHangul(hangul.begin(), hangul.end()); diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 7c518db4..eface277 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -1147,7 +1147,7 @@ Vector CompiledRule::combineImpl( return ret; } -pair CompiledRule::combineOneImpl( +tuple CompiledRule::combineOneImpl( U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv, CondPolarity cp @@ -1163,12 +1163,12 @@ pair CompiledRule::combineOneImpl( { for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { - if(p.score >= 0) return make_pair(p.str, p.rightBegin); + if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin); KString ret; ret.reserve(leftForm.size() + rightForm.size()); ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.insert(ret.end(), rightForm.begin(), rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } } @@ -1183,7 +1183,7 @@ pair CompiledRule::combineOneImpl( { for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { - return make_pair(p.str, p.rightBegin); + return make_tuple(p.str, p.leftEnd, p.rightBegin); } } } @@ -1198,14 +1198,14 @@ pair CompiledRule::combineOneImpl( ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.push_back(u'아'); // `어`를 `아`로 교체하여 삽입 ret.insert(ret.end(), rightForm.begin() + 1, rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } } KString ret; ret.reserve(leftForm.size() + rightForm.size()); ret.insert(ret.end(), leftForm.begin(), leftForm.end()); ret.insert(ret.end(), rightForm.begin(), rightForm.end()); - return make_pair(ret, leftForm.size()); + return make_tuple(ret, leftForm.size(), leftForm.size()); } Vector> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const diff --git a/src/Combiner.h b/src/Combiner.h index 542ec3dc..0d407d83 100644 --- a/src/Combiner.h +++ b/src/Combiner.h @@ -215,7 +215,10 @@ namespace kiwi CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none ) const; - std::pair combineOneImpl( + /** + * @return tuple(combinedForm, leftFormBoundary, rightFormBoundary) + */ + std::tuple combineOneImpl( U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 10081341..8de2b999 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -26,6 +26,7 @@ namespace kiwi if (l == POSTag::sn && r == POSTag::nr) return false; if (l == POSTag::sso || l == POSTag::ssc) return false; if (r == POSTag::sso) return true; + if ((isJClass(l) || isEClass(l)) && r == POSTag::ss) return true; if (r == POSTag::vx && rform.size() == 1 && (rform[0] == u'하' || rform[0] == u'지')) return false; @@ -79,9 +80,11 @@ namespace kiwi void Joiner::add(U16StringView form, POSTag tag, Space space) { + KString normForm = normalizeHangul(form); if (stack.size() == activeStart) { - stack += normalizeHangul(form); + ranges.emplace_back(stack.size(), stack.size() + normForm.size()); + stack += normForm; lastTag = tag; return; } @@ -90,7 +93,8 @@ namespace kiwi { if (stack.empty() || !isSpace(stack.back())) stack.push_back(u' '); activeStart = stack.size(); - stack += normalizeHangul(form); + ranges.emplace_back(stack.size(), stack.size() + normForm.size()); + stack += normForm; } else { @@ -100,8 +104,6 @@ namespace kiwi cv = isHangulSyllable(stack[activeStart - 1]) ? CondVowel::vowel : CondVowel::non_vowel; } - KString normForm = normalizeHangul(form); - if (!stack.empty() && (isJClass(tag) || isEClass(tag))) { if (isEClass(tag) && normForm[0] == u'아') normForm[0] = u'어'; @@ -148,8 +150,10 @@ namespace kiwi } auto r = cr->combineOneImpl({ stack.data() + activeStart, stack.size() - activeStart }, lastTag, normForm, tag, cv); stack.erase(stack.begin() + activeStart, stack.end()); - stack += r.first; - activeStart += r.second; + ranges.back().second = activeStart + get<1>(r); + ranges.emplace_back(activeStart + get<2>(r), activeStart + get<0>(r).size()); + stack += get<0>(r); + activeStart += get<2>(r); } anteLastTag = lastTag; lastTag = tag; @@ -165,14 +169,46 @@ namespace kiwi return add(U16StringView{ form }, tag, space); } - u16string Joiner::getU16() const + u16string Joiner::getU16(vector>* rangesOut) const { - return joinHangul(stack); + if (rangesOut) + { + rangesOut->clear(); + rangesOut->reserve(ranges.size()); + Vector u16pos; + auto ret = joinHangul(stack.begin(), stack.end(), u16pos); + u16pos.emplace_back(ret.size()); + for (auto& r : ranges) + { + auto endOffset = u16pos[r.second] + (r.second > 0 && u16pos[r.second - 1] == u16pos[r.second] ? 1 : 0); + rangesOut->emplace_back(u16pos[r.first], endOffset); + } + return ret; + } + else + { + return joinHangul(stack); + } } - string Joiner::getU8() const + string Joiner::getU8(vector>* rangesOut) const { - return utf16To8(joinHangul(stack)); + auto u16 = getU16(rangesOut); + if (rangesOut) + { + Vector positions; + auto ret = utf16To8(u16, positions); + for (auto& r : *rangesOut) + { + r.first = positions[r.first]; + r.second = positions[r.second]; + } + return ret; + } + else + { + return utf16To8(u16); + } } AutoJoiner::~AutoJoiner() @@ -264,16 +300,23 @@ namespace kiwi if (!node) break; } + // prevent unknown or partial tag + POSTag fixedTag = tag; + if (tag == POSTag::unknown || tag == POSTag::p) + { + fixedTag = POSTag::nnp; + } + if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) { Vector cands; foreachMorpheme(formHead, [&](const Morpheme* m) { - if (inferRegularity && clearIrregular(m->tag) == clearIrregular(tag)) + if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag)) { cands.emplace_back(m); } - else if (!inferRegularity && m->tag == tag) + else if (!inferRegularity && m->tag == fixedTag) { cands.emplace_back(m); } @@ -281,7 +324,7 @@ namespace kiwi if (cands.size() <= 1) { - auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(tag)) : cands[0]->lmMorphemeId; + auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; if (!cands.empty()) tag = cands[0]->tag; for (auto& cand : candidates) { @@ -308,11 +351,36 @@ namespace kiwi n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId); n.joiner.add(form, cands[0]->tag, space); } + + UnorderedMap> bestScoreByState; + for (size_t i = 0; i < candidates.size(); ++i) + { + auto& c = candidates[i]; + auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i)); + if (!inserted.second) + { + if (inserted.first->second.first < c.score) + { + inserted.first->second = make_pair(c.score, i); + } + } + } + + if (bestScoreByState.size() < candidates.size()) + { + Vector> newCandidates; + newCandidates.reserve(bestScoreByState.size()); + for (auto& p : bestScoreByState) + { + newCandidates.emplace_back(move(candidates[p.second.second])); + } + candidates = move(newCandidates); + } } } else { - auto lmId = getDefaultMorphemeId(clearIrregular(tag)); + auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); for (auto& cand : candidates) { cand.score += cand.lmState.next(kiwi->langMdl, lmId); @@ -422,19 +490,33 @@ namespace kiwi struct GetU16Visitor { + vector>* rangesOut; + + GetU16Visitor(vector>* _rangesOut) + : rangesOut{ _rangesOut } + { + } + template u16string operator()(const Vector>& o) const { - return o[0].joiner.getU16(); + return o[0].joiner.getU16(rangesOut); } }; struct GetU8Visitor { + vector>* rangesOut; + + GetU8Visitor(vector>* _rangesOut) + : rangesOut{ _rangesOut } + { + } + template string operator()(const Vector>& o) const { - return o[0].joiner.getU8(); + return o[0].joiner.getU8(rangesOut); } }; @@ -458,14 +540,14 @@ namespace kiwi return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast(candBuf)); } - u16string AutoJoiner::getU16() const + u16string AutoJoiner::getU16(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU16Visitor{}, reinterpret_cast(candBuf)); + return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast(candBuf)); } - string AutoJoiner::getU8() const + string AutoJoiner::getU8(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU8Visitor{}, reinterpret_cast(candBuf)); + return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast(candBuf)); } } } diff --git a/src/LmState.hpp b/src/LmState.hpp index ea225886..089cf224 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -30,6 +30,7 @@ namespace kiwi template class KnLMState { + friend class Hash>; int32_t node = 0; public: static constexpr ArchType arch = _arch; @@ -56,6 +57,7 @@ namespace kiwi template class SbgState : public KnLMState<_arch, VocabTy> { + friend class Hash>; size_t historyPos = 0; std::array history = { {0,} }; public: @@ -99,6 +101,41 @@ namespace kiwi } }; + // hash for LmState + template + struct Hash> + { + size_t operator()(const VoidState& state) const + { + return 0; + } + }; + + template + struct Hash> + { + size_t operator()(const KnLMState& state) const + { + std::hash hasher; + return hasher(state.node); + } + }; + + template + struct Hash> + { + size_t operator()(const SbgState& state) const + { + Hash> hasher; + std::hash vocabHasher; + size_t ret = hasher(state); + for (size_t i = 0; i < windowSize; ++i) + { + ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + return ret; + } + }; template struct WrappedKnLM diff --git a/src/StrUtils.h b/src/StrUtils.h index c8bc34da..0c508b36 100644 --- a/src/StrUtils.h +++ b/src/StrUtils.h @@ -393,6 +393,56 @@ namespace kiwi return ret; } + template + inline std::string utf16To8(nonstd::u16string_view str, std::vector& positions) + { + std::string ret; + positions.clear(); + for (auto it = str.begin(); it != str.end(); ++it) + { + size_t code = *it; + positions.emplace_back(ret.size()); + if (isHighSurrogate(code)) + { + if (++it == str.end()) throw UnicodeException{ "unpaired surrogate" }; + size_t code2 = *it; + if (!isLowSurrogate(code2)) throw UnicodeException{ "unpaired surrogate" }; + positions.emplace_back(ret.size()); + code = mergeSurrogate(code, code2); + } + + if (code <= 0x7F) + { + ret.push_back((char)code); + } + else if (code <= 0x7FF) + { + ret.push_back((char)(0xC0 | (code >> 6))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else if (code <= 0xFFFF) + { + ret.push_back((char)(0xE0 | (code >> 12))); + ret.push_back((char)(0x80 | ((code >> 6) & 0x3F))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else if (code <= 0x10FFFF) + { + ret.push_back((char)(0xF0 | (code >> 18))); + ret.push_back((char)(0x80 | ((code >> 12) & 0x3F))); + ret.push_back((char)(0x80 | ((code >> 6) & 0x3F))); + ret.push_back((char)(0x80 | (code & 0x3F))); + } + else + { + throw UnicodeException{ "unicode error" }; + } + } + positions.emplace_back(ret.size()); + + return ret; + } + inline std::string utf16To8(const char16_t* str) { return utf16To8(U16StringView{ str }); diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 51275fd9..9dceb6aa 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -32,6 +32,21 @@ inline testing::AssertionResult testTokenization(Kiwi& kiwi, const std::u16strin } } +#if defined(__GNUC__) && __GNUC__ < 5 +template +constexpr std::vector> toPair(std::initializer_list init) +{ + return std::vector>{ (const std::pair*)init.begin(), (const std::pair*)init.begin() + init.size() / 2 }; +} +#else +template +constexpr std::vector> toPair(const ATy(&init)[n]) +{ + static_assert(n % 2 == 0, "initializer_list must have an even number of elements."); + return std::vector>{ (const std::pair*)init, (const std::pair*)init + n / 2 }; +} +#endif + Kiwi& reuseKiwiInstance() { static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, }.build(); @@ -960,118 +975,139 @@ TEST(KiwiCpp, JoinAffix) TEST(KiwiCpp, AutoJoiner) { Kiwi& kiwi = reuseKiwiInstance(); + std::vector> ranges; auto joiner = kiwi.newJoiner(); joiner.add(u"시동", POSTag::nng); joiner.add(u"를", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시동을"); + EXPECT_EQ(joiner.getU16(&ranges), u"시동을"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"시동", POSTag::nng); joiner.add(u"ᆯ", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시동을"); + EXPECT_EQ(joiner.getU16(&ranges), u"시동을"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"나", POSTag::np); joiner.add(u"ᆯ", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"날"); + EXPECT_EQ(joiner.getU16(&ranges), u"날"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 1 })); joiner = kiwi.newJoiner(); joiner.add(u"시도", POSTag::nng); joiner.add(u"를", POSTag::jko); - EXPECT_EQ(joiner.getU16(), u"시도를"); + EXPECT_EQ(joiner.getU16(&ranges), u"시도를"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"바다", POSTag::nng); joiner.add(u"가", POSTag::jks); - EXPECT_EQ(joiner.getU16(), u"바다가"); + EXPECT_EQ(joiner.getU16(&ranges), u"바다가"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"바닥", POSTag::nng); joiner.add(u"가", POSTag::jks); - EXPECT_EQ(joiner.getU16(), u"바닥이"); + EXPECT_EQ(joiner.getU16(&ranges), u"바닥이"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"불", POSTag::nng); joiner.add(u"으로", POSTag::jkb); - EXPECT_EQ(joiner.getU16(), u"불로"); + EXPECT_EQ(joiner.getU16(&ranges), u"불로"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"북", POSTag::nng); joiner.add(u"으로", POSTag::jkb); - EXPECT_EQ(joiner.getU16(), u"북으로"); + EXPECT_EQ(joiner.getU16(&ranges), u"북으로"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"갈", POSTag::vv); joiner.add(u"면", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"갈면"); + EXPECT_EQ(joiner.getU16(&ranges), u"갈면"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"갈", POSTag::vv); joiner.add(u"시", POSTag::ep); joiner.add(u"았", POSTag::ep); joiner.add(u"면", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"가셨으면"); + EXPECT_EQ(joiner.getU16(&ranges), u"가셨으면"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 1, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"하", POSTag::vv); joiner.add(u"았", POSTag::ep); joiner.add(u"다", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"했다"); + EXPECT_EQ(joiner.getU16(&ranges), u"했다"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"날", POSTag::vv); joiner.add(u"어", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"날아"); + EXPECT_EQ(joiner.getU16(&ranges), u"날아"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2 })); joiner = kiwi.newJoiner(); joiner.add(u"고기", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"굽", POSTag::vv); joiner.add(u"어", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"고기를 구워"); + EXPECT_EQ(joiner.getU16(&ranges), u"고기를 구워"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 3, 4, 6, 5, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"길", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"걷", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"길을 걸어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"길을 걸어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(false); joiner.add(u"길", POSTag::nng); joiner.add(u"을", POSTag::jko); joiner.add(u"걷", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"길을 걷어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"길을 걷어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"땅", POSTag::nng); joiner.add(u"에", POSTag::jkb); joiner.add(u"묻", POSTag::vv); joiner.add(u"어요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"땅에 묻어요"); + EXPECT_EQ(joiner.getU16(&ranges), u"땅에 묻어요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 3, 4, 4, 6 })); joiner = kiwi.newJoiner(); joiner.add(u"땅", POSTag::nng); joiner.add(u"이", POSTag::vcp); joiner.add(u"에요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"땅이에요"); + EXPECT_EQ(joiner.getU16(&ranges), u"땅이에요"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"바다", POSTag::nng); joiner.add(u"이", POSTag::vcp); joiner.add(u"에요", POSTag::ef); - EXPECT_EQ(joiner.getU16(), u"바다에요"); + EXPECT_EQ(joiner.getU16(&ranges), u"바다에요"); + EXPECT_EQ(ranges, toPair({ 0, 2, 2, 2, 2, 4 })); joiner = kiwi.newJoiner(); joiner.add(u"좋", POSTag::va); joiner.add(u"은데", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"좋은데"); + EXPECT_EQ(joiner.getU16(&ranges), u"좋은데"); + EXPECT_EQ(ranges, toPair({ 0, 1, 1, 3 })); joiner = kiwi.newJoiner(); joiner.add(u"크", POSTag::va); joiner.add(u"은데", POSTag::ec); - EXPECT_EQ(joiner.getU16(), u"큰데"); + EXPECT_EQ(joiner.getU16(&ranges), u"큰데"); + EXPECT_EQ(ranges, toPair({ 0, 1, 0, 2 })); } TEST(KiwiCpp, UserWordWithNumeric)