Skip to content

Commit

Permalink
Merge pull request #152 from bab2min/dev_joiner_position
Browse files Browse the repository at this point in the history
`AutoJoiner` 개선
  • Loading branch information
bab2min authored Jan 30, 2024
2 parents b708dda + 5f44f86 commit 1e055bc
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 51 deletions.
9 changes: 5 additions & 4 deletions include/kiwi/Joiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace kiwi
template<class LmState> friend struct Candidate;
const CompiledRule* cr = nullptr;
KString stack;
std::vector<std::pair<uint32_t, uint32_t>> ranges;
size_t activeStart = 0;
POSTag lastTag = POSTag::unknown, anteLastTag = POSTag::unknown;

Expand All @@ -45,8 +46,8 @@ namespace kiwi
void add(const std::u16string& form, POSTag tag, Space space = Space::none);
void add(const char16_t* form, POSTag tag, Space space = Space::none);

std::u16string getU16() const;
std::string getU8() const;
std::u16string getU16(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
std::string getU8(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
};

template<class LmState>
Expand Down Expand Up @@ -115,8 +116,8 @@ namespace kiwi
void add(const std::u16string& form, POSTag tag, bool inferRegularity = true, Space space = Space::none);
void add(const char16_t* form, POSTag tag, bool inferRegularity = true, Space space = Space::none);

std::u16string getU16() const;
std::string getU8() const;
std::u16string getU16(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
std::string getU8(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
};
}
}
25 changes: 25 additions & 0 deletions include/kiwi/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@ namespace kiwi
return ret;
}

template<class It, class Ty, class Alloc>
inline std::u16string joinHangul(It first, It last, std::vector<Ty, Alloc>& positionOut)
{
std::u16string ret;
ret.reserve(std::distance(first, last));
positionOut.clear();
positionOut.reserve(std::distance(first, last));
for (; first != last; ++first)
{
auto c = *first;
if (isHangulCoda(c) && !ret.empty() && isHangulSyllable(ret.back()))
{
if ((ret.back() - 0xAC00) % 28) ret.push_back(c);
else ret.back() += c - 0x11A7;
positionOut.emplace_back(ret.size() - 1);
}
else
{
ret.push_back(c);
positionOut.emplace_back(ret.size() - 1);
}
}
return ret;
}

inline std::u16string joinHangul(const KString& hangul)
{
return joinHangul(hangul.begin(), hangul.end());
Expand Down
12 changes: 6 additions & 6 deletions src/Combiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ Vector<KString> CompiledRule::combineImpl(
return ret;
}

pair<KString, size_t> CompiledRule::combineOneImpl(
tuple<KString, size_t, size_t> CompiledRule::combineOneImpl(
U16StringView leftForm, POSTag leftTag,
U16StringView rightForm, POSTag rightTag,
CondVowel cv, CondPolarity cp
Expand All @@ -1163,12 +1163,12 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
{
for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second]))
{
if(p.score >= 0) return make_pair(p.str, p.rightBegin);
if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin);
KString ret;
ret.reserve(leftForm.size() + rightForm.size());
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.insert(ret.end(), rightForm.begin(), rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}
}

Expand All @@ -1183,7 +1183,7 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
{
for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second]))
{
return make_pair(p.str, p.rightBegin);
return make_tuple(p.str, p.leftEnd, p.rightBegin);
}
}
}
Expand All @@ -1198,14 +1198,14 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.push_back(u''); // `어`를 `아`로 교체하여 삽입
ret.insert(ret.end(), rightForm.begin() + 1, rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}
}
KString ret;
ret.reserve(leftForm.size() + rightForm.size());
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.insert(ret.end(), rightForm.begin(), rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}

Vector<tuple<size_t, size_t, CondPolarity>> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const
Expand Down
5 changes: 4 additions & 1 deletion src/Combiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ namespace kiwi
CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none
) const;

