Skip to content

Commit

Permalink
Removed legacy ShouldRevertTypingCorrection.
Browse files Browse the repository at this point in the history
Migrating to new TypingCorrectionReranker component.

PiperOrigin-RevId: 690046621
  • Loading branch information
taku910 authored and hiroyuki-komatsu committed Oct 26, 2024
1 parent 550b151 commit 4a1069a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 268 deletions.
10 changes: 0 additions & 10 deletions src/engine/supplemental_model_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ class SupplementalModelInterface {
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}

// Returns true if the final typing correct result is not confident.
// TODO(taku): Remove this function after finishing the migration of
// the more general SuppressTypingCorrection method.
virtual bool ShouldRevertTypingCorrection(
const ConversionRequest &request, const Segments &segments,
absl::Span<const prediction::Result> literal_results,
absl::Span<const prediction::Result> typing_corrected_results) const {
return false;
}

// Performs general post correction on `segments`.
virtual void PostCorrect(const ConversionRequest &request,
absl::Nonnull<Segments *> segments) const {}
Expand Down
5 changes: 0 additions & 5 deletions src/engine/supplemental_model_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ class MockSupplementalModel : public SupplementalModelInterface {
(const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results),
(const, override));
MOCK_METHOD(bool, ShouldRevertTypingCorrection,
(const ConversionRequest &request, const Segments &segments,
absl::Span<const prediction::Result> literal_results,
absl::Span<const prediction::Result> typing_corrected_results),
(const, override));
MOCK_METHOD(void, PostCorrect,
(const ConversionRequest &, absl::Nonnull<Segments *> segments),
(const, override));
Expand Down
123 changes: 11 additions & 112 deletions src/prediction/dictionary_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,11 @@ bool DictionaryPredictor::PredictForRequest(const ConversionRequest &request,
RewriteResultsForPrediction(request, *segments, &results);

// Explicitly populate the typing corrected results.
const TypingCorrectionMixingParams typing_correction_mixing_params =
MaybePopulateTypingCorrectedResults(request, *segments, &results);
MaybePopulateTypingCorrectedResults(request, *segments, &results);

MaybeRescoreResults(request, *segments, absl::MakeSpan(results));

return AddPredictionToCandidates(request, segments,
typing_correction_mixing_params,
absl::MakeSpan(results));
return AddPredictionToCandidates(request, segments, absl::MakeSpan(results));
}

void DictionaryPredictor::RewriteResultsForPrediction(
Expand Down Expand Up @@ -343,42 +340,30 @@ void DictionaryPredictor::RewriteResultsForPrediction(
}
}

