Skip to content

Commit

Permalink
Remove unchecked functions (#144)
Browse files Browse the repository at this point in the history
* Remove unchecked functions

* fix

* fix
  • Loading branch information
vbkaisetsu authored Sep 6, 2023
1 parent eaa0274 commit 0ec8152
Show file tree
Hide file tree
Showing 16 changed files with 3 additions and 283 deletions.
4 changes: 0 additions & 4 deletions benchmark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ name = "benchmark"
version = "0.1.0"
edition = "2021"

[features]
default = []
unchecked = []

[dependencies]
vibrato = { path = "../vibrato" }
clap = { version = "4.0", features = ["derive"] } # MIT or Apache-2.0
Expand Down
3 changes: 0 additions & 3 deletions benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();

let reader = zstd::Decoder::new(File::open(args.sysdic)?)?;
#[cfg(not(feature = "unchecked"))]
let dict = Dictionary::read(reader)?;
#[cfg(feature = "unchecked")]
let dict = unsafe { Dictionary::read_unchecked(reader)? };

let tokenizer = Tokenizer::new(dict)
.ignore_space(args.ignore_space)?
Expand Down
7 changes: 0 additions & 7 deletions docs/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,3 @@ and sentences in `test.txt` with the following command.
```
$ cargo run --release -p benchmark -- -i system.dic.zst < test.txt
```

If you can guarantee that `system.dic.zst` is exported from this library,
you can specify `--features=unchecked` for faster tokenization.

```
$ cargo run --release -p benchmark --features=unchecked -- -i system.dic.zst < test.txt
```
3 changes: 0 additions & 3 deletions map/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ fn main() -> Result<(), Box<dyn Error>> {

eprintln!("Loading the dictionary...");
let reader = zstd::Decoder::new(File::open(args.sysdic_in)?)?;
#[cfg(not(feature = "unchecked"))]
let dict = Dictionary::read(reader)?;
#[cfg(feature = "unchecked")]
let dict = unsafe { Dictionary::read_unchecked(reader)? };

eprintln!("Loading and doing the mapping...");
let lmap = {
Expand Down
40 changes: 0 additions & 40 deletions vibrato/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ pub(crate) struct DictionaryInner {
/// Dictionary for tokenization.
pub struct Dictionary {
pub(crate) data: DictionaryInner,
pub(crate) need_check: bool,
}

impl Dictionary {
Expand Down Expand Up @@ -182,45 +181,6 @@ impl Dictionary {
{
Ok(Self {
data: Self::read_common(rdr)?,
need_check: true,
})
}

/// Creates a dictionary from raw dictionary data.
///
/// The argument must be a byte sequence exported by the [`Dictionary::write()`] function.
///
/// Unlike the [`Dictionary::read()`] function, this function does not check the correctness of
/// the dictionary.
///
/// # Examples
///
/// ```no_run
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// use std::fs::File;
///
/// use vibrato::Dictionary;
///
/// let reader = File::open("path/to/system.dic")?;
/// let dict = unsafe { Dictionary::read_unchecked(reader)? } ;
/// # Ok(())
/// # }
/// ```
///
/// # Safety
///
/// The given reader must be a correct file exported by [`Dictionary::write()`].
///
/// # Errors
///
/// When bincode generates an error, it will be returned as is.
pub unsafe fn read_unchecked<R>(rdr: R) -> Result<Self>
where
R: Read,
{
Ok(Self {
data: Self::read_common(rdr)?,
need_check: false,
})
}

Expand Down
1 change: 0 additions & 1 deletion vibrato/src/dictionary/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ impl SystemDictionaryBuilder {
char_prop,
unk_handler,
},
need_check: false,
})
}

Expand Down
3 changes: 0 additions & 3 deletions vibrato/src/dictionary/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ pub trait Connector {
pub trait ConnectorCost: Connector {
/// Gets the value of the connection matrix
fn cost(&self, right_id: u16, left_id: u16) -> i32;

/// Gets the value of the connection matrix
unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32;
}

#[derive(Decode, Encode)]
Expand Down
14 changes: 0 additions & 14 deletions vibrato/src/dictionary/connector/dual_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,6 @@ impl ConnectorCost for DualConnector {
);
matrix_cost + raw_cost
}

#[inline(always)]
unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 {
let right_conn_id = *self.right_conn_id_map.get_unchecked(usize::from(right_id));
let left_conn_id = *self.left_conn_id_map.get_unchecked(usize::from(left_id));
let matrix_cost = self
.matrix_connector
.cost_unchecked(right_conn_id, left_conn_id);
let raw_cost = self.raw_scorer.accumulate_cost(
&[*self.right_feat_ids.get_unchecked(usize::from(right_id))],
&[*self.left_feat_ids.get_unchecked(usize::from(left_id))],
);
matrix_cost + raw_cost
}
}

#[cfg(test)]
Expand Down
7 changes: 0 additions & 7 deletions vibrato/src/dictionary/connector/matrix_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,6 @@ impl ConnectorCost for MatrixConnector {
let index = self.index(right_id, left_id);
i32::from(self.data[index])
}

#[inline(always)]
unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 {
let index = self.index(right_id, left_id);
// The tokenization time can be shortened by 5--10%.
i32::from(*self.data.get_unchecked(index))
}
}

#[cfg(test)]
Expand Down
6 changes: 0 additions & 6 deletions vibrato/src/dictionary/connector/raw_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,6 @@ impl ConnectorCost for RawConnector {
self.left_feature_ids(left_id),
)
}

