Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

자잘한 버그 수정 #194

Merged
merged 9 commits into from
Oct 16, 2024
7 changes: 6 additions & 1 deletion include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ namespace kiwi
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
std::shared_ptr<lm::KnLangModelBase> knlm;
std::unique_ptr<utils::ThreadPool> workers;
std::shared_ptr<KiwiBuilder> dummyBuilder;
std::discrete_distribution<> dropout;
std::mt19937_64 rng;
Vector<ThreadLocal> locals;
Vector<size_t> shuffledIdx;
Vector<int32_t> tokenToVocab, vocabToToken;
Vector<uint8_t> windowTokenValidness;
Deque<OptionalFuture<size_t>> futures;
const Vector<MorphemeRaw>* morphemes = nullptr;
const Vector<FormRaw>* forms = nullptr;
size_t knlmVocabSize = 0;
size_t batchSize = 0;
size_t causalContextSize = 0;
size_t windowSize = 0;
size_t totalTokens = 0;
size_t passedSents = 0;
Expand All @@ -68,7 +71,7 @@ namespace kiwi
size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

public:
HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
~HSDataset();
HSDataset(const HSDataset&) = delete;
HSDataset(HSDataset&&) /*noexcept*/;
Expand All @@ -80,7 +83,9 @@ namespace kiwi
size_t numTokens() const;

size_t getBatchSize() const { return batchSize; }
size_t getCausalContextSize() const { return causalContextSize; }
size_t getWindowSize() const { return windowSize; }
const Vector<uint8_t>& getWindowTokenValidness() const { return windowTokenValidness; }

void seed(size_t newSeed);
void reset();
Expand Down
10 changes: 7 additions & 3 deletions include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ namespace kiwi

using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;

void initMorphemes();

template<class Fn>
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);

Expand Down Expand Up @@ -612,8 +614,7 @@ namespace kiwi
std::vector<std::string> corpora;
size_t minMorphCnt = 10;
size_t lmOrder = 4;
size_t lmMinCnt = 1;
size_t lmLastOrderMinCnt = 2;
std::vector<size_t> lmMinCnts = { 1 };
size_t numWorkers = 1;
size_t sbgSize = 1000000;
bool useLmTagHistory = true;
Expand Down Expand Up @@ -801,11 +802,14 @@ namespace kiwi
using TokenFilter = std::function<bool(const std::u16string&, POSTag)>;

HSDataset makeHSDataset(const std::vector<std::string>& inputPathes,
size_t batchSize, size_t windowSize, size_t numWorkers,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb = 0,
const TokenFilter& tokenFilter = {},
const TokenFilter& windowFilter = {},
double splitRatio = 0,
bool separateDefaultMorpheme = false,
const std::string& morphemeDefPath = {},
size_t morphemeDefMinCnt = 0,
HSDataset* splitDataset = nullptr
) const;
};
Expand Down
2 changes: 1 addition & 1 deletion include/kiwi/TagUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace kiwi
inline bool isSuffix(POSTag tag)
{
tag = clearIrregular(tag);
return POSTag::xsn <= tag && tag <= POSTag::xsa;
return POSTag::xsn <= tag && tag <= POSTag::xsm;
}

inline bool isSpecialClass(POSTag tag)
Expand Down
51 changes: 40 additions & 11 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

using namespace kiwi;

HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb)
HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb)
: workers{ _workers ? make_unique<utils::ThreadPool>(_workers) : nullptr },
dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} },
locals( _workers ? workers->size() : 1),
batchSize{ _batchSize },
causalContextSize{ _causalContextSize },
windowSize{ _windowSize }
{
}
Expand Down Expand Up @@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,

local.lmLProbsBuf.resize(tokens.size());
local.outNgramNodeBuf.resize(tokens.size());
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
if (knlm)
{
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
}