TypingCorrectionMixingParams
DictionaryPredictor::MaybePopulateTypingCorrectedResults(
void DictionaryPredictor::MaybePopulateTypingCorrectedResults(
const ConversionRequest &request, const Segments &segments,
std::vector<Result> *results) const {
if (!IsTypingCorrectionEnabled(request)) {
return {};
}

if (results->empty()) {
return {};
if (!IsTypingCorrectionEnabled(request) || results->empty()) {
return;
}

const size_t key_len = Util::CharsLen(segments.conversion_segment(0).key());
constexpr int kMinTypingCorrectionKeyLen = 3;
if (key_len < kMinTypingCorrectionKeyLen) {
return {};
return;
}

std::vector<Result> typing_corrected_results =
aggregator_->AggregateTypingCorrectedResults(request, segments);
RewriteResultsForPrediction(request, segments, &typing_corrected_results);

const TypingCorrectionMixingParams typing_correction_mixing_params =
GetTypingCorrectionMixingParams(request, segments, *results,
typing_corrected_results);

for (auto &result : typing_corrected_results) {
results->emplace_back(std::move(result));
}

return typing_correction_mixing_params;
}

bool DictionaryPredictor::AddPredictionToCandidates(
const ConversionRequest &request, Segments *segments,
const TypingCorrectionMixingParams &typing_correction_mixing_params,
absl::Span<Result> results) const {
DCHECK(segments);

Expand Down Expand Up @@ -462,14 +447,8 @@ bool DictionaryPredictor::AddPredictionToCandidates(
final_results_ptrs.emplace_back(&result);
}

const auto &params = request.request().decoder_experiment_params();
if (params.typing_correction_result_reranker_mode() > 0) {
MaybeRerankAggressiveTypingCorrection(request, *segments,
&final_results_ptrs);
} else {
MaybeSuppressAggressiveTypingCorrection(
request, typing_correction_mixing_params, &final_results_ptrs);
}
MaybeRerankAggressiveTypingCorrection(request, *segments,
&final_results_ptrs);

// Fill segments from final_results_ptrs.
for (const Result *result : final_results_ptrs) {
Expand All @@ -491,72 +470,15 @@ bool DictionaryPredictor::AddPredictionToCandidates(
void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const {
const auto &params = request.request().decoder_experiment_params();
if (params.typing_correction_result_reranker_mode() == 0) return;

if (!IsTypingCorrectionEnabled(request) || results->empty()) {
return;
}
const engine::SupplementalModelInterface *supplemental_model =
modules_.GetSupplementalModel();
if (supplemental_model == nullptr) return;

supplemental_model->RerankTypingCorrection(request, segments, results);
}

// static
void DictionaryPredictor::MaybeSuppressAggressiveTypingCorrection(
const ConversionRequest &request,
const TypingCorrectionMixingParams &typing_correction_mixing_params,
std::vector<absl::Nonnull<const Result *>> *results) {
if (results->empty()) return;

// Top is already literal.
const auto &top_result = results->front();

auto is_typing_correction = [&](const Result &result) {
return (
result.types & PredictionType::TYPING_CORRECTION ||
(result.candidate_attributes & Segment::Candidate::TYPING_CORRECTION));
};

if (!is_typing_correction(*top_result)) {
return;
}

const bool force_literal_on_top =
typing_correction_mixing_params.literal_on_top;
const bool literal_at_least_second =
typing_correction_mixing_params.literal_at_least_second;

if (!force_literal_on_top && !literal_at_least_second) {
return;
}

auto promote_result = [&results](int old_idx, int new_idx) {
const Result *result = (*results)[old_idx];
for (int i = old_idx; i >= new_idx + 1; --i)
(*results)[i] = (*results)[i - 1];
(*results)[new_idx] = result;
};

const int max_size = std::min<int>(10, results->size());
for (int i = 1; i < max_size; ++i) {
const Result *result = (*results)[i];
// Finds the first non-typing-corrected candidate.
if (is_typing_correction(*result)) {
continue;
}
// Replace the literal with top when the cost is close enough or
// force_literal_on_top is true.
if (force_literal_on_top) {
promote_result(i, 0);
} else if (literal_at_least_second && i >= 2) {
// Moves the literal to the second position even when
// literal-on-top condition doesn't match.
promote_result(i, 1);
}
break;
}
}

// static
void DictionaryPredictor::MaybeApplyPostCorrection(
const ConversionRequest &request, const engine::Modules &modules,
Expand Down Expand Up @@ -1423,29 +1345,6 @@ std::shared_ptr<Result> DictionaryPredictor::MaybeGetPreviousTopResult(
return nullptr;
}

// Computes the typing correction mixing params.
// from the `literal_result` and `typing_corrected_results`
TypingCorrectionMixingParams
DictionaryPredictor::GetTypingCorrectionMixingParams(
const ConversionRequest &request, const Segments &segments,
absl::Span<const Result> literal_results,
absl::Span<const Result> typing_corrected_results) const {
TypingCorrectionMixingParams typing_correction_mixing_params;

const engine::SupplementalModelInterface *supplemental_model =
modules_.GetSupplementalModel();

if (supplemental_model) {
typing_correction_mixing_params.literal_on_top =
supplemental_model->ShouldRevertTypingCorrection(
request, segments, literal_results, typing_corrected_results);
}

typing_correction_mixing_params.literal_at_least_second = true;

return typing_correction_mixing_params;
}

} // namespace mozc::prediction

#undef MOZC_WORD_LOG_MESSAGE
Expand Down
39 changes: 6 additions & 33 deletions src/prediction/dictionary_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,6 @@ struct KeyValueView {

} // namespace dictionary_predictor_internal

// Parameters to mix the literal and typing corrected results.
// These parameters define the position of literal and typing corrected
// results, and determined dynamically using various quality signals.
struct TypingCorrectionMixingParams {
// Moves the literal candidate to the top position even when
// the typing corrected result is placed at top.
// Set this flag when the typing correction is less confident.
bool literal_on_top = false;

// Moves the literal candidate to the at least second position.
// When the literal candidate is already at the top, do nothing.
bool literal_at_least_second = false;
};

// Dictionary-based predictor
class DictionaryPredictor : public PredictorInterface {
public:
Expand Down Expand Up @@ -156,24 +142,16 @@ class DictionaryPredictor : public PredictorInterface {
aggregator,
const ImmutableConverterInterface *immutable_converter);

bool AddPredictionToCandidates(
const ConversionRequest &request, Segments *segments,
const TypingCorrectionMixingParams &typing_correction_mixing_params,
absl::Span<Result> results) const;
bool AddPredictionToCandidates(const ConversionRequest &request,
Segments *segments,
absl::Span<Result> results) const;

void FillCandidate(
const ConversionRequest &request, const Result &result,
dictionary_predictor_internal::KeyValueView key_value,
const absl::flat_hash_map<std::string, int32_t> &merged_types,
Segment::Candidate *candidate) const;

// Computes the typing correction mixing params.
// from the `base_result` and `typing_corrected_results`.
TypingCorrectionMixingParams GetTypingCorrectionMixingParams(
const ConversionRequest &request, const Segments &segments,
absl::Span<const Result> literal_results,
absl::Span<const Result> typing_corrected_results) const;

// Returns the position of misspelled character position.
//
// Example:
Expand Down Expand Up @@ -287,19 +265,14 @@ class DictionaryPredictor : public PredictorInterface {
absl::flat_hash_map<PrefixPenaltyKey, int> *cache) const;

// Populates typing corrected results to `results`.
TypingCorrectionMixingParams MaybePopulateTypingCorrectedResults(
const ConversionRequest &request, const Segments &segments,
std::vector<Result> *results) const;
void MaybePopulateTypingCorrectedResults(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> *results) const;

void MaybeRerankAggressiveTypingCorrection(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const;

static void MaybeSuppressAggressiveTypingCorrection(
const ConversionRequest &request,
const TypingCorrectionMixingParams &typing_correction_mixing_params,
std::vector<absl::Nonnull<const Result *>> *results);

static void MaybeApplyPostCorrection(const ConversionRequest &request,
const engine::Modules &modules,
Segments *segments);
Expand Down
Loading

0 comments on commit 4a1069a

Please sign in to comment.