From c6f24274aa95f51139a7c486c6b2b72421f1c6ca Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 14 Apr 2024 15:11:12 +0900 Subject: [PATCH] changed `topN` of evaluator to 1 & added more typo options --- tools/Evaluator.h | 4 ++-- tools/evaluator_main.cpp | 51 ++++++++++++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tools/Evaluator.h b/tools/Evaluator.h index 0de707c9..a91161d3 100644 --- a/tools/Evaluator.h +++ b/tools/Evaluator.h @@ -27,9 +27,9 @@ class Evaluator std::vector testsets, errors; const kiwi::Kiwi* kw = nullptr; kiwi::Match matchOption; - size_t topN = 3; + size_t topN = 1; public: - Evaluator(const std::string& testSetFile, const kiwi::Kiwi* _kw, kiwi::Match _matchOption = kiwi::Match::all, size_t topN = 3); + Evaluator(const std::string& testSetFile, const kiwi::Kiwi* _kw, kiwi::Match _matchOption = kiwi::Match::all, size_t topN = 1); void run(); Score evaluate(); const std::vector& getErrors() const { return errors; } diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index fb7130c9..84a049f8 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -12,13 +12,28 @@ using namespace std; using namespace kiwi; int doEvaluate(const string& modelPath, const string& output, const vector& input, - bool normCoda, bool zCoda, bool useSBG, float typoCostWeight, bool cTypo) + bool normCoda, bool zCoda, bool multiDict, bool useSBG, + float typoCostWeight, bool bTypo, bool cTypo, + int repeat) { try { + if (typoCostWeight > 0 && !bTypo && !cTypo) + { + bTypo = true; + } + else if (typoCostWeight == 0) + { + bTypo = false; + cTypo = false; + } + + DefaultTypoSet typos[] = { DefaultTypoSet::withoutTypo, DefaultTypoSet::basicTypoSet, DefaultTypoSet::continualTypoSet, DefaultTypoSet::basicTypoSetWithContinual}; + tutils::Timer timer; - Kiwi kw = KiwiBuilder{ modelPath, 1, BuildOption::default_, useSBG }.build( - typoCostWeight > 0 ? (cTypo ? DefaultTypoSet::basicTypoSetWithContinual : DefaultTypoSet::basicTypoSet) : DefaultTypoSet::withoutTypo + auto option = (BuildOption::default_ & ~BuildOption::loadMultiDict) | (multiDict ? BuildOption::loadMultiDict : BuildOption::none); + Kiwi kw = KiwiBuilder{ modelPath, 1, option, useSBG }.build( + typos[(bTypo ? 1 : 0) + (cTypo ? 2 : 0)] ); if (typoCostWeight > 0) kw.setTypoCostWeight(typoCostWeight); @@ -34,10 +49,13 @@ int doEvaluate(const string& modelPath, const string& output, const vector model{ "m", "model", "Kiwi model path", false, "ModelGenerator", "string" }; ValueArg output{ "o", "output", "output dir for evaluation errors", false, "", "string" }; - SwitchArg withoutNormCoda{ "", "wcoda", "without normalizing coda", false }; - SwitchArg withoutZCoda{ "", "wzcoda", "without z-coda", false }; + SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; + SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; + SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; SwitchArg useSBG{ "", "sbg", "use SkipBigram", false }; - ValueArg typoTolerant{ "", "typo", "make typo-tolerant model", false, 0.f, "float"}; + ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float"}; + SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; + ValueArg repeat{ "", "repeat", "repeat evaluation for benchmark", false, 1, "int" }; UnlabeledMultiArg files{ "files", "evaluation set files", true, "string" }; cmd.add(model); cmd.add(output); cmd.add(files); - cmd.add(withoutNormCoda); - cmd.add(withoutZCoda); + cmd.add(noNormCoda); + cmd.add(noZCoda); + cmd.add(noMulti); cmd.add(useSBG); - cmd.add(typoTolerant); + cmd.add(typoWeight); + cmd.add(bTypo); cmd.add(cTypo); + cmd.add(repeat); try { @@ -118,6 +142,7 @@ int main(int argc, const char* argv[]) cerr << "error: " << e.error() << " for arg " << e.argId() << endl; return -1; } - return doEvaluate(model, output, files.getValue(), !withoutNormCoda, !withoutZCoda, useSBG, typoTolerant, cTypo); + return doEvaluate(model, output, files.getValue(), + !noNormCoda, !noZCoda, !noMulti, useSBG, typoWeight, bTypo, cTypo, repeat); }