Skip to content

Commit

Permalink
Merge pull request #177 from bab2min/dev_fix176
Browse files Browse the repository at this point in the history
Minor Fix including #176
  • Loading branch information
bab2min authored Aug 30, 2024
2 parents 5a20486 + 90b121f commit 2696a9f
Show file tree
Hide file tree
Showing 22 changed files with 807 additions and 175 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set ( CORE_SRCS
src/Joiner.cpp
src/Kiwi.cpp
src/KiwiBuilder.cpp
src/Knlm.cpp
src/KTrie.cpp
src/PatternMatcher.cpp
src/search.cpp
Expand Down
27 changes: 27 additions & 0 deletions include/kiwi/FrozenTrie.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ namespace kiwi
std::unique_ptr<Key[]> nextKeys;
std::unique_ptr<Diff[]> nextDiffs;

template<class Fn>
void traverse(Fn&& visitor, const Node* node, std::vector<Key>& prefix, size_t maxDepth) const
{
auto* keys = &nextKeys[node->nextOffset];
auto* diffs = &nextDiffs[node->nextOffset];
for (size_t i = 0; i < node->numNexts; ++i)
{
const auto* child = node + diffs[i];
const auto val = child->val(*this);
if (!hasMatch(val)) continue;
prefix.emplace_back(keys[i]);
visitor(val, prefix);
if (prefix.size() < maxDepth)
{
traverse(visitor, child, prefix, maxDepth);
}
prefix.pop_back();
}
}

public:

FrozenTrie() = default;
Expand All @@ -117,6 +137,13 @@ namespace kiwi
const Value& value(size_t idx) const { return values[idx]; };

bool hasMatch(_Value v) const { return !this->isNull(v) && !this->hasSubmatch(v); }

template<class Fn>
void traverse(Fn&& visitor, size_t maxDepth = -1) const
{
std::vector<Key> prefix;
traverse(std::forward<Fn>(visitor), root(), prefix, maxDepth);
}
};
}
}
34 changes: 29 additions & 5 deletions include/kiwi/Knlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <algorithm>
#include <numeric>

#include "Utils.h"
#include "Mmap.h"
#include "ArchUtils.h"

namespace kiwi
{
Expand All @@ -20,6 +22,7 @@ namespace kiwi
uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset;
uint64_t unk_id, bos_id, eos_id, vocab_size;
uint8_t order, key_size, diff_size, quantized;
uint32_t extra_buf_size;
};

template<class KeyType, class DiffType = int32_t>
Expand All @@ -43,6 +46,7 @@ namespace kiwi
virtual float _progress(ptrdiff_t& node_idx, size_t next) const = 0;
virtual std::vector<float> allNextLL(ptrdiff_t node_idx) const = 0;
virtual std::vector<float> allNextLL(ptrdiff_t node_idx, std::vector<ptrdiff_t>& next_node_idx) const = 0;
virtual void nextTopN(ptrdiff_t node_idx, size_t top_n, uint32_t* idx_out, float* ll_out) const = 0;

public:

Expand All @@ -55,21 +59,28 @@ namespace kiwi
virtual size_t llSize() const = 0;
virtual const float* getLLBuf() const = 0;
virtual const float* getGammaBuf() const = 0;
virtual const void* getExtraBuf() const = 0;

static std::unique_ptr<KnLangModelBase> create(utils::MemoryObject&& mem, ArchType archType = ArchType::none);

template<class TrieNode, class HistoryTx = std::vector<Vid>>
static utils::MemoryOwner build(const utils::ContinuousTrie<TrieNode>& ngram_cf,
size_t order, size_t min_cf, size_t last_min_cf,
template<class Trie, class HistoryTx = std::vector<Vid>>
static utils::MemoryOwner build(Trie&& ngram_cf,
size_t order, const std::vector<size_t>& min_cf_by_order,
size_t unk_id, size_t bos_id, size_t eos_id,
float unigram_alpha, size_t quantize, bool compress,
const std::vector<std::pair<Vid, Vid>>* bigram_list = nullptr,
const HistoryTx* historyTransformer = nullptr
const HistoryTx* history_transformer = nullptr,
const void* extra_buf = nullptr,
size_t extra_buf_size = 0
);

const utils::MemoryObject& getMemory() const { return base; }

//virtual float progress(ptrdiff_t& node_idx, size_t next) const = 0;
template<class Ty>
float progress(ptrdiff_t& node_idx, Ty next) const
{
return _progress(node_idx, next);
}

template<class InTy, class OutTy>
void evaluate(InTy in_first, InTy in_last, OutTy out_first) const
Expand Down Expand Up @@ -130,6 +141,19 @@ namespace kiwi
}
}