std::pair<KString, size_t> combineOneImpl(
/**
* @return tuple(combinedForm, leftFormBoundary, rightFormBoundary)
*/
std::tuple<KString, size_t, size_t> combineOneImpl(
U16StringView leftForm, POSTag leftTag,
U16StringView rightForm, POSTag rightTag,
CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none
Expand Down
122 changes: 102 additions & 20 deletions src/Joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace kiwi
if (l == POSTag::sn && r == POSTag::nr) return false;
if (l == POSTag::sso || l == POSTag::ssc) return false;
if (r == POSTag::sso) return true;
if ((isJClass(l) || isEClass(l)) && r == POSTag::ss) return true;

if (r == POSTag::vx && rform.size() == 1 && (rform[0] == u'' || rform[0] == u'')) return false;

Expand Down Expand Up @@ -79,9 +80,11 @@ namespace kiwi

void Joiner::add(U16StringView form, POSTag tag, Space space)
{
KString normForm = normalizeHangul(form);
if (stack.size() == activeStart)
{
stack += normalizeHangul(form);
ranges.emplace_back(stack.size(), stack.size() + normForm.size());
stack += normForm;
lastTag = tag;
return;
}
Expand All @@ -90,7 +93,8 @@ namespace kiwi
{
if (stack.empty() || !isSpace(stack.back())) stack.push_back(u' ');
activeStart = stack.size();
stack += normalizeHangul(form);
ranges.emplace_back(stack.size(), stack.size() + normForm.size());
stack += normForm;
}
else
{
Expand All @@ -100,8 +104,6 @@ namespace kiwi
cv = isHangulSyllable(stack[activeStart - 1]) ? CondVowel::vowel : CondVowel::non_vowel;
}

KString normForm = normalizeHangul(form);

if (!stack.empty() && (isJClass(tag) || isEClass(tag)))
{
if (isEClass(tag) && normForm[0] == u'') normForm[0] = u'';
Expand Down Expand Up @@ -148,8 +150,10 @@ namespace kiwi
}
auto r = cr->combineOneImpl({ stack.data() + activeStart, stack.size() - activeStart }, lastTag, normForm, tag, cv);
stack.erase(stack.begin() + activeStart, stack.end());
stack += r.first;
activeStart += r.second;
ranges.back().second = activeStart + get<1>(r);
ranges.emplace_back(activeStart + get<2>(r), activeStart + get<0>(r).size());
stack += get<0>(r);
activeStart += get<2>(r);
}
anteLastTag = lastTag;
lastTag = tag;
Expand All @@ -165,14 +169,46 @@ namespace kiwi
return add(U16StringView{ form }, tag, space);
}

u16string Joiner::getU16() const
u16string Joiner::getU16(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return joinHangul(stack);
if (rangesOut)
{
rangesOut->clear();
rangesOut->reserve(ranges.size());
Vector<uint32_t> u16pos;
auto ret = joinHangul(stack.begin(), stack.end(), u16pos);
u16pos.emplace_back(ret.size());
for (auto& r : ranges)
{
auto endOffset = u16pos[r.second] + (r.second > 0 && u16pos[r.second - 1] == u16pos[r.second] ? 1 : 0);
rangesOut->emplace_back(u16pos[r.first], endOffset);
}
return ret;
}
else
{
return joinHangul(stack);
}
}

string Joiner::getU8() const
string Joiner::getU8(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return utf16To8(joinHangul(stack));
auto u16 = getU16(rangesOut);
if (rangesOut)
{
Vector<uint32_t> positions;
auto ret = utf16To8(u16, positions);
for (auto& r : *rangesOut)
{
r.first = positions[r.first];
r.second = positions[r.second];
}
return ret;
}
else
{
return utf16To8(u16);
}
}

AutoJoiner::~AutoJoiner()
Expand Down Expand Up @@ -264,24 +300,31 @@ namespace kiwi
if (!node) break;
}

// prevent unknown or partial tag
POSTag fixedTag = tag;
if (tag == POSTag::unknown || tag == POSTag::p)
{
fixedTag = POSTag::nnp;
}

if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie)))
{
Vector<const Morpheme*> cands;
foreachMorpheme(formHead, [&](const Morpheme* m)
{
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(tag))
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag))
{
cands.emplace_back(m);
}
else if (!inferRegularity && m->tag == tag)
else if (!inferRegularity && m->tag == fixedTag)
{
cands.emplace_back(m);
}
});

if (cands.size() <= 1)
{
auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(tag)) : cands[0]->lmMorphemeId;
auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId;
if (!cands.empty()) tag = cands[0]->tag;
for (auto& cand : candidates)
{
Expand All @@ -308,11 +351,36 @@ namespace kiwi
n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId);
n.joiner.add(form, cands[0]->tag, space);
}

UnorderedMap<LmState, pair<float, uint32_t>> bestScoreByState;
for (size_t i = 0; i < candidates.size(); ++i)
{
auto& c = candidates[i];
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i));
if (!inserted.second)
{
if (inserted.first->second.first < c.score)
{
inserted.first->second = make_pair(c.score, i);
}
}
}

if (bestScoreByState.size() < candidates.size())
{
Vector<Candidate<LmState>> newCandidates;
newCandidates.reserve(bestScoreByState.size());
for (auto& p : bestScoreByState)
{
newCandidates.emplace_back(move(candidates[p.second.second]));
}
candidates = move(newCandidates);
}
}
}
else
{
auto lmId = getDefaultMorphemeId(clearIrregular(tag));
auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag));
for (auto& cand : candidates)
{
cand.score += cand.lmState.next(kiwi->langMdl, lmId);
Expand Down Expand Up @@ -422,19 +490,33 @@ namespace kiwi

struct GetU16Visitor
{
vector<pair<uint32_t, uint32_t>>* rangesOut;

GetU16Visitor(vector<pair<uint32_t, uint32_t>>* _rangesOut)
: rangesOut{ _rangesOut }
{
}

template<class LmState>
u16string operator()(const Vector<Candidate<LmState>>& o) const
{
return o[0].joiner.getU16();
return o[0].joiner.getU16(rangesOut);
}
};

struct GetU8Visitor
{
vector<pair<uint32_t, uint32_t>>* rangesOut;

GetU8Visitor(vector<pair<uint32_t, uint32_t>>* _rangesOut)
: rangesOut{ _rangesOut }
{
}

template<class LmState>
string operator()(const Vector<Candidate<LmState>>& o) const
{
return o[0].joiner.getU8();
return o[0].joiner.getU8(rangesOut);
}
};

Expand All @@ -458,14 +540,14 @@ namespace kiwi
return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast<CandVector&>(candBuf));
}

u16string AutoJoiner::getU16() const
u16string AutoJoiner::getU16(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return mapbox::util::apply_visitor(GetU16Visitor{}, reinterpret_cast<const CandVector&>(candBuf));
return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast<const CandVector&>(candBuf));
}

string AutoJoiner::getU8() const
string AutoJoiner::getU8(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return mapbox::util::apply_visitor(GetU8Visitor{}, reinterpret_cast<const CandVector&>(candBuf));
return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast<const CandVector&>(candBuf));
}
}
}
Loading

0 comments on commit 1e055bc

Please sign in to comment.