diff --git a/examples/serialize_bench.rs b/examples/serialize_bench.rs index 50012c3..d9ee85c 100644 --- a/examples/serialize_bench.rs +++ b/examples/serialize_bench.rs @@ -10,19 +10,20 @@ static ALLOCATOR: jemallocator::Jemalloc = jemallocator::Jemalloc; fn main() { let args: Vec = env::args().collect(); - if args.len() <= 1 { - panic!("Usage: cargo run --example serialize_bench "); + if args.len() <= 3 { + panic!("Usage: cargo run --example serialize_bench "); } - let number_of_universities: &String = &args[1]; - let zarr_path = format!("{}-lubm", number_of_universities); + let rdf_path: &String = &args[1]; + let zarr_path: &String = &args[2]; + let shard_size: &String = &args[3]; let before = Instant::now(); LocalStorage::new(MatrixLayout) .serialize( - format!("{}.zarr", zarr_path).as_str(), - format!("../lubm-uba-improved/out/{}.ttl", zarr_path).as_str(), - ChunkingStrategy::Sharding(10240), + &zarr_path.as_str(), + &rdf_path.as_str(), + ChunkingStrategy::Sharding(shard_size.parse::().unwrap()), ) .unwrap(); diff --git a/src/dictionary.rs b/src/dictionary.rs index aa1a401..6a74608 100644 --- a/src/dictionary.rs +++ b/src/dictionary.rs @@ -77,10 +77,7 @@ impl Dictionary { pub fn get_predicate_idx(&self, predicate: &str) -> Option { let mut locator = self.predicates.locator(); - match locator.run(predicate) { - Some(value) => Some(value + 1), - None => None, - } + locator.run(predicate).map(|value| value + 1) } pub fn get_predicate_idx_unchecked(&self, predicate: &str) -> usize { diff --git a/src/engine/array.rs b/src/engine/array.rs index d4042fb..5ca2f5d 100644 --- a/src/engine/array.rs +++ b/src/engine/array.rs @@ -1,21 +1,25 @@ -use sprs::CsVec; +use sprs::{CsMat, TriMat}; use crate::storage::ZarrArray; use super::{EngineResult, EngineStrategy}; -impl EngineStrategy> for ZarrArray { - fn get_subject(&self, index: usize) -> EngineResult> { - let selection = CsVec::new(self.rows(), vec![index], vec![1]); - Ok(&self.transpose_view() * &selection) +impl EngineStrategy> for ZarrArray { + fn get_subject(&self, index: usize) -> EngineResult> { + let mut matrix = TriMat::new((self.rows(), self.rows())); + matrix.add_triplet(index, index, 1); + let matrix = matrix.to_csc(); + Ok(&matrix * self) } - fn get_predicate(&self, index: usize) -> EngineResult> { + fn get_predicate(&self, _value: u8) -> EngineResult> { unimplemented!() } - fn get_object(&self, index: usize) -> EngineResult> { - let selection = CsVec::new(self.cols(), vec![index], vec![1]); - Ok(self * &selection) + fn get_object(&self, index: usize) -> EngineResult> { + let mut matrix = TriMat::new((self.cols(), self.cols())); + matrix.add_triplet(index, index, 1); + let matrix = matrix.to_csc(); + Ok(self * &matrix) } } diff --git a/src/engine/chunk.rs b/src/engine/chunk.rs index 88029ab..1761804 100644 --- a/src/engine/chunk.rs +++ b/src/engine/chunk.rs @@ -23,7 +23,7 @@ impl EngineStrategy> for Array { } } - fn get_predicate(&self, index: usize) -> EngineResult> { + fn get_predicate(&self, _index: u8) -> EngineResult> { unimplemented!() } diff --git a/src/engine/mod.rs b/src/engine/mod.rs index b8acc0b..f4b2400 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -7,6 +7,6 @@ pub type EngineResult = Result; pub trait EngineStrategy { fn get_subject(&self, index: usize) -> EngineResult; - fn get_predicate(&self, index: usize) -> EngineResult; + fn get_predicate(&self, index: u8) -> EngineResult; fn get_object(&self, index: usize) -> EngineResult; } diff --git a/src/storage/tabular.rs b/src/storage/tabular.rs index 853dfda..8cca3f1 100644 --- a/src/storage/tabular.rs +++ b/src/storage/tabular.rs @@ -80,13 +80,12 @@ where graph .iter() .enumerate() - .map(|(subject, triples)| { + .flat_map(|(subject, triples)| { triples .iter() .map(|&(predicate, object)| (subject as u32, predicate, object)) .collect::>() }) - .flatten() .collect::>() } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e50f97f..3f08f76 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -180,7 +180,6 @@ impl Graph { Object::GCHQ.get_idx(dictionary), Predicate::Manufacturer.get_idx(dictionary), ); - ans.to_csc() } } diff --git a/tests/get_object_test.rs b/tests/get_object_test.rs index 644b083..15bfd4e 100644 --- a/tests/get_object_test.rs +++ b/tests/get_object_test.rs @@ -2,8 +2,7 @@ use remote_hdt::{ engine::EngineStrategy, storage::{matrix::MatrixLayout, tabular::TabularLayout, ChunkingStrategy, LocalStorage}, }; -use sprs::CsVec; - +use sprs::TriMat; mod common; #[test] @@ -49,5 +48,8 @@ fn get_object_tabular_test() { .get_object(common::Object::Alan.get_idx(&storage.get_dictionary())) .unwrap(); - assert_eq!(actual, CsVec::new(4, vec![1], vec![3])) + let mut expected = TriMat::new((4, 9)); + expected.add_triplet(1, 3, 3); + let expected = expected.to_csc(); + assert_eq!(actual, expected) } diff --git a/tests/get_subject_test.rs b/tests/get_subject_test.rs index ee57cb1..4a8d5fb 100644 --- a/tests/get_subject_test.rs +++ b/tests/get_subject_test.rs @@ -2,8 +2,7 @@ use remote_hdt::{ engine::EngineStrategy, storage::{matrix::MatrixLayout, tabular::TabularLayout, ChunkingStrategy, LocalStorage}, }; -use sprs::CsVec; - +use sprs::TriMat; mod common; #[test] @@ -49,8 +48,12 @@ fn get_subject_tabular_test() { .get_subject(common::Subject::Alan.get_idx(&storage.get_dictionary())) .unwrap(); - assert_eq!( - actual, - CsVec::new(9, vec![0, 1, 2, 7, 8], vec![2, 4, 5, 7, 8]) - ) + let mut result = TriMat::new((4, 9)); + result.add_triplet(0, 0, 2); + result.add_triplet(0, 1, 4); + result.add_triplet(0, 2, 5); + result.add_triplet(0, 7, 7); + result.add_triplet(0, 8, 8); + let result = result.to_csc(); + assert_eq!(actual, result) }