template<class InTy>
void predictTopN(InTy in_first, InTy in_last, size_t top_n, uint32_t* idx_out, float* ll_out) const
{
ptrdiff_t node_idx = 0;
for (; in_first != in_last; ++in_first)
{
_progress(node_idx, *in_first);
nextTopN(node_idx, top_n, idx_out, ll_out);
idx_out += top_n;
ll_out += top_n;
}
}

template<class PfTy, class SfTy, class OutTy>
void fillIn(PfTy prefix_first, PfTy prefix_last, SfTy suffix_first, SfTy suffix_last, OutTy out_first, bool reduce = true) const
{
Expand Down
5 changes: 5 additions & 0 deletions include/kiwi/Mmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ namespace kiwi
setp(epptr() + off, epptr());
else if (dir == std::ios_base::beg)
setp(pbase() + off, epptr());

if (!(which & std::ios_base::in))
{
return pptr() - pbase();
}
}
return gptr() - eback();
}
Expand Down
49 changes: 49 additions & 0 deletions include/kiwi/SubstringExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <vector>
#include <string>

#include <kiwi/FrozenTrie.h>
#include <kiwi/Knlm.h>

namespace kiwi
{
std::vector<std::pair<std::u16string, size_t>> extractSubstrings(
Expand All @@ -13,4 +16,50 @@ namespace kiwi
size_t maxLength = 32,
bool longestOnly = true,
char16_t stopChr = 0);


class PrefixCounter
{
size_t prefixSize = 0, minCf = 0, numArrays = 0;
UnorderedMap<uint32_t, uint32_t> token2id;
Vector<uint32_t> id2Token;
Vector<uint16_t> buf;
Vector<size_t> tokenClusters;
Vector<size_t> tokenCnts;
std::shared_ptr<void> threadPool;

template<class It>
void _addArray(It first, It last);

Vector<std::pair<uint32_t, float>> computeClusterScore() const;

public:
PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers,
const std::vector<std::vector<size_t>>& clusters = {}
);
void addArray(const uint16_t* first, const uint16_t* last);
void addArray(const uint32_t* first, const uint32_t* last);
void addArray(const uint64_t* first, const uint64_t* last);
utils::FrozenTrie<uint32_t, uint32_t> count() const;
std::unique_ptr<lm::KnLangModelBase> buildLM(
const std::vector<size_t>& minCfByOrder,
size_t bosTokenId,
size_t eosTokenId,
size_t unkTokenId,
ArchType archType = ArchType::none
) const;
};

class ClusterData
{
const std::pair<uint32_t, float>* clusterScores = nullptr;
size_t clusterSize = 0;
public:
ClusterData();
ClusterData(const void* _ptr, size_t _size);

size_t size() const;
size_t cluster(size_t i) const;
float score(size_t i) const;
};
}
25 changes: 25 additions & 0 deletions include/kiwi/Trie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ namespace kiwi
return;
}

