Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretokenized span 개선 #170

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/kiwi/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ namespace kiwi
return irregular ? setIrregular(tag) : clearIrregular(tag);
}

inline constexpr bool areTagsEqual(POSTag a, POSTag b, bool ignoreRegularity = false)
{
return ignoreRegularity ? (clearIrregular(a) == clearIrregular(b)) : (a == b);
}

constexpr size_t defaultTagSize = (size_t)POSTag::p;

/**
Expand Down Expand Up @@ -349,9 +354,10 @@ namespace kiwi
std::u16string form;
uint32_t begin = -1, end = -1;
POSTag tag = POSTag::unknown;
uint8_t inferRegularity = 1;

BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown)
: form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag }
BasicToken(const std::u16string& _form = {}, uint32_t _begin = -1, uint32_t _end = -1, POSTag _tag = POSTag::unknown, uint8_t _inferRegularity = 1)
: form{ _form }, begin{ _begin }, end{ _end }, tag{ _tag }, inferRegularity{ _inferRegularity }
{}
};

Expand Down
8 changes: 2 additions & 6 deletions src/Joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,7 @@ namespace kiwi
Vector<const Morpheme*> cands;
foreachMorpheme(formHead, [&](const Morpheme* m)
{
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag))
{
cands.emplace_back(m);
}
else if (!inferRegularity && m->tag == fixedTag)
if (areTagsEqual(m->tag, fixedTag, inferRegularity))
{
cands.emplace_back(m);
}
Expand Down Expand Up @@ -412,7 +408,7 @@ namespace kiwi
Vector<const Morpheme*> cands;
foreachMorpheme(formHead, [&](const Morpheme* m)
{
if (clearIrregular(m->tag) == clearIrregular(tag))
if (areTagsEqual(m->tag, tag, true))
{
cands.emplace_back(m);
}
Expand Down
30 changes: 17 additions & 13 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ namespace kiwi
case POSTag::vcp:
case POSTag::etm:
case POSTag::ec:
if (t.tag == POSTag::jx && *t.morph->kform == u"요")
if (t.tag == POSTag::jx && t.morph && *t.morph->kform == u"요")
{
if (state == State::ef)
{
Expand Down Expand Up @@ -804,36 +804,40 @@ namespace kiwi
{
auto formStr = normalizeHangul(s.tokenization[0].form);
auto* tform = findForm(formTrie, formStr);
if (tform && tform->candidate.size() == 1 && tform->candidate[0]->tag == s.tokenization[0].tag) // reuse the predefined form & morpheme
if (tform && tform->candidate.size() == 1 &&
areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity))
// reuse the predefined form & morpheme
{
span.form = tform;
}
else if (formStr == normStr.substr(span.begin, span.end - span.begin)) // use a fallback form
{
span.form = formTrie.value((size_t)clearIrregular(s.tokenization[0].tag));
}
else // or add a new form & morpheme
{
ret.forms.emplace_back();
auto& form = ret.forms.back();
form.form = move(formStr);
form.candidate = FixedVector<const Morpheme*>{ 1 };
const Morpheme* foundMorph = nullptr;
const Morpheme* foundMorph[2] = { nullptr, nullptr };
if (tform)
{
size_t i = 0;
for (auto m : tform->candidate)
{
if (m->tag == s.tokenization[0].tag)
if (areTagsEqual(m->tag, s.tokenization[0].tag, s.tokenization[0].inferRegularity))
{
foundMorph = m;
break;
foundMorph[i++] = m;
if (i >= 2) break;
}
}
}

form.candidate = FixedVector<const Morpheme*>{ (size_t)(foundMorph[1] ? 2 : 1) };

if (foundMorph)
if (foundMorph[0])
{
form.candidate[0] = foundMorph;
form.candidate[0] = foundMorph[0];
if (foundMorph[1])
{
form.candidate[1] = foundMorph[1];
}
}
else
{
Expand Down
28 changes: 28 additions & 0 deletions test/test_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,34 @@ TEST(KiwiCpp, Pretokenized)
EXPECT_EQ(res[13].str, u"매트");
EXPECT_EQ(res[13].tag, POSTag::nng);
}

{
std::vector<PretokenizedSpan> pretokenized = {
PretokenizedSpan{ 9, 10, { BasicToken{ u"가", 0, 1, POSTag::jks } } },
PretokenizedSpan{ 16, 17, { BasicToken{ u"에", 0, 1, POSTag::jkb } } },
};

auto ref = kiwi.analyze(str, Match::allWithNormalizing).first;
res = kiwi.analyze(str, Match::allWithNormalizing, nullptr, pretokenized).first;
EXPECT_EQ(res[2].tag, POSTag::jks);
EXPECT_EQ(res[2].morph, ref[2].morph);
EXPECT_EQ(res[2].score, ref[2].score);
EXPECT_EQ(res[5].tag, POSTag::jkb);
EXPECT_EQ(res[5].morph, ref[5].morph);
EXPECT_EQ(res[5].score, ref[5].score);
}

{
auto str2 = u"길을 걷다";
std::vector<PretokenizedSpan> pretokenized = {
PretokenizedSpan{ 3, 4, { BasicToken{ u"걷", 0, 1, POSTag::vv } } },
};

auto ref = kiwi.analyze(str2, Match::allWithNormalizing).first;
res = kiwi.analyze(str2, Match::allWithNormalizing, nullptr, pretokenized).first;
EXPECT_EQ(res[2].tag, POSTag::vvi);
EXPECT_EQ(res[2].morph, ref[2].morph);
}
}

TEST(KiwiCpp, TagRoundTrip)
Expand Down
Loading