Skip to content

Commit

Permalink
improved efficiency of splitting long text
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Mar 3, 2024
1 parent 6f5c94b commit 2c0ee92
Showing 1 changed file with 44 additions and 47 deletions.
91 changes: 44 additions & 47 deletions src/KTrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace kiwi
{
static constexpr uint32_t npos = -1;

if (endPosMap[startPos].first == npos)
if (endPosMap[startPos].first == endPosMap[startPos].second)
{
return false;
}
Expand All @@ -30,15 +30,15 @@ namespace kiwi
nnode.prev = newId - endPosMap[startPos].first;
if (nnode.endPos >= endPosMap.size()) return true;

if (endPosMap[nnode.endPos].first == npos)
if (endPosMap[nnode.endPos].first == endPosMap[nnode.endPos].second)
{
endPosMap[nnode.endPos].first = newId;
endPosMap[nnode.endPos].second = newId;
endPosMap[nnode.endPos].second = newId + 1;
}
else
{
nodes[endPosMap[nnode.endPos].second].sibling = newId - endPosMap[nnode.endPos].second;
endPosMap[nnode.endPos].second = newId;
nodes[endPosMap[nnode.endPos].second - 1].sibling = newId - (endPosMap[nnode.endPos].second - 1);
endPosMap[nnode.endPos].second = newId + 1;
}
return true;
}
Expand Down Expand Up @@ -118,46 +118,34 @@ namespace kiwi
return true;
}

inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph)
inline void removeUnconnected(Vector<KGraphNode>& ret, const Vector<KGraphNode>& graph, const Vector<std::pair<uint32_t, uint32_t>>& endPosMap)
{
Vector<uint8_t> connectedList(graph.size());
Vector<uint16_t> newIndexDiff(graph.size());
connectedList[graph.size() - 1] = true;
connectedList[0] = true;
// forward searching
for (size_t i = 1; i < graph.size(); ++i)
{
bool connected = false;
for (auto prev = graph[i].getPrev(); prev; prev = prev->getSibling())
thread_local Vector<uint8_t> connectedList;
thread_local Vector<uint16_t> newIndexDiff;
thread_local Deque<uint32_t> updateList;
connectedList.clear();
connectedList.resize(graph.size());
newIndexDiff.clear();
newIndexDiff.resize(graph.size());
updateList.clear();
updateList.emplace_back(graph.size() - 1);
connectedList[graph.size() - 1] = 1;

while (!updateList.empty())
{
const auto id = updateList.front();
updateList.pop_front();
const auto& node = graph[id];
const auto scanStart = endPosMap[node.startPos].first, scanEnd = endPosMap[node.startPos].second;
for (auto i = scanStart; i < scanEnd; ++i)
{
if (connectedList[prev - graph.data()])
{
connected = true;
break;
}
if (graph[i].endPos != node.startPos) continue;
if (connectedList[i]) continue;
updateList.emplace_back(i);
}
connectedList[i] = connected ? 1 : 0;
}
// backward searching
for (size_t i = graph.size() - 1; i-- > 1; )
{
bool connected = false;
for (size_t j = i + 1; j < graph.size(); ++j)
{
for (auto prev = graph[j].getPrev(); prev; prev = prev->getSibling())
{
if (prev > &graph[i]) break;
if (prev < &graph[i]) continue;
if (connectedList[j])
{
connected = true;
goto break_2;
}
}
}
break_2:
connectedList[i] = (connectedList[i] && connected) ? 1 : 0;
fill(connectedList.begin() + scanStart, connectedList.begin() + scanEnd, 1);
}

size_t connectedCnt = accumulate(connectedList.begin(), connectedList.end(), 0);
newIndexDiff[0] = connectedList[0];
for (size_t i = 1; i < graph.size(); ++i)
Expand Down Expand Up @@ -231,10 +219,15 @@ size_t kiwi::splitByTrie(
const PretokenizedSpanGroup::Span* pretokenizedLast
)
{
/*
* endPosMap[i]에는 out[x].endPos == i를 만족하는 첫번째 x(first)와 마지막 x + 1(second)가 들어 있다.
* first == second인 경우 endPos가 i인 노드가 없다는 것을 의미한다.
* first <= x && x < second인 out[x] 중에는 endPos가 i가 아닌 것도 있을 수 있으므로 주의해야 한다.
*/
thread_local Vector<pair<uint32_t, uint32_t>> endPosMap;
endPosMap.clear();
endPosMap.resize(str.size() + 1, make_pair<uint32_t, uint32_t>(-1, -1));
endPosMap[0] = make_pair(0, 0);
endPosMap[0] = make_pair(0, 1);

thread_local Vector<uint32_t> nonSpaces;
nonSpaces.clear();
Expand All @@ -259,7 +252,8 @@ size_t kiwi::splitByTrie(
for (auto& cand : candidates)
{
const size_t nBegin = typoTolerant ? candTypoCostStarts[&cand - candidates.data()].start : (nonSpaces.size() - cand->sizeWithoutSpace());
const bool longestMatched = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nBegin].first, (uint32_t)1), scanEnd = endPosMap[nBegin].second;
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nBegin == g.endPos && lastSpecialEndPos == g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
});
Expand Down Expand Up @@ -335,7 +329,8 @@ size_t kiwi::splitByTrie(
}
}

bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[unkFormEndPos].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size());
return startPos == lastSpecialEndPos && g.endPos == unkFormEndPos;
Expand Down Expand Up @@ -483,7 +478,8 @@ size_t kiwi::splitByTrie(
// sequence of speical characters found
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && lastChrType != lastMatchedPattern)
{
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nonSpaces.size() == g.endPos;
});
Expand Down Expand Up @@ -635,7 +631,8 @@ size_t kiwi::splitByTrie(
// sequence of speical characters found
if (lastChrType != POSTag::max && lastChrType != POSTag::unknown && !isWebTag(lastChrType))
{
bool duplicated = any_of(out.begin() + 1, out.end(), [&](const KGraphNode& g)
const auto scanStart = max(endPosMap[nonSpaces.size()].first, (uint32_t)1), scanEnd = endPosMap[nonSpaces.size()].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
return nonSpaces.size() == g.endPos;
});
Expand Down Expand Up @@ -667,7 +664,7 @@ size_t kiwi::splitByTrie(

nonSpaces.emplace_back(n);

removeUnconnected(ret, out);
removeUnconnected(ret, out, endPosMap);
for (size_t i = 1; i < ret.size() - 1; ++i)
{
auto& r = ret[i];
Expand Down

0 comments on commit 2c0ee92

Please sign in to comment.