diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index 2c19675f..ea1fa317 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -241,6 +241,11 @@ namespace kiwi return irregular ? setIrregular(tag) : clearIrregular(tag); } + inline constexpr bool areTagsEqual(POSTag a, POSTag b, bool ignoreRegularity = false) + { + return ignoreRegularity ? (clearIrregular(a) == clearIrregular(b)) : (a == b); + } + constexpr size_t defaultTagSize = (size_t)POSTag::p; /** @@ -349,9 +354,10 @@ namespace kiwi std::u16string form; uint32_t begin = -1, end = -1; POSTag tag = POSTag::unknown; + uint8_t inferRegularity = 1; - BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown) - : form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag } + BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown, uint8_t _inferRegularity = 1) + : form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag }, inferRegularity{ _inferRegularity } {} }; diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 8de2b999..d6c3de03 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -312,11 +312,7 @@ namespace kiwi Vector cands; foreachMorpheme(formHead, [&](const Morpheme* m) { - if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag)) - { - cands.emplace_back(m); - } - else if (!inferRegularity && m->tag == fixedTag) + if (areTagsEqual(m->tag, fixedTag, inferRegularity)) { cands.emplace_back(m); } @@ -412,7 +408,7 @@ namespace kiwi Vector cands; foreachMorpheme(formHead, [&](const Morpheme* m) { - if (clearIrregular(m->tag) == clearIrregular(tag)) + if (areTagsEqual(m->tag, tag, true)) { cands.emplace_back(m); } diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index cc0e3649..c01df3e1 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -247,7 +247,7 @@ namespace kiwi case POSTag::vcp: case POSTag::etm: case POSTag::ec: - if (t.tag == POSTag::jx && *t.morph->kform == u"요") + if (t.tag == POSTag::jx && t.morph && *t.morph->kform == u"요") { if (state == State::ef) { @@ -804,36 +804,40 @@ namespace kiwi { auto formStr = normalizeHangul(s.tokenization[0].form); auto* tform = findForm(formTrie, formStr); - if (tform && tform->candidate.size() == 1 && tform->candidate[0]->tag == s.tokenization[0].tag) // reuse the predefined form & morpheme + if (tform && tform->candidate.size() == 1 && + areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity)) + // reuse the predefined form & morpheme { span.form = tform; } - else if (formStr == normStr.substr(span.begin, span.end - span.begin)) // use a fallback form - { - span.form = formTrie.value((size_t)clearIrregular(s.tokenization[0].tag)); - } else // or add a new form & morpheme { ret.forms.emplace_back(); auto& form = ret.forms.back(); form.form = move(formStr); - form.candidate = FixedVector{ 1 }; - const Morpheme* foundMorph = nullptr; + const Morpheme* foundMorph[2] = { nullptr, nullptr }; if (tform) { + size_t i = 0; for (auto m : tform->candidate) { - if (m->tag == s.tokenization[0].tag) + if (areTagsEqual(m->tag, s.tokenization[0].tag, s.tokenization[0].inferRegularity)) { - foundMorph = m; - break; + foundMorph[i++] = m; + if (i >= 2) break; } } } + + form.candidate = FixedVector{ (size_t)(foundMorph[1] ? 2 : 1) }; - if (foundMorph) + if (foundMorph[0]) { - form.candidate[0] = foundMorph; + form.candidate[0] = foundMorph[0]; + if (foundMorph[1]) + { + form.candidate[1] = foundMorph[1]; + } } else { diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index c3d85940..299f4fe6 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -271,6 +271,34 @@ TEST(KiwiCpp, Pretokenized) EXPECT_EQ(res[13].str, u"매트"); EXPECT_EQ(res[13].tag, POSTag::nng); } + + { + std::vector pretokenized = { + PretokenizedSpan{ 9, 10, { BasicToken{ u"가", 0, 1, POSTag::jks } } }, + PretokenizedSpan{ 16, 17, { BasicToken{ u"에", 0, 1, POSTag::jkb } } }, + }; + + auto ref = kiwi.analyze(str, Match::allWithNormalizing).first; + res = kiwi.analyze(str, Match::allWithNormalizing, nullptr, pretokenized).first; + EXPECT_EQ(res[2].tag, POSTag::jks); + EXPECT_EQ(res[2].morph, ref[2].morph); + EXPECT_EQ(res[2].score, ref[2].score); + EXPECT_EQ(res[5].tag, POSTag::jkb); + EXPECT_EQ(res[5].morph, ref[5].morph); + EXPECT_EQ(res[5].score, ref[5].score); + } + + { + auto str2 = u"길을 걷다"; + std::vector pretokenized = { + PretokenizedSpan{ 3, 4, { BasicToken{ u"걷", 0, 1, POSTag::vv } } }, + }; + + auto ref = kiwi.analyze(str2, Match::allWithNormalizing).first; + res = kiwi.analyze(str2, Match::allWithNormalizing, nullptr, pretokenized).first; + EXPECT_EQ(res[2].tag, POSTag::vvi); + EXPECT_EQ(res[2].morph, ref[2].morph); + } } TEST(KiwiCpp, TagRoundTrip)