Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

공백 없는 긴 텍스트에 대한 효율적인 분할 #156

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 50 additions & 47 deletions src/KTrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace kiwi
{
static constexpr uint32_t npos = -1;

if (endPosMap[startPos].first == npos)
if (endPosMap[startPos].first == endPosMap[startPos].second)
{
return false;
}
Expand All @@ -30,15 +30,15 @@ namespace kiwi
nnode.prev = newId - endPosMap[startPos].first;
if (nnode.endPos >= endPosMap.size()) return true;

if (endPosMap[nnode.endPos].first == npos)
if (endPosMap[nnode.endPos].first == endPosMap[nnode.endPos].second)
{
endPosMap[nnode.endPos].first = newId;
endPosMap[nnode.endPos].second = newId;
endPosMap[nnode.endPos].second = newId + 1;
}
else
{
nodes[endPosMap[nnode.endPos].second].sibling = newId - endPosMap[nnode.endPos].second;
endPosMap[nnode.endPos].second = newId;
nodes[endPosMap[nnode.endPos].second - 1].sibling = newId - (endPosMap[nnode.endPos].second - 1);
endPosMap[nnode.endPos].second = newId + 1;
}
return true;
}
Expand Down Expand Up @@ -118,46 +118,34 @@ namespace kiwi
return true;
}

inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph)
inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph, const Vector<std::pair<uint32_t, uint32_t>>& endPosMap)
{
Vector<uint8_t> connectedList(graph.size());
Vector<uint16_t> newIndexDiff(graph.size());
connectedList[graph.size() - 1] = true;
connectedList[0] = true;
// forward searching
for (size_t i = 1; i < graph.size(); ++i)
{
bool connected = false;
for (auto prev = graph[i].getPrev(); prev; prev = prev->getSibling())
thread_local Vector<uint8_t> connectedList;
thread_local Vector<uint16_t> newIndexDiff;
thread_local Deque<uint32_t> updateList;
connectedList.clear();
connectedList.resize(graph.size());
newIndexDiff.clear();
newIndexDiff.resize(graph.size());
updateList.clear();
updateList.emplace_back(graph.size() - 1);
connectedList[graph.size() - 1] = 1;

while (!updateList.empty())
{
const auto id = updateList.front();
updateList.pop_front();
const auto& node = graph[id];
const auto scanStart = endPosMap[node.startPos].first, scanEnd = endPosMap[node.startPos].second;
for (auto i = scanStart; i < scanEnd; ++i)
{
if (connectedList[prev - graph.data()])
{
connected = true;
break;
}
if (graph[i].endPos != node.startPos) continue;
if (connectedList[i]) continue;
updateList.emplace_back(i);
}
connectedList[i] = connected ? 1 : 0;
}
// backward searching
for (size_t i = graph.size() - 1; i-- > 1; )
{
bool connected = false;
for (size_t j = i + 1; j < graph.size(); ++j)
{
for (auto prev = graph[j].getPrev(); prev; prev = prev->getSibling())
{
if (prev > &graph[i]) break;
if (prev < &graph[i]) continue;
if (connectedList[j])
{
connected = true;
goto break_2;
}
}
}
break_2:
connectedList[i] = (connectedList[i] && connected) ? 1 : 0;
fill(connectedList.begin() + scanStart, connectedList.begin() + scanEnd, 1);
}

size_t connectedCnt = accumulate(connectedList.begin(), connectedList.end(), 0);
newIndexDiff[0] = connectedList[0];
for (size_t i = 1; i < graph.size(); ++i)
Expand Down Expand Up @@ -231,10 +219,15 @@ size_t kiwi::splitByTrie(
const PretokenizedSpanGroup::Span* pretokenizedLast
)
{
/*
* endPosMap[i]에는 out[x].endPos == i를 만족하는 첫번째 x(first)와 마지막 x + 1(second)가 들어 있다.
* first == second인 경우 endPos가 i인 노드가 없다는 것을 의미한다.
* first <= x && x < second인 out[x] 중에는 endPos가 i가 아닌 것도 있을 수 있으므로 주의해야 한다.
*/
thread_local Vector<pair<uint32_t, uint32_t>> endPosMap;
endPosMap.clear();
endPosMap.resize(str.size() + 1, make_pair<uint32_t, uint32_t>(-1, -1));
endPosMap[0] = make_pair(0, 0);
endPosMap[0] = make_pair(0, 1);

thread_local Vector<uint32_t> nonSpaces;
nonSpaces.clear();
Expand All @@ -259,7 +252,8 @@ size_t kiwi::splitByTrie(
for (auto& cand : candidates)
{
const size_t nBegin = typoTolerant ? candTypoCostStarts[&cand - candidates.data()].start : (nonSpaces.size() - cand->sizeWithoutSpace());
const bool longestMatched = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nBegin].first, (uint32_t)1), scanEnd = endPosMap[nBegin].second;
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nBegin == g.endPos && lastSpecialEndPos == g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
});
Expand Down Expand Up @@ -335,7 +329,8 @@ size_t kiwi::splitByTrie(
}
}

bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[unkFormEndPos].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
return startPos == lastSpecialEndPos && g.endPos == unkFormEndPos;
Expand Down Expand Up @@ -483,7 +478,8 @@ size_t kiwi::splitByTrie(
// sequence of speical characters found
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && lastChrType != lastMatchedPattern)
{
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nonSpaces.size() == g.endPos;
});
Expand Down Expand Up @@ -514,6 +510,12 @@ size_t kiwi::splitByTrie(
break;
}
}
// 혹은 공백 문자가 아예 없는 경우 너무 길어지는 것을 방지하기 위해 강제로 중단
else if (n >= 8192)
{
lastChrType = chrType;
break;
}

// 공백문자를 무시하고 분할 진행
if (chrType == POSTag::unknown)
Expand Down Expand Up @@ -629,7 +631,8 @@ size_t kiwi::splitByTrie(
// sequence of speical characters found
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && !isWebTag(lastChrType))
{
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nonSpaces.size() == g.endPos;
});
Expand Down Expand Up @@ -661,7 +664,7 @@ size_t kiwi::splitByTrie(

nonSpaces.emplace_back(n);

removeUnconnected(ret, out);
removeUnconnected(ret, out, endPosMap);
for (size_t i = 1; i < ret.size() - 1; ++i)
{
auto& r = ret[i];
Expand Down
Loading