Skip to content

Commit

Permalink
Populates the typing corrector's penalty/bonus to the candidates cont…
Browse files Browse the repository at this point in the history
…aining auto-kana-modifier-expansions generated by composer or dictionary-based expansions.

We can evaluate all typing corrections with uniform scoring.

The actual populates are implemented inside the supplemental model.

PiperOrigin-RevId: 675586853
  • Loading branch information
taku910 authored and hiroyuki-komatsu committed Sep 17, 2024
1 parent 0c29141 commit ff08a9c
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/engine/supplemental_model_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class SupplementalModelInterface {
return std::nullopt;
}

// Populates the typing correction penalty and attribute to `results`.
virtual void PopulateTypeCorrectedQuery(
const ConversionRequest &request, const Segments &segments,
std::vector<prediction::Result> *results) const {}

// Reranks (boost or promote) the typing corrected candidates at `results`.
virtual void RerankTypingCorrection(
const ConversionRequest &request, const Segments &segments,
Expand Down
4 changes: 4 additions & 0 deletions src/engine/supplemental_model_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class MockSupplementalModel : public SupplementalModelInterface {
CorrectComposition,
(const ConversionRequest &request, absl::string_view context),
(const, override));
MOCK_METHOD(void, PopulateTypeCorrectedQuery,
(const ConversionRequest &request, const Segments &segments,
std::vector<prediction::Result> *results),
(const, override));
MOCK_METHOD(void, RerankTypingCorrection,
(const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results),
Expand Down
2 changes: 2 additions & 0 deletions src/prediction/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ mozc_cc_library(
deps = [
":zero_query_dict",
"//base/strings:unicode",
"//composer:query",
"//converter:segments",
"//dictionary:dictionary_token",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
Expand Down
40 changes: 21 additions & 19 deletions src/prediction/dictionary_prediction_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,13 @@ class DictionaryPredictionAggregator::PredictiveLookupCallback
return TRAVERSE_CONTINUE;
}

results_->push_back(Result());
results_->back().InitializeByTokenAndTypes(token, types_);
results_->back().wcost += penalty_;
results_->back().source_info |= source_info_;
results_->back().non_expanded_original_key =
std::string(non_expanded_original_key_);
Result result;
result.InitializeByTokenAndTypes(token, types_);
result.wcost += penalty_;
result.source_info |= source_info_;
result.non_expanded_original_key = std::string(non_expanded_original_key_);
if (penalty_ > 0) result.types |= KANA_MODIFIER_EXPANDED;
results_->emplace_back(std::move(result));
return (results_->size() < limit_) ? TRAVERSE_CONTINUE : TRAVERSE_DONE;
}

Expand Down Expand Up @@ -691,6 +692,8 @@ PredictionTypes DictionaryPredictionAggregator::AggregatePrediction(
}
}

MaybePopulateTypingCorrectionPenalty(request, segments, results);

return selected_types;
}

Expand Down Expand Up @@ -1773,19 +1776,7 @@ void DictionaryPredictionAggregator::AggregateTypingCorrectedPrediction(

// Appends the result with TYPING_CORRECTION attribute.
for (Result &result : corrected_results) {
if (query.type & TypeCorrectedQuery::CORRECTION) {
result.types |= TYPING_CORRECTION;
}
if (query.type & TypeCorrectedQuery::COMPLETION) {
result.types |= TYPING_COMPLETION;
}
result.typing_correction_score = query.score;
// bias = hyp_score - base_score, so larger is better.
// bias is computed in log10 domain, so we need to use the different
// scale factor. 500 * log(10) = ~1150.
const int adjustment = -1150 * query.bias;
result.typing_correction_adjustment = adjustment;
result.wcost += adjustment;
PopulateTypeCorrectedQuery(query, &result);
result.value = manager->ConvertConversionString(result.value);
results->emplace_back(std::move(result));
}
Expand Down Expand Up @@ -1933,6 +1924,17 @@ bool DictionaryPredictionAggregator::IsZipCodeRequest(
}
return true;
}

void DictionaryPredictionAggregator::MaybePopulateTypingCorrectionPenalty(
const ConversionRequest &request, const Segments &segments,
std::vector<Result> *results) const {
const engine::SupplementalModelInterface *supplemental_model =
modules_.GetSupplementalModel();
if (!supplemental_model) return;

supplemental_model->PopulateTypeCorrectedQuery(request, segments, results);
}

} // namespace prediction
} // namespace mozc