template<typename _Fn, typename _CKey>
void traverse(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
{
fn(this->val, rkeys);

if (rkeys.size() >= maxDepth) return;

for (auto& p : next)
{
if (ignoreNegative ? (p.second > 0) : (p.second))
{
rkeys.emplace_back(p.first);
getNext(p.first)->traverse(fn, rkeys, maxDepth, ignoreNegative);
rkeys.pop_back();
}
}
}

template<typename _Fn, typename _CKey>
void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
{
Expand Down Expand Up @@ -462,6 +480,13 @@ namespace kiwi
return nodes[0].fillFail(std::forward<HistoryTx>(htx), ignoreNegative);
}

template<typename _Fn>
void traverse(_Fn&& fn, size_t maxDepth = -1, bool ignoreNegative = false) const
{
std::vector<typename Node::Key> rkeys;
return nodes[0].traverse(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative);
}

template<typename _Fn, typename _CKey>
void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const
{
Expand Down
40 changes: 39 additions & 1 deletion include/kiwi/Utils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once
#pragma once
#include <iostream>
#include <string>
#include <memory>
#include <array>
#include "Types.h"

namespace kiwi
Expand Down Expand Up @@ -82,6 +83,11 @@ namespace kiwi
return within(chr, 0x302E, 0x3030);
}

inline bool isCompatibleHangulConsonant(char16_t chr)
{
return within(chr, 0x3131, 0x314E) || within(chr, 0x3165, 0x3186);
}

struct ComparatorIgnoringSpace
{
static bool less(const KString& a, const KString& b, const kchar_t space = u' ');
Expand Down Expand Up @@ -146,6 +152,38 @@ namespace kiwi
return joinHangul(hangul.begin(), hangul.end());
}

inline bool isHighSurrogate(char16_t c)
{
return (c & 0xFC00) == 0xD800;
}

inline bool isLowSurrogate(char16_t c)
{
return (c & 0xFC00) == 0xDC00;
}

inline char32_t mergeSurrogate(char16_t h, char16_t l)
{
return (((h & 0x3FF) << 10) | (l & 0x3FF)) + 0x10000;
}

inline std::array<char16_t, 2> decomposeSurrogate(char32_t c)
{
std::array<char16_t, 2> ret;
if (c < 0x10000)
{
ret[0] = c;
ret[1] = 0;
}
else
{
c -= 0x10000;
ret[0] = ((c >> 10) & 0x3FF) | 0xD800;
ret[1] = (c & 0x3FF) | 0xDC00;
}
return ret;
}

POSTag identifySpecialChr(char32_t chr);
size_t getSSType(char16_t c);
size_t getSBType(const std::u16string& form);
Expand Down
2 changes: 1 addition & 1 deletion src/FrozenTrie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ namespace kiwi
for (size_t i = 0; i < trie.size(); ++i)
{
auto& o = trie[i];
nodes[i].numNexts = o.next.size();
nodes[i].numNexts = (Key)o.next.size();
values[i] = xform(o);
nodes[i].nextOffset = ptr;

Expand Down
2 changes: 1 addition & 1 deletion src/Joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ namespace kiwi
for (size_t i = 0; i < candidates.size(); ++i)
{
auto& c = candidates[i];
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i));
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i));
if (!inserted.second)
{
if (inserted.first->second.first < c.score)
Expand Down
1 change: 1 addition & 0 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ namespace kiwi
morph.tag = s.tokenization[0].tag;
morph.vowel = CondVowel::none;
morph.polar = CondPolarity::none;
morph.complex = 0;
morph.lmMorphemeId = getDefaultMorphemeId(s.tokenization[0].tag);
form.candidate[0] = &morph;
}
Expand Down
5 changes: 3 additions & 2 deletions src/KiwiBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,10 +723,11 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args)
new (&pool) utils::ThreadPool{ args.numWorkers };
}
auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr);
cntNodes.root().getNext(lmVocabSize)->val /= 2;
std::vector<size_t> minCnts(args.lmOrder, args.lmMinCnt);
minCnts.back() = args.lmLastOrderMinCnt;
langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build(
cntNodes,
args.lmOrder, args.lmMinCnt, args.lmLastOrderMinCnt,
args.lmOrder, minCnts,
2, 0, 1, 1e-5,
args.quantizeLm ? 8 : 0,
args.compressLm,
Expand Down
46 changes: 46 additions & 0 deletions src/Knlm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "Knlm.hpp"

namespace kiwi
{
namespace lm
{
template<ArchType archType>
std::unique_ptr<KnLangModelBase> createOptimizedModel(utils::MemoryObject&& mem)
{
auto* ptr = reinterpret_cast<const char*>(mem.get());
auto& header = *reinterpret_cast<const Header*>(ptr);
switch (header.key_size)
{
case 1:
return make_unique<KnLangModel<archType, uint8_t>>(std::move(mem));
case 2:
return make_unique<KnLangModel<archType, uint16_t>>(std::move(mem));
case 4:
return make_unique<KnLangModel<archType, uint32_t>>(std::move(mem));
case 8:
return make_unique<KnLangModel<archType, uint64_t>>(std::move(mem));
default:
throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) };
}
}

using FnCreateOptimizedModel = decltype(&createOptimizedModel<ArchType::none>);

struct CreateOptimizedModelGetter
{
template<std::ptrdiff_t i>
struct Wrapper
{
static constexpr FnCreateOptimizedModel value = &createOptimizedModel<static_cast<ArchType>(i)>;
};
};

std::unique_ptr<KnLangModelBase> KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType)
{
static tp::Table<FnCreateOptimizedModel, AvailableArch> table{ CreateOptimizedModelGetter{} };
auto fn = table[static_cast<std::ptrdiff_t>(archType)];
if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) };
return (*fn)(std::move(mem));
}
}
}
Loading

0 comments on commit 2696a9f

Please sign in to comment.