Skip to content

Commit

Permalink
Add new features to HSDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Oct 12, 2024
1 parent 3b29e28 commit 01a22a8
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 32 deletions.
7 changes: 6 additions & 1 deletion include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ namespace kiwi
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
std::shared_ptr<lm::KnLangModelBase> knlm;
std::unique_ptr<utils::ThreadPool> workers;
std::shared_ptr<KiwiBuilder> dummyBuilder;
std::discrete_distribution<> dropout;
std::mt19937_64 rng;
Vector<ThreadLocal> locals;
Vector<size_t> shuffledIdx;
Vector<int32_t> tokenToVocab, vocabToToken;
Vector<uint8_t> windowTokenValidness;
Deque<OptionalFuture<size_t>> futures;
const Vector<MorphemeRaw>* morphemes = nullptr;
const Vector<FormRaw>* forms = nullptr;
size_t knlmVocabSize = 0;
size_t batchSize = 0;
size_t causalContextSize = 0;
size_t windowSize = 0;
size_t totalTokens = 0;
size_t passedSents = 0;
Expand All @@ -68,7 +71,7 @@ namespace kiwi
size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

public:
HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
~HSDataset();
HSDataset(const HSDataset&) = delete;
HSDataset(HSDataset&&) /*noexcept*/;
Expand All @@ -80,7 +83,9 @@ namespace kiwi
size_t numTokens() const;

size_t getBatchSize() const { return batchSize; }
size_t getCausalContextSize() const { return causalContextSize; }
size_t getWindowSize() const { return windowSize; }
const Vector<uint8_t>& getWindowTokenValidness() const { return windowTokenValidness; }

void seed(size_t newSeed);
void reset();
Expand Down
7 changes: 6 additions & 1 deletion include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ namespace kiwi

using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;

void initMorphemes();

template<class Fn>
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);

Expand Down Expand Up @@ -801,11 +803,14 @@ namespace kiwi
using TokenFilter = std::function<bool(const std::u16string&, POSTag)>;

HSDataset makeHSDataset(const std::vector<std::string>& inputPathes,
size_t batchSize, size_t windowSize, size_t numWorkers,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb = 0,
const TokenFilter& tokenFilter = {},
const TokenFilter& windowFilter = {},
double splitRatio = 0,
bool separateDefaultMorpheme = false,
const std::string& morphemeDefPath = {},
size_t morphemeDefMinCnt = 0,
HSDataset* splitDataset = nullptr
) const;
};
Expand Down
51 changes: 40 additions & 11 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

using namespace kiwi;

HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb)
HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb)
: workers{ _workers ? make_unique<utils::ThreadPool>(_workers) : nullptr },
dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} },
locals( _workers ? workers->size() : 1),
batchSize{ _batchSize },
causalContextSize{ _causalContextSize },
windowSize{ _windowSize }
{
}
Expand Down Expand Up @@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,

local.lmLProbsBuf.resize(tokens.size());
local.outNgramNodeBuf.resize(tokens.size());
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
if (knlm)
{
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
}

auto& history = local.historyBuf;
history.clear();
history.resize(windowSize, -1);
history.back() = tokenToVocab[tokens[0]];
if (windowSize)
{
history.resize(windowSize, -1);
if (windowTokenValidness[tokens[0]])
{
history.back() = tokenToVocab[tokens[0]];
}
}
for (size_t i = 1; i < tokens.size(); ++i)
{
int32_t v = tokenToVocab[tokens[i]];
Expand All @@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
local.restLmLProbsCntData[r] += 1;
continue;
}
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));

if (causalContextSize)
{
for (size_t j = 0; j < causalContextSize; ++j)
{
local.inData.emplace_back(i + j < causalContextSize ?
nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]);
}
}
if (windowSize)
{
if (windowTokenValidness[v])
{
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));
history.pop_front();
history.push_back(v);
}
else
{
local.inData.resize(local.inData.size() + windowSize, -1);
}
}

local.outData.emplace_back(v);
local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]);
local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]);

history.pop_front();
history.push_back(v);
}

size_t r = local.outData.size() / batchSize;
Expand Down Expand Up @@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
auto& l = locals[localId];

size_t rest = std::min(l.outData.size(), batchSize);
std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in);
std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in);
std::copy(l.outData.begin(), l.outData.begin() + rest, out);
std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs);
std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode);
restLmOut = l.restLmLProbsData.front();
restLmCntOut = l.restLmLProbsCntData.front();

l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize);
l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize));
l.outData.erase(l.outData.begin(), l.outData.begin() + rest);
l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest);
l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest);
Expand All @@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN

size_t HSDataset::ngramNodeSize() const
{
return knlm->nonLeafNodeSize();
return knlm ? knlm->nonLeafNodeSize() : 0;
}

const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const
Expand Down
86 changes: 68 additions & 18 deletions src/KiwiBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,8 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio
}
}

KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args)
void KiwiBuilder::initMorphemes()
{
archType = getSelectedArch(ArchType::default_);

forms.resize(defaultFormSize);
morphemes.resize(defaultFormSize + 2); // additional places for <s>, </s>
for (size_t i = 1; i < defaultTagSize; ++i)
Expand All @@ -805,6 +803,18 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args)
morphemes[i + defaultTagSize + 1].userScore = -1.5f;
}

}

KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args)
{
if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder))
{
throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" };
}

archType = getSelectedArch(ArchType::default_);
initMorphemes();

ifstream ifs;
auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt)
{
Expand Down Expand Up @@ -2179,51 +2189,90 @@ vector<WordInfo> KiwiBuilder::extractAddWords(const U16MultipleReader& reader, s
}

HSDataset KiwiBuilder::makeHSDataset(const vector<string>& inputPathes,
size_t batchSize, size_t windowSize, size_t numWorkers,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb,
const TokenFilter& tokenFilter,
const TokenFilter& windowFilter,
double splitRatio,
bool separateDefaultMorpheme,
const string& morphemeDefPath,
size_t morphemeDefMinCnt,
HSDataset* splitDataset
) const
{
auto realMorph = restoreMorphemeMap(separateDefaultMorpheme);
HSDataset dataset{ batchSize, windowSize, numWorkers, dropoutProb };
HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb };
auto& sents = dataset.sents.get();
dataset.knlm = langMdl.knlm;
dataset.morphemes = &morphemes;
dataset.forms = &forms;
const KiwiBuilder* srcBuilder = this;
MorphemeMap realMorph;
size_t maxTokenId = 0;
if (morphemeDefPath.empty())
{
realMorph = restoreMorphemeMap(separateDefaultMorpheme);
}
else
{
dataset.dummyBuilder = make_shared<KiwiBuilder>();
dataset.dummyBuilder->initMorphemes();
ifstream ifs;
realMorph = dataset.dummyBuilder->loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt)
{
return cnt >= morphemeDefMinCnt;
});
srcBuilder = dataset.dummyBuilder.get();

for (auto& p : realMorph)
{
maxTokenId = max(p.second.first + 1, maxTokenId);
}
}

auto& knlm = srcBuilder->langMdl.knlm;
dataset.knlm = knlm;
dataset.morphemes = &srcBuilder->morphemes;
dataset.forms = &srcBuilder->forms;

if (splitDataset)
{
*splitDataset = HSDataset{ batchSize, windowSize, numWorkers, dropoutProb };
splitDataset->knlm = langMdl.knlm;
splitDataset->morphemes = &morphemes;
splitDataset->forms = &forms;
*splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb };
splitDataset->dummyBuilder = dataset.dummyBuilder;
splitDataset->knlm = knlm;
splitDataset->morphemes = &srcBuilder->morphemes;
splitDataset->forms = &srcBuilder->forms;
}

for (auto& path : inputPathes)
{
ifstream ifs;
addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr);
srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr);
}
size_t tokenSize = sents.raw().empty() ? 0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1;
size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1;

if (splitDataset)
{
auto& sents = splitDataset->sents.get();
tokenSize = std::max(tokenSize, sents.raw().empty() ? (size_t)0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1);
tokenSize = max(tokenSize, sents.raw().empty() ? (size_t)0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1);
}

const size_t knlmVocabSize = langMdl.knlm->getHeader().vocab_size;
const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId;
tokenSize = max(tokenSize, knlmVocabSize);
size_t filteredKnlmVocabSize = 0;
for (size_t i = 0; i < tokenSize; ++i)
{
if (i == knlmVocabSize)
{
filteredKnlmVocabSize = dataset.vocabToToken.size();
}
if (tokenFilter && !tokenFilter(joinHangul(forms[morphemes[i].kform].form), morphemes[i].tag))

if (windowFilter && !windowFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag))
{
dataset.windowTokenValidness.emplace_back(0);
}
else
{
dataset.windowTokenValidness.emplace_back(1);
}

if (tokenFilter && !tokenFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag))
{
dataset.tokenToVocab.emplace_back(HSDataset::nonVocab);
continue;
Expand All @@ -2244,6 +2293,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector<string>& inputPathes,

if (splitDataset)
{
splitDataset->windowTokenValidness = dataset.windowTokenValidness;
splitDataset->tokenToVocab = dataset.tokenToVocab;
splitDataset->vocabToToken = dataset.vocabToToken;
splitDataset->knlmVocabSize = dataset.knlmVocabSize;
Expand Down
2 changes: 1 addition & 1 deletion test/test_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ TEST(KiwiCpp, HSDataset)
};

HSDataset trainset, devset;
trainset = kw.makeHSDataset(data, batchSize, windowSize, 1, 0., tokenFilter, 0.1, false, &devset);
trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., tokenFilter, {}, 0.1, false, {}, 0, &devset);
for (size_t i = 0; i < 2; ++i)
{
{
Expand Down

0 comments on commit 01a22a8

Please sign in to comment.