diff --git a/src/engine/supplemental_model_interface.h b/src/engine/supplemental_model_interface.h index 580641908..da9dc0f12 100644 --- a/src/engine/supplemental_model_interface.h +++ b/src/engine/supplemental_model_interface.h @@ -78,6 +78,11 @@ class SupplementalModelInterface { return std::nullopt; } + // Populates the typing correction penalty and attribute to `results`. + virtual void PopulateTypeCorrectedQuery( + const ConversionRequest &request, const Segments &segments, + std::vector *results) const {} + // Reranks (boost or promote) the typing corrected candidates at `results`. virtual void RerankTypingCorrection( const ConversionRequest &request, const Segments &segments, diff --git a/src/engine/supplemental_model_mock.h b/src/engine/supplemental_model_mock.h index 7f1d47989..515227577 100644 --- a/src/engine/supplemental_model_mock.h +++ b/src/engine/supplemental_model_mock.h @@ -60,6 +60,10 @@ class MockSupplementalModel : public SupplementalModelInterface { CorrectComposition, (const ConversionRequest &request, absl::string_view context), (const, override)); + MOCK_METHOD(void, PopulateTypeCorrectedQuery, + (const ConversionRequest &request, const Segments &segments, + std::vector *results), + (const, override)); MOCK_METHOD(void, RerankTypingCorrection, (const ConversionRequest &request, const Segments &segments, std::vector> *results), diff --git a/src/prediction/BUILD.bazel b/src/prediction/BUILD.bazel index 02e439ac6..aba1a80af 100644 --- a/src/prediction/BUILD.bazel +++ b/src/prediction/BUILD.bazel @@ -208,8 +208,10 @@ mozc_cc_library( deps = [ ":zero_query_dict", "//base/strings:unicode", + "//composer:query", "//converter:segments", "//dictionary:dictionary_token", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/src/prediction/dictionary_prediction_aggregator.cc b/src/prediction/dictionary_prediction_aggregator.cc index 50658a883..61134d931 100644 --- a/src/prediction/dictionary_prediction_aggregator.cc +++ b/src/prediction/dictionary_prediction_aggregator.cc @@ -316,12 +316,13 @@ class DictionaryPredictionAggregator::PredictiveLookupCallback return TRAVERSE_CONTINUE; } - results_->push_back(Result()); - results_->back().InitializeByTokenAndTypes(token, types_); - results_->back().wcost += penalty_; - results_->back().source_info |= source_info_; - results_->back().non_expanded_original_key = - std::string(non_expanded_original_key_); + Result result; + result.InitializeByTokenAndTypes(token, types_); + result.wcost += penalty_; + result.source_info |= source_info_; + result.non_expanded_original_key = std::string(non_expanded_original_key_); + if (penalty_ > 0) result.types |= KANA_MODIFIER_EXPANDED; + results_->emplace_back(std::move(result)); return (results_->size() < limit_) ? TRAVERSE_CONTINUE : TRAVERSE_DONE; } @@ -691,6 +692,8 @@ PredictionTypes DictionaryPredictionAggregator::AggregatePrediction( } } + MaybePopulateTypingCorrectionPenalty(request, segments, results); + return selected_types; } @@ -1773,19 +1776,7 @@ void DictionaryPredictionAggregator::AggregateTypingCorrectedPrediction( // Appends the result with TYPING_CORRECTION attribute. for (Result &result : corrected_results) { - if (query.type & TypeCorrectedQuery::CORRECTION) { - result.types |= TYPING_CORRECTION; - } - if (query.type & TypeCorrectedQuery::COMPLETION) { - result.types |= TYPING_COMPLETION; - } - result.typing_correction_score = query.score; - // bias = hyp_score - base_score, so larger is better. - // bias is computed in log10 domain, so we need to use the different - // scale factor. 500 * log(10) = ~1150. - const int adjustment = -1150 * query.bias; - result.typing_correction_adjustment = adjustment; - result.wcost += adjustment; + PopulateTypeCorrectedQuery(query, &result); result.value = manager->ConvertConversionString(result.value); results->emplace_back(std::move(result)); } @@ -1933,6 +1924,17 @@ bool DictionaryPredictionAggregator::IsZipCodeRequest( } return true; } + +void DictionaryPredictionAggregator::MaybePopulateTypingCorrectionPenalty( + const ConversionRequest &request, const Segments &segments, + std::vector *results) const { + const engine::SupplementalModelInterface *supplemental_model = + modules_.GetSupplementalModel(); + if (!supplemental_model) return; + + supplemental_model->PopulateTypeCorrectedQuery(request, segments, results); +} + } // namespace prediction } // namespace mozc diff --git a/src/prediction/dictionary_prediction_aggregator.h b/src/prediction/dictionary_prediction_aggregator.h index 0a39e6b89..f1dfd1871 100644 --- a/src/prediction/dictionary_prediction_aggregator.h +++ b/src/prediction/dictionary_prediction_aggregator.h @@ -278,6 +278,10 @@ class DictionaryPredictionAggregator : public PredictionAggregatorInterface { const ConversionRequest &request, const Segments &segments, int zip_code_id, int unknown_id, std::vector *results); + void MaybePopulateTypingCorrectionPenalty(const ConversionRequest &request, + const Segments &segments, + std::vector *results) const; + // Test peer to access private methods friend class DictionaryPredictionAggregatorTestPeer; diff --git a/src/prediction/dictionary_predictor.cc b/src/prediction/dictionary_predictor.cc index 33092a36d..56ef64e85 100644 --- a/src/prediction/dictionary_predictor.cc +++ b/src/prediction/dictionary_predictor.cc @@ -893,11 +893,17 @@ std::string DictionaryPredictor::GetPredictionTypeDebugString( debug_desc.append(1, 'E'); } if (types & PredictionType::TYPING_CORRECTION) { - debug_desc.append("T"); + debug_desc.append(1, 'T'); + } + if (types & PredictionType::TYPING_COMPLETION) { + debug_desc.append(1, 'C'); } if (types & PredictionType::SUPPLEMENTAL_MODEL) { debug_desc.append(1, 'X'); } + if (types & PredictionType::KANA_MODIFIER_EXPANDED) { + debug_desc.append(1, 'K'); + } return debug_desc; } diff --git a/src/prediction/result.cc b/src/prediction/result.cc index 4c2969c05..efa806be1 100644 --- a/src/prediction/result.cc +++ b/src/prediction/result.cc @@ -31,9 +31,11 @@ #include +#include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "base/strings/unicode.h" +#include "composer/query.h" #include "converter/segments.h" #include "dictionary/dictionary_token.h" #include "prediction/zero_query_dict.h" @@ -138,5 +140,23 @@ void Result::SetSourceInfoForZeroQuery(ZeroQueryType type) { } } +void PopulateTypeCorrectedQuery( + const composer::TypeCorrectedQuery &typing_corrected_result, + absl::Nonnull result) { + if (typing_corrected_result.type & composer::TypeCorrectedQuery::CORRECTION) { + result->types |= TYPING_CORRECTION; + } + if (typing_corrected_result.type & composer::TypeCorrectedQuery::COMPLETION) { + result->types |= TYPING_COMPLETION; + } + result->typing_correction_score = typing_corrected_result.score; + // bias = hyp_score - base_score, so larger is better. + // bias is computed in log10 domain, so we need to use the different + // scale factor. 500 * log(10) = ~1150. + const int adjustment = -1150 * typing_corrected_result.bias; + result->typing_correction_adjustment = adjustment; + result->wcost += adjustment; +} + } // namespace prediction } // namespace mozc diff --git a/src/prediction/result.h b/src/prediction/result.h index 01ad8129b..c04bc1db7 100644 --- a/src/prediction/result.h +++ b/src/prediction/result.h @@ -36,10 +36,12 @@ #include #include +#include "absl/base/nullability.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "composer/query.h" #include "converter/segments.h" #include "dictionary/dictionary_token.h" #include "prediction/zero_query_dict.h" @@ -80,6 +82,10 @@ enum PredictionType { // TODO(noriyukit): This label should be integrated with REALTIME. This is // why 65536 is used to indicate that it is a temporary assignment. REALTIME_TOP = 65536, + + // Kana modifier is expanded inside the dictionary lookup. + // TODO(taku): This label should be migrated to TYPING_CORRECTION. + KANA_MODIFIER_EXPANDED = 32768, }; // Bitfield to store a set of PredictionType. using PredictionTypes = int32_t; @@ -200,6 +206,12 @@ struct ResultCostLess { } }; +// Populates the typing correction result in `query` to prediction::Result +// TODO(taku): rename `query` as it is not a query. +void PopulateTypeCorrectedQuery( + const composer::TypeCorrectedQuery &typing_corrected_result, + absl::Nonnull result); + #ifndef NDEBUG #define MOZC_WORD_LOG_MESSAGE(message) \ absl::StrCat(__FILE__, ":", __LINE__, " ", message, "\n") diff --git a/src/protocol/commands.proto b/src/protocol/commands.proto index c08731168..f50cb3951 100644 --- a/src/protocol/commands.proto +++ b/src/protocol/commands.proto @@ -573,7 +573,7 @@ message Capability { [default = NO_TEXT_DELETION_CAPABILITY]; } -// Next ID: 96 +// Next ID: 97 // Bundles together some Android experiment flags so that they can be easily // retrieved throughout the native code. These flags are generally specific to // the decoder, and are made available when the decoder is initialized.