/// TODO: Implement unchecked optimization.
#[inline(always)]
unsafe fn cost_unchecked(&self, right_id: u16, left_id: u16) -> i32 {
self.cost(right_id, left_id)
}
}

/// Builder for components of [`RawConnector`] using simple data structures.
Expand Down
16 changes: 0 additions & 16 deletions vibrato/src/dictionary/lexicon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ impl Lexicon {
})
}

#[inline(always)]
pub unsafe fn common_prefix_iterator_unchecked<'a>(
&'a self,
input: &'a [char],
) -> impl Iterator<Item = LexMatch> + 'a {
self.map
.common_prefix_iterator_unchecked(input)
.map(move |(word_id, end_char)| {
LexMatch::new(
WordIdx::new(self.lex_type, word_id),
self.params.get(usize::from_u32(word_id)),
end_char,
)
})
}

/// Do NOT make this function public to maintain consistency in
/// the connection-id mapping among members of `Dictionary`.
/// The consistency is managed in `Dictionary`.
Expand Down
12 changes: 0 additions & 12 deletions vibrato/src/dictionary/lexicon/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,6 @@ impl WordMap {
.map(move |word_id| (word_id, e.end_char))
})
}

#[inline(always)]
pub unsafe fn common_prefix_iterator_unchecked<'a>(
&'a self,
input: &'a [char],
) -> impl Iterator<Item = (u32, usize)> + 'a {
self.trie.common_prefix_iterator(input).flat_map(move |e| {
self.postings
.ids_unchecked(usize::from_u32(e.value))
.map(move |word_id| (word_id, e.end_char))
})
}
}

#[derive(Default)]
Expand Down
7 changes: 0 additions & 7 deletions vibrato/src/dictionary/lexicon/map/posting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ impl Postings {
let len = usize::from_u32(self.data[i]);
self.data[i + 1..i + 1 + len].iter().cloned()
}

#[inline(always)]
pub unsafe fn ids_unchecked(&'_ self, i: usize) -> impl Iterator<Item = u32> + '_ {
let len = usize::from_u32(self.data[i]);
// The tokenization time can be shortened by 10%.
self.data.get_unchecked(i + 1..i + 1 + len).iter().cloned()
}
}

#[derive(Default)]
Expand Down
88 changes: 3 additions & 85 deletions vibrato/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,13 @@ impl Tokenizer {
break;
}

if self.dict.need_check {
self.add_lattice_edges(sent, lattice, start_node, start_word, connector);
} else {
unsafe {
self.add_lattice_edges_unchecked(
sent, lattice, start_node, start_word, connector,
);
}
}
self.add_lattice_edges(sent, lattice, start_node, start_word, connector);

start_word += 1;
start_node = start_word;
}

if self.dict.need_check {
lattice.insert_eos(start_node, connector);
} else {
unsafe {
lattice.insert_eos_unchecked(start_node, connector);
}
}
lattice.insert_eos(start_node, connector);
}

fn add_lattice_edges<C>(
Expand All @@ -164,9 +150,7 @@ impl Tokenizer {
{
let mut has_matched = false;

// Safety: `start_word < sent.len_char()` is already checked in `build_lattice()`.
debug_assert!(start_word < sent.len_char());
let suffix = unsafe { sent.chars().get_unchecked(start_word..) };
let suffix = &sent.chars()[start_word..];

if let Some(user_lexicon) = self.dict.user_lexicon() {
for m in user_lexicon.common_prefix_iterator(suffix) {
Expand Down Expand Up @@ -213,72 +197,6 @@ impl Tokenizer {
},
);
}

unsafe fn add_lattice_edges_unchecked<C>(
&self,
sent: &Sentence,
lattice: &mut Lattice,
start_node: usize,
start_word: usize,
connector: &C,
) where
C: ConnectorCost,
{
let mut has_matched = false;

// Safety: `start_word < sent.len_char()` is already checked in `build_lattice()`.
debug_assert!(start_word < sent.len_char());
let suffix = sent.chars().get_unchecked(start_word..);

if let Some(user_lexicon) = self.dict.user_lexicon() {
for m in user_lexicon.common_prefix_iterator_unchecked(suffix) {
debug_assert!(start_word + m.end_char <= sent.len_char());
lattice.insert_node_unchecked(
start_node,
start_word,
start_word + m.end_char,
m.word_idx,
m.word_param,
connector,
);
has_matched = true;
}
}

for m in self
.dict
.system_lexicon()
.common_prefix_iterator_unchecked(suffix)
{
debug_assert!(start_word + m.end_char <= sent.len_char());
lattice.insert_node_unchecked(
start_node,
start_word,
start_word + m.end_char,
m.word_idx,
m.word_param,
connector,
);
has_matched = true;
}

self.dict.unk_handler().gen_unk_words(
sent,
start_word,
has_matched,
self.max_grouping_len,
|w| {
lattice.insert_node_unchecked(
start_node,
w.start_char(),
w.end_char(),
w.word_idx(),
w.word_param(),
connector,
);
},
);
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 0ec8152

Please sign in to comment.