auto& history = local.historyBuf;
history.clear();
history.resize(windowSize, -1);
history.back() = tokenToVocab[tokens[0]];
if (windowSize)
{
history.resize(windowSize, -1);
if (windowTokenValidness[tokens[0]])
{
history.back() = tokenToVocab[tokens[0]];
}
}
for (size_t i = 1; i < tokens.size(); ++i)
{
int32_t v = tokenToVocab[tokens[i]];
Expand All @@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
local.restLmLProbsCntData[r] += 1;
continue;
}
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));

if (causalContextSize)
{
for (size_t j = 0; j < causalContextSize; ++j)
{
local.inData.emplace_back(i + j < causalContextSize ?
nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]);
}
}
if (windowSize)
{
if (windowTokenValidness[v])
{
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));
history.pop_front();
history.push_back(v);
}
else
{
local.inData.resize(local.inData.size() + windowSize, -1);
}
}

local.outData.emplace_back(v);
local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]);
local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]);

history.pop_front();
history.push_back(v);
}

size_t r = local.outData.size() / batchSize;
Expand Down Expand Up @@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
auto& l = locals[localId];

size_t rest = std::min(l.outData.size(), batchSize);
std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in);
std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in);
std::copy(l.outData.begin(), l.outData.begin() + rest, out);
std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs);
std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode);
restLmOut = l.restLmLProbsData.front();
restLmCntOut = l.restLmLProbsCntData.front();

l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize);
l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize));
l.outData.erase(l.outData.begin(), l.outData.begin() + rest);
l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest);
l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest);
Expand All @@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN

size_t HSDataset::ngramNodeSize() const
{
return knlm->nonLeafNodeSize();
return knlm ? knlm->nonLeafNodeSize() : 0;
}

const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const
Expand Down
28 changes: 19 additions & 9 deletions src/KTrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,8 @@ size_t kiwi::splitByTrie(
const auto scanStart = max(endPosMap[nBeginWithMultiplier].first, (uint32_t)1), scanEnd = endPosMap[nBeginWithMultiplier].second;
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
const auto start = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos == start || specialStartPos == start);
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos * posMultiplier == startPos || specialStartPos * posMultiplier == startPos);
});

// insert unknown form
Expand Down Expand Up @@ -742,7 +742,7 @@ size_t kiwi::splitByTrie(
const auto scanStart = max(endPosMap[unkFormEndPos * posMultiplier].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos * posMultiplier].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return startPos == lastSpecialEndPos * posMultiplier && g.endPos == unkFormEndPos * posMultiplier;
});
if (unkFormEndPos > lastSpecialEndPos && !duplicated)
Expand Down Expand Up @@ -1215,9 +1215,10 @@ size_t kiwi::splitByTrie(
return n + startOffset;
}

template<ArchType arch>
template<ArchType arch, bool typoTolerant>
const Form* kiwi::findForm(
const utils::FrozenTrie<kchar_t, const Form*>& trie,
const Form* formData,
const KString& str
)
{
Expand All @@ -1228,7 +1229,12 @@ const Form* kiwi::findForm(
if (!node) return nullptr;
}
if (trie.hasSubmatch(node->val(trie))) return nullptr;
return node->val(trie);
auto ret = node->val(trie);
if (typoTolerant)
{
ret = &reinterpret_cast<const TypoForm*>(ret)->form(formData);
}
return ret;
}

namespace kiwi
Expand Down Expand Up @@ -1266,19 +1272,23 @@ FnSplitByTrie kiwi::getSplitByTrieFn(ArchType arch, bool typoTolerant, bool cont

namespace kiwi
{
template<bool typoTolerant>
struct FindFormGetter
{
template<std::ptrdiff_t i>
struct Wrapper
{
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i)>;
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i), typoTolerant>;
};
};
}

