From 01a22a8b74da5e36b1b06f5efe5f291f8dcfceea Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 12 Oct 2024 20:29:11 +0900 Subject: [PATCH] Add new features to `HSDataset` --- include/kiwi/Dataset.h | 7 +++- include/kiwi/Kiwi.h | 7 +++- src/Dataset.cpp | 51 +++++++++++++++++++------ src/KiwiBuilder.cpp | 86 +++++++++++++++++++++++++++++++++--------- test/test_cpp.cpp | 2 +- 5 files changed, 121 insertions(+), 32 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 6ce86c38..32be4879 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -47,16 +47,19 @@ namespace kiwi HiddenMember, sizeof(Vector) * 2> sents; std::shared_ptr knlm; std::unique_ptr workers; + std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; Vector tokenToVocab, vocabToToken; + Vector windowTokenValidness; Deque> futures; const Vector* morphemes = nullptr; const Vector* forms = nullptr; size_t knlmVocabSize = 0; size_t batchSize = 0; + size_t causalContextSize = 0; size_t windowSize = 0; size_t totalTokens = 0; size_t passedSents = 0; @@ -68,7 +71,7 @@ namespace kiwi size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); public: - HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0); + HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; @@ -80,7 +83,9 @@ namespace kiwi size_t numTokens() const; size_t getBatchSize() const { return batchSize; } + size_t getCausalContextSize() const { return causalContextSize; } size_t getWindowSize() const { return windowSize; } + const Vector& getWindowTokenValidness() const { return windowTokenValidness; } void seed(size_t newSeed); void reset(); diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 7cec3b2b..44a14d5c 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -548,6 +548,8 @@ namespace kiwi using MorphemeMap = UnorderedMap, std::pair>; + void initMorphemes(); + template MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter); @@ -801,11 +803,14 @@ namespace kiwi using TokenFilter = std::function; HSDataset makeHSDataset(const std::vector& inputPathes, - size_t batchSize, size_t windowSize, size_t numWorkers, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, const TokenFilter& tokenFilter = {}, + const TokenFilter& windowFilter = {}, double splitRatio = 0, bool separateDefaultMorpheme = false, + const std::string& morphemeDefPath = {}, + size_t morphemeDefMinCnt = 0, HSDataset* splitDataset = nullptr ) const; }; diff --git a/src/Dataset.cpp b/src/Dataset.cpp index bf9bd1a4..f3e47edd 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -3,11 +3,12 @@ using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb) +HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb) : workers{ _workers ? make_unique(_workers) : nullptr }, dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, + causalContextSize{ _causalContextSize }, windowSize{ _windowSize } { } @@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.lmLProbsBuf.resize(tokens.size()); local.outNgramNodeBuf.resize(tokens.size()); - knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); + if (knlm) + { + knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); + } auto& history = local.historyBuf; history.clear(); - history.resize(windowSize, -1); - history.back() = tokenToVocab[tokens[0]]; + if (windowSize) + { + history.resize(windowSize, -1); + if (windowTokenValidness[tokens[0]]) + { + history.back() = tokenToVocab[tokens[0]]; + } + } for (size_t i = 1; i < tokens.size(); ++i) { int32_t v = tokenToVocab[tokens[i]]; @@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.restLmLProbsCntData[r] += 1; continue; } - std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); + + if (causalContextSize) + { + for (size_t j = 0; j < causalContextSize; ++j) + { + local.inData.emplace_back(i + j < causalContextSize ? + nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]); + } + } + if (windowSize) + { + if (windowTokenValidness[v]) + { + std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); + history.pop_front(); + history.push_back(v); + } + else + { + local.inData.resize(local.inData.size() + windowSize, -1); + } + } + local.outData.emplace_back(v); local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]); local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]); - - history.pop_front(); - history.push_back(v); } size_t r = local.outData.size() / batchSize; @@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, auto& l = locals[localId]; size_t rest = std::min(l.outData.size(), batchSize); - std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in); + std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in); std::copy(l.outData.begin(), l.outData.begin() + rest, out); std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs); std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode); restLmOut = l.restLmLProbsData.front(); restLmCntOut = l.restLmLProbsCntData.front(); - l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize); + l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize)); l.outData.erase(l.outData.begin(), l.outData.begin() + rest); l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest); l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest); @@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN size_t HSDataset::ngramNodeSize() const { - return knlm->nonLeafNodeSize(); + return knlm ? knlm->nonLeafNodeSize() : 0; } const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index b216fff5..10a798e3 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -784,10 +784,8 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio } } -KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +void KiwiBuilder::initMorphemes() { - archType = getSelectedArch(ArchType::default_); - forms.resize(defaultFormSize); morphemes.resize(defaultFormSize + 2); // additional places for , for (size_t i = 1; i < defaultTagSize; ++i) @@ -805,6 +803,18 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) morphemes[i + defaultTagSize + 1].userScore = -1.5f; } +} + +KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +{ + if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) + { + throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; + } + + archType = getSelectedArch(ArchType::default_); + initMorphemes(); + ifstream ifs; auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) { @@ -2179,43 +2189,72 @@ vector KiwiBuilder::extractAddWords(const U16MultipleReader& reader, s } HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, - size_t batchSize, size_t windowSize, size_t numWorkers, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb, const TokenFilter& tokenFilter, + const TokenFilter& windowFilter, double splitRatio, bool separateDefaultMorpheme, + const string& morphemeDefPath, + size_t morphemeDefMinCnt, HSDataset* splitDataset ) const { - auto realMorph = restoreMorphemeMap(separateDefaultMorpheme); - HSDataset dataset{ batchSize, windowSize, numWorkers, dropoutProb }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; auto& sents = dataset.sents.get(); - dataset.knlm = langMdl.knlm; - dataset.morphemes = &morphemes; - dataset.forms = &forms; + const KiwiBuilder* srcBuilder = this; + MorphemeMap realMorph; + size_t maxTokenId = 0; + if (morphemeDefPath.empty()) + { + realMorph = restoreMorphemeMap(separateDefaultMorpheme); + } + else + { + dataset.dummyBuilder = make_shared(); + dataset.dummyBuilder->initMorphemes(); + ifstream ifs; + realMorph = dataset.dummyBuilder->loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt) + { + return cnt >= morphemeDefMinCnt; + }); + srcBuilder = dataset.dummyBuilder.get(); + + for (auto& p : realMorph) + { + maxTokenId = max(p.second.first + 1, maxTokenId); + } + } + + auto& knlm = srcBuilder->langMdl.knlm; + dataset.knlm = knlm; + dataset.morphemes = &srcBuilder->morphemes; + dataset.forms = &srcBuilder->forms; if (splitDataset) { - *splitDataset = HSDataset{ batchSize, windowSize, numWorkers, dropoutProb }; - splitDataset->knlm = langMdl.knlm; - splitDataset->morphemes = &morphemes; - splitDataset->forms = &forms; + *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; + splitDataset->dummyBuilder = dataset.dummyBuilder; + splitDataset->knlm = knlm; + splitDataset->morphemes = &srcBuilder->morphemes; + splitDataset->forms = &srcBuilder->forms; } for (auto& path : inputPathes) { ifstream ifs; - addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); } - size_t tokenSize = sents.raw().empty() ? 0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1; + size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1; if (splitDataset) { auto& sents = splitDataset->sents.get(); - tokenSize = std::max(tokenSize, sents.raw().empty() ? (size_t)0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1); + tokenSize = max(tokenSize, sents.raw().empty() ? (size_t)0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1); } - const size_t knlmVocabSize = langMdl.knlm->getHeader().vocab_size; + const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId; + tokenSize = max(tokenSize, knlmVocabSize); size_t filteredKnlmVocabSize = 0; for (size_t i = 0; i < tokenSize; ++i) { @@ -2223,7 +2262,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { filteredKnlmVocabSize = dataset.vocabToToken.size(); } - if (tokenFilter && !tokenFilter(joinHangul(forms[morphemes[i].kform].form), morphemes[i].tag)) + + if (windowFilter && !windowFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag)) + { + dataset.windowTokenValidness.emplace_back(0); + } + else + { + dataset.windowTokenValidness.emplace_back(1); + } + + if (tokenFilter && !tokenFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag)) { dataset.tokenToVocab.emplace_back(HSDataset::nonVocab); continue; @@ -2244,6 +2293,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, if (splitDataset) { + splitDataset->windowTokenValidness = dataset.windowTokenValidness; splitDataset->tokenToVocab = dataset.tokenToVocab; splitDataset->vocabToToken = dataset.vocabToToken; splitDataset->knlmVocabSize = dataset.knlmVocabSize; diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index d9605450..602fa73b 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -414,7 +414,7 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, windowSize, 1, 0., tokenFilter, 0.1, false, &devset); + trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., tokenFilter, {}, 0.1, false, {}, 0, &devset); for (size_t i = 0; i < 2; ++i) { {