Expand Down
4 changes: 4 additions & 0 deletions src/prediction/dictionary_prediction_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ class DictionaryPredictionAggregator : public PredictionAggregatorInterface {
const ConversionRequest &request, const Segments &segments,
int zip_code_id, int unknown_id, std::vector<Result> *results);

void MaybePopulateTypingCorrectionPenalty(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> *results) const;

// Test peer to access private methods
friend class DictionaryPredictionAggregatorTestPeer;

Expand Down
8 changes: 7 additions & 1 deletion src/prediction/dictionary_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,11 +893,17 @@ std::string DictionaryPredictor::GetPredictionTypeDebugString(
debug_desc.append(1, 'E');
}
if (types & PredictionType::TYPING_CORRECTION) {
debug_desc.append("T");
debug_desc.append(1, 'T');
}
if (types & PredictionType::TYPING_COMPLETION) {
debug_desc.append(1, 'C');
}
if (types & PredictionType::SUPPLEMENTAL_MODEL) {
debug_desc.append(1, 'X');
}
if (types & PredictionType::KANA_MODIFIER_EXPANDED) {
debug_desc.append(1, 'K');
}
return debug_desc;
}

Expand Down
20 changes: 20 additions & 0 deletions src/prediction/result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@

#include <tuple>

#include "absl/base/nullability.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "base/strings/unicode.h"
#include "composer/query.h"
#include "converter/segments.h"
#include "dictionary/dictionary_token.h"
#include "prediction/zero_query_dict.h"
Expand Down Expand Up @@ -138,5 +140,23 @@ void Result::SetSourceInfoForZeroQuery(ZeroQueryType type) {
}
}

void PopulateTypeCorrectedQuery(
const composer::TypeCorrectedQuery &typing_corrected_result,
absl::Nonnull<Result *> result) {
if (typing_corrected_result.type & composer::TypeCorrectedQuery::CORRECTION) {
result->types |= TYPING_CORRECTION;
}
if (typing_corrected_result.type & composer::TypeCorrectedQuery::COMPLETION) {
result->types |= TYPING_COMPLETION;
}
result->typing_correction_score = typing_corrected_result.score;
// bias = hyp_score - base_score, so larger is better.
// bias is computed in log10 domain, so we need to use the different
// scale factor. 500 * log(10) = ~1150.
const int adjustment = -1150 * typing_corrected_result.bias;
result->typing_correction_adjustment = adjustment;
result->wcost += adjustment;
}

} // namespace prediction
} // namespace mozc
12 changes: 12 additions & 0 deletions src/prediction/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
#include <utility>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "composer/query.h"
#include "converter/segments.h"
#include "dictionary/dictionary_token.h"
#include "prediction/zero_query_dict.h"
Expand Down Expand Up @@ -80,6 +82,10 @@ enum PredictionType {
// TODO(noriyukit): This label should be integrated with REALTIME. This is
// why 65536 is used to indicate that it is a temporary assignment.
REALTIME_TOP = 65536,

// Kana modifier is expanded inside the dictionary lookup.
// TODO(taku): This label should be migrated to TYPING_CORRECTION.
KANA_MODIFIER_EXPANDED = 32768,
};
// Bitfield to store a set of PredictionType.
using PredictionTypes = int32_t;
Expand Down Expand Up @@ -200,6 +206,12 @@ struct ResultCostLess {
}
};

// Populates the typing correction result in `query` to prediction::Result
// TODO(taku): rename `query` as it is not a query.
void PopulateTypeCorrectedQuery(
const composer::TypeCorrectedQuery &typing_corrected_result,
absl::Nonnull<Result *> result);

#ifndef NDEBUG
#define MOZC_WORD_LOG_MESSAGE(message) \
absl::StrCat(__FILE__, ":", __LINE__, " ", message, "\n")
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ message Capability {
[default = NO_TEXT_DELETION_CAPABILITY];
}

// Next ID: 96
// Next ID: 97
// Bundles together some Android experiment flags so that they can be easily
// retrieved throughout the native code. These flags are generally specific to
// the decoder, and are made available when the decoder is initialized.
Expand Down

0 comments on commit ff08a9c

Please sign in to comment.