FnFindForm kiwi::getFindFormFn(ArchType arch)
FnFindForm kiwi::getFindFormFn(ArchType arch, bool typoTolerant)
{
static tp::Table<FnFindForm, AvailableArch> table{ FindFormGetter{} };
static std::array<tp::Table<FnFindForm, AvailableArch>, 2> table{
FindFormGetter<false>{},
FindFormGetter<true>{},
};

return table[static_cast<std::ptrdiff_t>(arch)];
return table[typoTolerant ? 1 : 0][static_cast<std::ptrdiff_t>(arch)];
}
7 changes: 4 additions & 3 deletions src/KTrie.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ namespace kiwi
const PretokenizedSpanGroup::Span* pretokenizedLast
);

template<ArchType arch>
template<ArchType arch, bool typoTolerant>
const Form* findForm(
const utils::FrozenTrie<kchar_t, const Form*>& trie,
const Form* formData,
const KString& str
);

using FnSplitByTrie = decltype(&splitByTrie<ArchType::default_>);
FnSplitByTrie getSplitByTrieFn(ArchType arch, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant);

using FnFindForm = decltype(&findForm<ArchType::default_>);
FnFindForm getFindFormFn(ArchType arch);
using FnFindForm = decltype(&findForm<ArchType::default_, false>);
FnFindForm getFindFormFn(ArchType arch, bool typoTolerant);

struct KTrie : public utils::TrieNode<char16_t, const Form*, utils::ConstAccess<map<char16_t, int32_t>>, KTrie>
{
Expand Down
16 changes: 9 additions & 7 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace kiwi
typoTolerant,
continualTypoTolerant,
lengtheningTypoTolerant);
dfFindForm = (void*)getFindFormFn(selectedArch);
dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant);

static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_8{ FindBestPathGetter<WrappedKnLM<uint8_t>::type>{} };
static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_16{ FindBestPathGetter<WrappedKnLM<uint16_t>::type>{} };
Expand Down Expand Up @@ -802,7 +802,8 @@ namespace kiwi
const Vector<uint32_t>& positionTable,
const KString& normStr,
FnFindForm findForm,
const utils::FrozenTrie<kchar_t, const Form*>& formTrie
const utils::FrozenTrie<kchar_t, const Form*>& formTrie,
const Form* formData
)
{
if (pretokenized.empty()) return;
Expand Down Expand Up @@ -833,7 +834,7 @@ namespace kiwi
if (s.tokenization.empty())
{
auto formStr = normStr.substr(span.begin, span.end - span.begin);
span.form = findForm(formTrie, formStr); // reuse the predefined form & morpheme
span.form = findForm(formTrie, formData, formStr); // reuse the predefined form & morpheme
if (!span.form) // or use a fallback form
{
span.form = formTrie.value((size_t)POSTag::nnp);
Expand All @@ -842,7 +843,7 @@ namespace kiwi
else if (s.tokenization.size() == 1)
{
auto formStr = normalizeHangul(s.tokenization[0].form);
auto* tform = findForm(formTrie, formStr);
auto* tform = findForm(formTrie, formData, formStr);
if (tform && tform->candidate.size() == 1 &&
areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity))
// reuse the predefined form & morpheme
Expand Down Expand Up @@ -908,7 +909,7 @@ namespace kiwi
{
auto& t = s.tokenization[i];
auto formStr = normalizeHangul(t.form);
auto* tform = findForm(formTrie, formStr);
auto* tform = findForm(formTrie, formData, formStr);
const Morpheme* foundMorph = nullptr;
if (tform)
{
Expand Down Expand Up @@ -999,7 +1000,8 @@ namespace kiwi
positionTable,
normalizedStr,
reinterpret_cast<FnFindForm>(dfFindForm),
formTrie
formTrie,
forms.data()
);

// 분석할 문장에 포함된 개별 문자에 대해 어절번호를 생성한다
Expand Down Expand Up @@ -1317,7 +1319,7 @@ namespace kiwi
void Kiwi::findMorpheme(vector<const Morpheme*>& ret, const u16string& s, POSTag tag) const
{
auto normalized = normalizeHangul(s);
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, normalized);
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, forms.data(), normalized);
if (!form) return;
tag = clearIrregular(tag);
for (auto c : form->candidate)
Expand Down
Loading
Loading