From 0ec815229f1cbbad8cf37e359cbffbc9811235b0 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 6 Sep 2023 13:34:12 +0900 Subject: [PATCH] Remove unchecked functions (#144) * Remove unchecked functions * fix * fix --- benchmark/Cargo.toml | 4 - benchmark/src/main.rs | 3 - docs/benchmark.md | 7 -- map/src/main.rs | 3 - vibrato/src/dictionary.rs | 40 --------- vibrato/src/dictionary/builder.rs | 1 - vibrato/src/dictionary/connector.rs | 3 - .../dictionary/connector/dual_connector.rs | 14 --- .../dictionary/connector/matrix_connector.rs | 7 -- .../src/dictionary/connector/raw_connector.rs | 6 -- vibrato/src/dictionary/lexicon.rs | 16 ---- vibrato/src/dictionary/lexicon/map.rs | 12 --- vibrato/src/dictionary/lexicon/map/posting.rs | 7 -- vibrato/src/tokenizer.rs | 88 +------------------ vibrato/src/tokenizer/lattice.rs | 74 ---------------- vibrato/src/trainer/config.rs | 1 - 16 files changed, 3 insertions(+), 283 deletions(-) diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 6aaf8a4c..dbad03e5 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -3,10 +3,6 @@ name = "benchmark" version = "0.1.0" edition = "2021" -[features] -default = [] -unchecked = [] - [dependencies] vibrato = { path = "../vibrato" } clap = { version = "4.0", features = ["derive"] } # MIT or Apache-2.0 diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 7edff6b8..9d33fc37 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -37,10 +37,7 @@ fn main() -> Result<(), Box> { let args = Args::parse(); let reader = zstd::Decoder::new(File::open(args.sysdic)?)?; - #[cfg(not(feature = "unchecked"))] let dict = Dictionary::read(reader)?; - #[cfg(feature = "unchecked")] - let dict = unsafe { Dictionary::read_unchecked(reader)? }; let tokenizer = Tokenizer::new(dict) .ignore_space(args.ignore_space)? diff --git a/docs/benchmark.md b/docs/benchmark.md index 13a74cb7..295ad72a 100644 --- a/docs/benchmark.md +++ b/docs/benchmark.md @@ -6,10 +6,3 @@ and sentences in `test.txt` with the following command. ``` $ cargo run --release -p benchmark -- -i system.dic.zst < test.txt ``` - -If you can guarantee that `system.dic.zst` is exported from this library, -you can specify `--features=unchecked` for faster tokenization. - -``` -$ cargo run --release -p benchmark --features=unchecked -- -i system.dic.zst < test.txt -``` diff --git a/map/src/main.rs b/map/src/main.rs index 8acc9684..36ea3c55 100644 --- a/map/src/main.rs +++ b/map/src/main.rs @@ -32,10 +32,7 @@ fn main() -> Result<(), Box> { eprintln!("Loading the dictionary..."); let reader = zstd::Decoder::new(File::open(args.sysdic_in)?)?; - #[cfg(not(feature = "unchecked"))] let dict = Dictionary::read(reader)?; - #[cfg(feature = "unchecked")] - let dict = unsafe { Dictionary::read_unchecked(reader)? }; eprintln!("Loading and doing the mapping..."); let lmap = { diff --git a/vibrato/src/dictionary.rs b/vibrato/src/dictionary.rs index cce5b6a7..430e4f7a 100644 --- a/vibrato/src/dictionary.rs +++ b/vibrato/src/dictionary.rs @@ -58,7 +58,6 @@ pub(crate) struct DictionaryInner { /// Dictionary for tokenization. pub struct Dictionary { pub(crate) data: DictionaryInner, - pub(crate) need_check: bool, } impl Dictionary { @@ -182,45 +181,6 @@ impl Dictionary { { Ok(Self { data: Self::read_common(rdr)?, - need_check: true, - }) - } - - /// Creates a dictionary from raw dictionary data. - /// - /// The argument must be a byte sequence exported by the [`Dictionary::write()`] function. - /// - /// Unlike the [`Dictionary::read()`] function, this function does not check the correctness of - /// the dictionary. - /// - /// # Examples - /// - /// ```no_run - /// # fn main() -> Result<(), Box> { - /// use std::fs::File; - /// - /// use vibrato::Dictionary; - /// - /// let reader = File::open("path/to/system.dic")?; - /// let dict = unsafe { Dictionary::read_unchecked(reader)? } ; - /// # Ok(()) - /// # } - /// ``` - /// - /// # Safety - /// - /// The given reader must be a correct file exported by [`Dictionary::write()`]. - /// - /// # Errors - /// - /// When bincode generates an error, it will be returned as is. - pub unsafe fn read_unchecked(rdr: R) -> Result - where - R: Read, - { - Ok(Self { - data: Self::read_common(rdr)?, - need_check: false, }) } diff --git a/vibrato/src/dictionary/builder.rs b/vibrato/src/dictionary/builder.rs index 1dc1f706..53d9da43 100644 --- a/vibrato/src/dictionary/builder.rs +++ b/vibrato/src/dictionary/builder.rs @@ -43,7 +43,6 @@ impl SystemDictionaryBuilder { char_prop, unk_handler, }, - need_check: false, }) } diff --git a/vibrato/src/dictionary/connector.rs b/vibrato/src/dictionary/connector.rs index fa6c7663..1df06f00 100644 --- a/vibrato/src/dictionary/connector.rs +++ b/vibrato/src/dictionary/connector.rs @@ -25,9 +25,6 @@ pub trait Connector { pub trait ConnectorCost: Connector { /// Gets the value of the connection matrix fn cost(&self, right_id: u16, left_id: u16) -> i32; - - /// Gets the value of the connection matrix - unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32; } #[derive(Decode, Encode)] diff --git a/vibrato/src/dictionary/connector/dual_connector.rs b/vibrato/src/dictionary/connector/dual_connector.rs index aeefb406..bcf7d328 100644 --- a/vibrato/src/dictionary/connector/dual_connector.rs +++ b/vibrato/src/dictionary/connector/dual_connector.rs @@ -276,20 +276,6 @@ impl ConnectorCost for DualConnector { ); matrix_cost + raw_cost } - - #[inline(always)] - unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 { - let right_conn_id = *self.right_conn_id_map.get_unchecked(usize::from(right_id)); - let left_conn_id = *self.left_conn_id_map.get_unchecked(usize::from(left_id)); - let matrix_cost = self - .matrix_connector - .cost_unchecked(right_conn_id, left_conn_id); - let raw_cost = self.raw_scorer.accumulate_cost( - &[*self.right_feat_ids.get_unchecked(usize::from(right_id))], - &[*self.left_feat_ids.get_unchecked(usize::from(left_id))], - ); - matrix_cost + raw_cost - } } #[cfg(test)] diff --git a/vibrato/src/dictionary/connector/matrix_connector.rs b/vibrato/src/dictionary/connector/matrix_connector.rs index d9401174..df6c3f8d 100644 --- a/vibrato/src/dictionary/connector/matrix_connector.rs +++ b/vibrato/src/dictionary/connector/matrix_connector.rs @@ -122,13 +122,6 @@ impl ConnectorCost for MatrixConnector { let index = self.index(right_id, left_id); i32::from(self.data[index]) } - - #[inline(always)] - unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 { - let index = self.index(right_id, left_id); - // The tokenization time can be shortened by 5--10%. - i32::from(*self.data.get_unchecked(index)) - } } #[cfg(test)] diff --git a/vibrato/src/dictionary/connector/raw_connector.rs b/vibrato/src/dictionary/connector/raw_connector.rs index a6e12547..b548b316 100644 --- a/vibrato/src/dictionary/connector/raw_connector.rs +++ b/vibrato/src/dictionary/connector/raw_connector.rs @@ -158,12 +158,6 @@ impl ConnectorCost for RawConnector { self.left_feature_ids(left_id), ) } - - /// TODO: Implement unchecked optimization. - #[inline(always)] - unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 { - self.cost(right_id, left_id) - } } /// Builder for components of [`RawConnector`] using simple data structures. diff --git a/vibrato/src/dictionary/lexicon.rs b/vibrato/src/dictionary/lexicon.rs index 8e68fc23..34bfd8aa 100644 --- a/vibrato/src/dictionary/lexicon.rs +++ b/vibrato/src/dictionary/lexicon.rs @@ -45,22 +45,6 @@ impl Lexicon { }) } - #[inline(always)] - pub unsafe fn common_prefix_iterator_unchecked<'a>( - &'a self, - input: &'a [char], - ) -> impl Iterator + 'a { - self.map - .common_prefix_iterator_unchecked(input) - .map(move |(word_id, end_char)| { - LexMatch::new( - WordIdx::new(self.lex_type, word_id), - self.params.get(usize::from_u32(word_id)), - end_char, - ) - }) - } - /// Do NOT make this function public to maintain consistency in /// the connection-id mapping among members of `Dictionary`. /// The consistency is managed in `Dictionary`. diff --git a/vibrato/src/dictionary/lexicon/map.rs b/vibrato/src/dictionary/lexicon/map.rs index e4523560..3eeda625 100644 --- a/vibrato/src/dictionary/lexicon/map.rs +++ b/vibrato/src/dictionary/lexicon/map.rs @@ -40,18 +40,6 @@ impl WordMap { .map(move |word_id| (word_id, e.end_char)) }) } - - #[inline(always)] - pub unsafe fn common_prefix_iterator_unchecked<'a>( - &'a self, - input: &'a [char], - ) -> impl Iterator + 'a { - self.trie.common_prefix_iterator(input).flat_map(move |e| { - self.postings - .ids_unchecked(usize::from_u32(e.value)) - .map(move |word_id| (word_id, e.end_char)) - }) - } } #[derive(Default)] diff --git a/vibrato/src/dictionary/lexicon/map/posting.rs b/vibrato/src/dictionary/lexicon/map/posting.rs index 372468f5..6eb6809a 100644 --- a/vibrato/src/dictionary/lexicon/map/posting.rs +++ b/vibrato/src/dictionary/lexicon/map/posting.rs @@ -19,13 +19,6 @@ impl Postings { let len = usize::from_u32(self.data[i]); self.data[i + 1..i + 1 + len].iter().cloned() } - - #[inline(always)] - pub unsafe fn ids_unchecked(&'_ self, i: usize) -> impl Iterator + '_ { - let len = usize::from_u32(self.data[i]); - // The tokenization time can be shortened by 10%. - self.data.get_unchecked(i + 1..i + 1 + len).iter().cloned() - } } #[derive(Default)] diff --git a/vibrato/src/tokenizer.rs b/vibrato/src/tokenizer.rs index 10dbd579..ebc919d1 100644 --- a/vibrato/src/tokenizer.rs +++ b/vibrato/src/tokenizer.rs @@ -129,27 +129,13 @@ impl Tokenizer { break; } - if self.dict.need_check { - self.add_lattice_edges(sent, lattice, start_node, start_word, connector); - } else { - unsafe { - self.add_lattice_edges_unchecked( - sent, lattice, start_node, start_word, connector, - ); - } - } + self.add_lattice_edges(sent, lattice, start_node, start_word, connector); start_word += 1; start_node = start_word; } - if self.dict.need_check { - lattice.insert_eos(start_node, connector); - } else { - unsafe { - lattice.insert_eos_unchecked(start_node, connector); - } - } + lattice.insert_eos(start_node, connector); } fn add_lattice_edges( @@ -164,9 +150,7 @@ impl Tokenizer { { let mut has_matched = false; - // Safety: `start_word < sent.len_char()` is already checked in `build_lattice()`. - debug_assert!(start_word < sent.len_char()); - let suffix = unsafe { sent.chars().get_unchecked(start_word..) }; + let suffix = &sent.chars()[start_word..]; if let Some(user_lexicon) = self.dict.user_lexicon() { for m in user_lexicon.common_prefix_iterator(suffix) { @@ -213,72 +197,6 @@ impl Tokenizer { }, ); } - - unsafe fn add_lattice_edges_unchecked( - &self, - sent: &Sentence, - lattice: &mut Lattice, - start_node: usize, - start_word: usize, - connector: &C, - ) where - C: ConnectorCost, - { - let mut has_matched = false; - - // Safety: `start_word < sent.len_char()` is already checked in `build_lattice()`. - debug_assert!(start_word < sent.len_char()); - let suffix = sent.chars().get_unchecked(start_word..); - - if let Some(user_lexicon) = self.dict.user_lexicon() { - for m in user_lexicon.common_prefix_iterator_unchecked(suffix) { - debug_assert!(start_word + m.end_char <= sent.len_char()); - lattice.insert_node_unchecked( - start_node, - start_word, - start_word + m.end_char, - m.word_idx, - m.word_param, - connector, - ); - has_matched = true; - } - } - - for m in self - .dict - .system_lexicon() - .common_prefix_iterator_unchecked(suffix) - { - debug_assert!(start_word + m.end_char <= sent.len_char()); - lattice.insert_node_unchecked( - start_node, - start_word, - start_word + m.end_char, - m.word_idx, - m.word_param, - connector, - ); - has_matched = true; - } - - self.dict.unk_handler().gen_unk_words( - sent, - start_word, - has_matched, - self.max_grouping_len, - |w| { - lattice.insert_node_unchecked( - start_node, - w.start_char(), - w.end_char(), - w.word_idx(), - w.word_param(), - connector, - ); - }, - ); - } } #[cfg(test)] diff --git a/vibrato/src/tokenizer/lattice.rs b/vibrato/src/tokenizer/lattice.rs index aed914f9..9a2a5c1c 100644 --- a/vibrato/src/tokenizer/lattice.rs +++ b/vibrato/src/tokenizer/lattice.rs @@ -100,24 +100,6 @@ impl Lattice { }); } - pub unsafe fn insert_eos_unchecked(&mut self, start_node: usize, connector: &C) - where - C: ConnectorCost, - { - let (min_idx, min_cost) = - self.search_min_node_unchecked(start_node, BOS_EOS_CONNECTION_ID, connector); - self.eos = Some(Node { - word_id: u32::MAX, - lex_type: LexType::default(), - start_node, - start_word: self.len_char(), - left_id: BOS_EOS_CONNECTION_ID, - right_id: u16::MAX, - min_idx, - min_cost, - }); - } - pub fn insert_node( &mut self, start_node: usize, @@ -144,33 +126,6 @@ impl Lattice { }); } - pub unsafe fn insert_node_unchecked( - &mut self, - start_node: usize, - start_word: usize, - end_word: usize, - word_idx: WordIdx, - word_param: WordParam, - connector: &C, - ) where - C: ConnectorCost, - { - debug_assert!(start_node <= start_word); - debug_assert!(start_word < end_word); - let (min_idx, min_cost) = - self.search_min_node_unchecked(start_node, word_param.left_id, connector); - self.ends[end_word].push(Node { - word_id: word_idx.word_id, - lex_type: word_idx.lex_type, - start_node, - start_word, - left_id: word_param.left_id, - right_id: word_param.right_id, - min_idx, - min_cost: min_cost + i32::from(word_param.word_cost), - }); - } - fn search_min_node(&self, start_node: usize, left_id: u16, connector: &C) -> (u16, i32) where C: ConnectorCost, @@ -195,35 +150,6 @@ impl Lattice { (min_idx, min_cost) } - unsafe fn search_min_node_unchecked( - &self, - start_node: usize, - left_id: u16, - connector: &C, - ) -> (u16, i32) - where - C: ConnectorCost, - { - debug_assert!(!self.ends[start_node].is_empty()); - - let mut min_idx = INVALID_IDX; - let mut min_cost = MAX_COST; - for (i, left_node) in self.ends[start_node].iter().enumerate() { - debug_assert!(left_node.is_connected_to_bos()); - let conn_cost = connector.cost_unchecked(left_node.right_id, left_id); - let new_cost = left_node.min_cost + conn_cost; - // Depending on the order of tie-breaking, the result can be different from MeCab. - // Using <= (not <) will produce results identical to MeCab in most case (empirically). - if new_cost <= min_cost { - min_idx = i as u16; - min_cost = new_cost; - } - } - - debug_assert_ne!(min_idx, INVALID_IDX); - (min_idx, min_cost) - } - /// Checks if there exist at least one at the word end boundary #[inline(always)] pub fn has_previous_node(&self, i: usize) -> bool { diff --git a/vibrato/src/trainer/config.rs b/vibrato/src/trainer/config.rs index f1e4aa0a..ecdac601 100644 --- a/vibrato/src/trainer/config.rs +++ b/vibrato/src/trainer/config.rs @@ -34,7 +34,6 @@ impl Decode for TrainerConfig { let right_rewriter = Decode::decode(decoder)?; let dict = Dictionary { data: Decode::decode(decoder)?, - need_check: true, }; let surfaces = Decode::decode(decoder)?; Ok(Self {