diff --git a/.gitignore b/.gitignore index da8219f..ae5aa3c 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,6 @@ fixtures/*.ipynb # Mac droppings .DS_Store + +# Gen tracking dir +.gen/ diff --git a/Cargo.lock b/Cargo.lock index db8004f..bb301fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,7 +89,7 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -99,7 +99,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -228,9 +228,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.18" +version = "1.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "45bcde016d64c21da4be18b655631e5ab6d3107607e71a73a9f53eb48aae23fb" dependencies = [ "shlex", ] @@ -411,6 +411,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -423,6 +433,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fastrand" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" + [[package]] name = "fixedbitset" version = "0.4.2" @@ -526,6 +542,7 @@ dependencies = [ "rusqlite_migration", "sha2", "tempdir", + "tempfile", ] [[package]] @@ -598,9 +615,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -794,6 +811,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "log" version = "0.4.22" @@ -829,9 +852,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmap2" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" dependencies = [ "libc", ] @@ -1318,6 +1341,19 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustix" +version = "0.38.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "serde" version = "1.0.210" @@ -1403,6 +1439,19 @@ dependencies = [ "remove_dir_all", ] +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "tokio" version = "1.40.0" @@ -1558,6 +1607,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 286474b..1e71daa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ itertools = "0.13.0" rusqlite = { version = "0.31.0", features = ["bundled", "array", "session"] } rusqlite_migration = { version = "1.2.0" , features = ["from-directory"]} sha2 = "0.10.8" +tempfile = "3.12.0" noodles = { version = "0.78.0", features = ["core", "vcf", "fasta", "async"] } petgraph = "0.6.5" chrono = "0.4.38" diff --git a/fixtures/no_path.gfa b/fixtures/no_path.gfa new file mode 100644 index 0000000..58e1fe5 --- /dev/null +++ b/fixtures/no_path.gfa @@ -0,0 +1,7 @@ +S 0 GGGG * +S 1 AAAA * +S 2 TTTT * +S 3 CCCC * +L 1 + 2 + * +L 2 + 0 + * +L 0 + 3 + * diff --git a/src/exports.rs b/src/exports.rs new file mode 100644 index 0000000..5bcbabe --- /dev/null +++ b/src/exports.rs @@ -0,0 +1 @@ +pub mod gfa; diff --git a/src/exports/gfa.rs b/src/exports/gfa.rs new file mode 100644 index 0000000..e0bf431 --- /dev/null +++ b/src/exports/gfa.rs @@ -0,0 +1,212 @@ +use itertools::Itertools; +use rusqlite::Connection; +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::fs::File; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; + +use crate::models::{ + self, + block_group::BlockGroup, + block_group_edge::BlockGroupEdge, + collection::Collection, + edge::{Edge, GroupBlock}, + path::Path, + sequence::Sequence, + strand::Strand, +}; + +pub fn export_gfa(conn: &Connection, collection_name: &str, filename: &PathBuf) { + let block_groups = Collection::get_block_groups(conn, collection_name); + + let mut edge_set = HashSet::new(); + for block_group in block_groups { + let block_group_edges = BlockGroupEdge::edges_for_block_group(conn, block_group.id); + edge_set.extend(block_group_edges.into_iter()); + } + + let mut edges = edge_set.into_iter().collect(); + let (blocks, boundary_edges) = Edge::blocks_from_edges(conn, &edges); + edges.extend(boundary_edges.clone()); + + let (graph, edges_by_node_pair) = Edge::build_graph(conn, &edges, &blocks); + + let mut file = File::create(filename).unwrap(); + let mut writer = BufWriter::new(file); + + let mut terminal_block_ids = HashSet::new(); + for block in &blocks { + if block.sequence_hash == Sequence::PATH_START_HASH + || block.sequence_hash == Sequence::PATH_END_HASH + { + terminal_block_ids.insert(block.id); + continue; + } + writer + .write_all(&segment_line(&block.sequence, block.id as usize).into_bytes()) + .unwrap_or_else(|_| { + panic!( + "Error writing segment with sequence {} to GFA stream", + block.sequence, + ) + }); + } + + let blocks_by_id = blocks + .clone() + .into_iter() + .map(|block| (block.id, block)) + .collect::>(); + + for (source, target, ()) in graph.all_edges() { + if terminal_block_ids.contains(&source) || terminal_block_ids.contains(&target) { + continue; + } + let edge = edges_by_node_pair.get(&(source, target)).unwrap(); + writer + .write_all( + &link_line(source, edge.source_strand, target, edge.target_strand).into_bytes(), + ) + .unwrap_or_else(|_| { + panic!( + "Error writing link from segment {} to {} to GFA stream", + source, target, + ) + }); + } +} + +fn segment_line(sequence: &str, index: usize) -> String { + format!("S\t{}\t{}\t{}\n", index, sequence, "*") +} + +fn link_line( + source_index: i32, + source_strand: Strand, + target_index: i32, + target_strand: Strand, +) -> String { + format!( + "L\t{}\t{}\t{}\t{}\t*\n", + source_index, source_strand, target_index, target_strand + ) +} + +mod tests { + use rusqlite::Connection; + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + use crate::imports::gfa::import_gfa; + use crate::models::{ + block_group::BlockGroup, block_group_edge::BlockGroupEdge, collection::Collection, + edge::Edge, sequence::Sequence, + }; + use crate::test_helpers::get_connection; + use tempfile::tempdir; + + #[test] + fn test_simple_export() { + // Sets up a basic graph and then exports it to a GFA file + let conn = get_connection(None); + + let collection_name = "test collection"; + let collection = Collection::create(&conn, collection_name); + let block_group = BlockGroup::create(&conn, collection_name, None, "test block group"); + let sequence1 = Sequence::new() + .sequence_type("DNA") + .sequence("AAAA") + .save(&conn); + let sequence2 = Sequence::new() + .sequence_type("DNA") + .sequence("TTTT") + .save(&conn); + let sequence3 = Sequence::new() + .sequence_type("DNA") + .sequence("GGGG") + .save(&conn); + let sequence4 = Sequence::new() + .sequence_type("DNA") + .sequence("CCCC") + .save(&conn); + + let edge1 = Edge::create( + &conn, + Sequence::PATH_START_HASH.to_string(), + 0, + Strand::Forward, + sequence1.hash.clone(), + 0, + Strand::Forward, + 0, + 0, + ); + let edge2 = Edge::create( + &conn, + sequence1.hash, + 4, + Strand::Forward, + sequence2.hash.clone(), + 0, + Strand::Forward, + 0, + 0, + ); + let edge3 = Edge::create( + &conn, + sequence2.hash, + 4, + Strand::Forward, + sequence3.hash.clone(), + 0, + Strand::Forward, + 0, + 0, + ); + let edge4 = Edge::create( + &conn, + sequence3.hash, + 4, + Strand::Forward, + sequence4.hash.clone(), + 0, + Strand::Forward, + 0, + 0, + ); + let edge5 = Edge::create( + &conn, + sequence4.hash, + 4, + Strand::Forward, + Sequence::PATH_END_HASH.to_string(), + 0, + Strand::Forward, + 0, + 0, + ); + + BlockGroupEdge::bulk_create( + &conn, + block_group.id, + &[edge1.id, edge2.id, edge3.id, edge4.id, edge5.id], + ); + let all_sequences = BlockGroup::get_all_sequences(&conn, block_group.id); + + let temp_dir = tempdir().expect("Couldn't get handle to temp directory"); + let mut gfa_path = PathBuf::from(temp_dir.path()); + gfa_path.push("intermediate.gfa"); + + export_gfa(&conn, collection_name, &gfa_path); + // NOTE: Not directly checking file contents because segments are written in random order + import_gfa(&gfa_path, "test collection 2", &conn); + + let block_group2 = Collection::get_block_groups(&conn, "test collection 2") + .pop() + .unwrap(); + let all_sequences2 = BlockGroup::get_all_sequences(&conn, block_group2.id); + + assert_eq!(all_sequences, all_sequences2); + } +} diff --git a/src/imports/fasta.rs b/src/imports/fasta.rs index 43fd274..4cd3989 100644 --- a/src/imports/fasta.rs +++ b/src/imports/fasta.rs @@ -5,8 +5,8 @@ use std::str; use crate::models::file_types::FileTypes; use crate::models::operations::{FileAddition, Operation, OperationSummary}; use crate::models::{ - self, block_group::BlockGroup, block_group_edge::BlockGroupEdge, edge::Edge, metadata, - path::Path, sequence::Sequence, strand::Strand, + self, block_group::BlockGroup, block_group_edge::BlockGroupEdge, collection::Collection, + edge::Edge, metadata, path::Path, sequence::Sequence, strand::Strand, }; use crate::operation_management; use noodles::fasta; @@ -29,8 +29,8 @@ pub fn import_fasta( let operation = Operation::create(operation_conn, &db_uuid, name, "fasta_addition", change.id); - if !models::Collection::exists(conn, name) { - let collection = models::Collection::create(conn, name); + if !Collection::exists(conn, name) { + let collection = Collection::create(conn, name); let mut summary: HashMap = HashMap::new(); for result in reader.records() { diff --git a/src/imports/gfa.rs b/src/imports/gfa.rs index 3de45b0..82a84cf 100644 --- a/src/imports/gfa.rs +++ b/src/imports/gfa.rs @@ -1,11 +1,14 @@ use gfa_reader::Gfa; use rusqlite::Connection; use std::collections::{HashMap, HashSet}; +use std::path::Path as FilePath; +use std::path::PathBuf; use crate::models::{ self, block_group::BlockGroup, block_group_edge::BlockGroupEdge, + collection::Collection, edge::{Edge, EdgeData}, path::Path, sequence::Sequence, @@ -20,10 +23,10 @@ fn bool_to_strand(direction: bool) -> Strand { } } -pub fn import_gfa(gfa_path: &str, collection_name: &str, conn: &Connection) { - models::Collection::create(conn, collection_name); +pub fn import_gfa(gfa_path: &FilePath, collection_name: &str, conn: &Connection) { + Collection::create(conn, collection_name); let block_group = BlockGroup::create(conn, collection_name, None, ""); - let gfa: Gfa = Gfa::parse_gfa_file(gfa_path); + let gfa: Gfa = Gfa::parse_gfa_file(gfa_path.to_str().unwrap()); let mut sequences_by_segment_id: HashMap = HashMap::new(); for segment in &gfa.segments { @@ -221,7 +224,7 @@ mod tests { gfa_path.push("fixtures/simple.gfa"); let collection_name = "test".to_string(); let conn = &get_connection(None); - import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); + import_gfa(&gfa_path, &collection_name, conn); let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); let path = Path::get_paths( @@ -238,13 +241,29 @@ mod tests { assert_eq!(result, "ATGGCATATTCGCAGCT"); } + #[test] + fn test_import_no_path_gfa() { + let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + gfa_path.push("fixtures/no_path.gfa"); + let collection_name = "no path".to_string(); + let conn = &get_connection(None); + import_gfa(&gfa_path, &collection_name, conn); + + let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); + let all_sequences = BlockGroup::get_all_sequences(conn, block_group_id); + assert_eq!( + all_sequences, + HashSet::from_iter(vec!["AAAATTTTGGGGCCCC".to_string()]) + ); + } + #[test] fn test_import_gfa_with_walk() { let mut gfa_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); gfa_path.push("fixtures/walk.gfa"); let collection_name = "walk".to_string(); let conn = &mut get_connection(None); - import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); + import_gfa(&gfa_path, &collection_name, conn); let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); let path = Path::get_paths( @@ -267,7 +286,7 @@ mod tests { gfa_path.push("fixtures/reverse_strand.gfa"); let collection_name = "test".to_string(); let conn = &get_connection(None); - import_gfa(gfa_path.to_str().unwrap(), &collection_name, conn); + import_gfa(&gfa_path, &collection_name, conn); let block_group_id = BlockGroup::get_id(conn, &collection_name, None, ""); let path = Path::get_paths( diff --git a/src/lib.rs b/src/lib.rs index 9c35216..5040e64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use std::str; pub mod config; +pub mod exports; pub mod graph; pub mod imports; pub mod migrations; diff --git a/src/main.rs b/src/main.rs index 6458bbd..b04f3fa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,8 @@ use clap::{Parser, Subcommand}; use gen::config; use gen::config::get_operation_connection; + +use gen::exports::gfa::export_gfa; use gen::get_connection; use gen::imports::fasta::import_fasta; use gen::imports::gfa::import_gfa; @@ -11,6 +13,7 @@ use gen::operation_management; use gen::updates::vcf::update_with_vcf; use rusqlite::types::Value; use std::fmt::Debug; +use std::path::PathBuf; use std::str; #[derive(Parser)] @@ -68,6 +71,14 @@ enum Commands { }, /// View operations carried out against a database Operations {}, + Export { + /// The name of the collection to export + #[arg(short, long)] + name: String, + /// The name of the GFA file to export to + #[arg(short, long)] + gfa: String, + }, } fn main() { @@ -102,7 +113,7 @@ fn main() { &operation_conn, ); } else if gfa.is_some() { - import_gfa(&gfa.clone().unwrap(), name, &conn); + import_gfa(&PathBuf::from(gfa.clone().unwrap()), name, &conn); } else { panic!( "ERROR: Import command attempted but no recognized file format was specified" @@ -163,6 +174,11 @@ fn main() { Some(Commands::Checkout { id }) => { operation_management::move_to(&conn, &Operation::get_by_id(&operation_conn, *id)); } + Some(Commands::Export { name, gfa }) => { + conn.execute("BEGIN TRANSACTION", []).unwrap(); + export_gfa(&conn, name, &PathBuf::from(gfa)); + conn.execute("END TRANSACTION", []).unwrap(); + } None => {} } } diff --git a/src/models.rs b/src/models.rs index 855ae09..0539bc5 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,10 +1,10 @@ use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; -use sha2::{Digest, Sha256}; use std::fmt::*; pub mod block_group; pub mod block_group_edge; +pub mod collection; pub mod edge; pub mod file_types; pub mod metadata; @@ -16,54 +16,6 @@ pub mod strand; use crate::models; -#[derive(Clone, Debug)] -pub struct Collection { - pub name: String, -} - -impl Collection { - pub fn exists(conn: &Connection, name: &str) -> bool { - let mut stmt = conn - .prepare("select name from collection where name = ?1") - .unwrap(); - stmt.exists([name]).unwrap() - } - pub fn create(conn: &Connection, name: &str) -> Collection { - let mut stmt = conn - .prepare("INSERT INTO collection (name) VALUES (?1) RETURNING *") - .unwrap(); - let mut rows = stmt - .query_map((name,), |row| Ok(models::Collection { name: row.get(0)? })) - .unwrap(); - rows.next().unwrap().unwrap() - } - - pub fn bulk_create(conn: &Connection, names: &Vec) -> Vec { - let placeholders = names.iter().map(|_| "(?)").collect::>().join(", "); - let q = format!( - "INSERT INTO collection (name) VALUES {} RETURNING *", - placeholders - ); - let mut stmt = conn.prepare(&q).unwrap(); - let rows = stmt - .query_map(params_from_iter(names), |row| { - Ok(Collection { name: row.get(0)? }) - }) - .unwrap(); - rows.map(|row| row.unwrap()).collect() - } - - pub fn query(conn: &Connection, query: &str, placeholders: Vec) -> Vec { - let mut stmt = conn.prepare(query).unwrap(); - let rows = stmt - .query_map(params_from_iter(placeholders), |row| { - Ok(Collection { name: row.get(0)? }) - }) - .unwrap(); - rows.map(|row| row.unwrap()).collect() - } -} - #[derive(Debug)] pub struct Sample { pub name: String, diff --git a/src/models/block_group.rs b/src/models/block_group.rs index 19b8d85..ccd98eb 100644 --- a/src/models/block_group.rs +++ b/src/models/block_group.rs @@ -1,5 +1,4 @@ use intervaltree::IntervalTree; -use itertools::Itertools; use petgraph::graphmap::DiGraphMap; use petgraph::Direction; use rusqlite::{params_from_iter, types::Value as SQLValue, Connection}; @@ -7,7 +6,7 @@ use std::collections::{HashMap, HashSet}; use crate::graph::all_simple_paths; use crate::models::block_group_edge::BlockGroupEdge; -use crate::models::edge::{Edge, EdgeData}; +use crate::models::edge::{Edge, EdgeData, GroupBlock}; use crate::models::path::{NewBlock, Path, PathData}; use crate::models::path_edge::PathEdge; use crate::models::sequence::Sequence; @@ -28,21 +27,6 @@ pub struct BlockGroupData<'a> { pub name: String, } -#[derive(Clone)] -pub struct GroupBlock { - pub id: i32, - pub sequence_hash: String, - pub sequence: String, - pub start: i32, - pub end: i32, -} - -#[derive(Eq, Hash, PartialEq)] -pub struct BlockKey { - pub sequence_hash: String, - pub coordinate: i32, -} - #[derive(Clone, Debug)] pub struct PathChange { pub block_group_id: i32, @@ -254,218 +238,11 @@ impl BlockGroup { } } - pub fn get_block_boundaries( - source_edges: Option<&Vec<&Edge>>, - target_edges: Option<&Vec<&Edge>>, - sequence_length: i32, - ) -> Vec { - let mut block_boundary_coordinates = HashSet::new(); - if let Some(actual_source_edges) = source_edges { - for source_edge in actual_source_edges { - if source_edge.source_coordinate > 0 - && source_edge.source_coordinate < sequence_length - { - block_boundary_coordinates.insert(source_edge.source_coordinate); - } - } - } - if let Some(actual_target_edges) = target_edges { - for target_edge in actual_target_edges { - if target_edge.target_coordinate > 0 - && target_edge.target_coordinate < sequence_length - { - block_boundary_coordinates.insert(target_edge.target_coordinate); - } - } - } - - block_boundary_coordinates - .into_iter() - .sorted_by(|c1, c2| Ord::cmp(&c1, &c2)) - .collect::>() - } - - pub fn blocks_from_edges(conn: &Connection, edges: &Vec) -> (Vec, Vec) { - let mut sequence_hashes = HashSet::new(); - let mut edges_by_source_hash: HashMap<&str, Vec<&Edge>> = HashMap::new(); - let mut edges_by_target_hash: HashMap<&str, Vec<&Edge>> = HashMap::new(); - for edge in edges { - if edge.source_hash != Sequence::PATH_START_HASH { - sequence_hashes.insert(edge.source_hash.as_str()); - edges_by_source_hash - .entry(&edge.source_hash) - .and_modify(|edges| edges.push(edge)) - .or_default(); - } - if edge.target_hash != Sequence::PATH_END_HASH { - sequence_hashes.insert(edge.target_hash.as_str()); - edges_by_target_hash - .entry(&edge.target_hash) - .and_modify(|edges| edges.push(edge)) - .or_default(); - } - } - - let sequences_by_hash = - Sequence::sequences_by_hash(conn, sequence_hashes.into_iter().collect::>()); - let mut blocks = vec![]; - let mut block_index = 0; - let mut boundary_edges = vec![]; - for (hash, sequence) in sequences_by_hash.into_iter() { - let block_boundaries = BlockGroup::get_block_boundaries( - edges_by_source_hash.get(hash.as_str()), - edges_by_target_hash.get(hash.as_str()), - sequence.length, - ); - for block_boundary in &block_boundaries { - // NOTE: Most of this data is bogus, the Edge struct is just a convenient wrapper - // for the data we need to set up boundary edges in the block group graph - boundary_edges.push(Edge { - id: -1, - source_hash: hash.clone(), - source_coordinate: *block_boundary, - source_strand: Strand::Unknown, - target_hash: hash.clone(), - target_coordinate: *block_boundary, - target_strand: Strand::Unknown, - chromosome_index: 0, - phased: 0, - }); - } - - if !block_boundaries.is_empty() { - let start = 0; - let end = block_boundaries[0]; - let block_sequence = sequence.get_sequence(start, end).to_string(); - let first_block = GroupBlock { - id: block_index, - sequence_hash: hash.clone(), - sequence: block_sequence, - start, - end, - }; - blocks.push(first_block); - block_index += 1; - for (start, end) in block_boundaries.clone().into_iter().tuple_windows() { - let block_sequence = sequence.get_sequence(start, end).to_string(); - let block = GroupBlock { - id: block_index, - sequence_hash: hash.clone(), - sequence: block_sequence, - start, - end, - }; - blocks.push(block); - block_index += 1; - } - let start = block_boundaries[block_boundaries.len() - 1]; - let end = sequence.length; - let block_sequence = sequence.get_sequence(start, end).to_string(); - let last_block = GroupBlock { - id: block_index, - sequence_hash: hash.clone(), - sequence: block_sequence, - start, - end, - }; - blocks.push(last_block); - block_index += 1; - } else { - blocks.push(GroupBlock { - id: block_index, - sequence_hash: hash.clone(), - sequence: sequence.get_sequence(None, None), - start: 0, - end: sequence.length, - }); - block_index += 1; - } - } - - // NOTE: We need a dedicated start node and a dedicated end node for the graph formed by the - // block group, since different paths in the block group may start or end at different - // places on sequences. These two "start sequence" and "end sequence" blocks will serve - // that role. - let start_sequence = Sequence::sequence_from_hash(conn, Sequence::PATH_START_HASH).unwrap(); - let start_block = GroupBlock { - id: block_index + 1, - sequence_hash: start_sequence.hash.clone(), - sequence: "".to_string(), - start: 0, - end: 0, - }; - blocks.push(start_block); - let end_sequence = Sequence::sequence_from_hash(conn, Sequence::PATH_END_HASH).unwrap(); - let end_block = GroupBlock { - id: block_index + 2, - sequence_hash: end_sequence.hash.clone(), - sequence: "".to_string(), - start: 0, - end: 0, - }; - blocks.push(end_block); - (blocks, boundary_edges) - } - pub fn get_all_sequences(conn: &Connection, block_group_id: i32) -> HashSet { let mut edges = BlockGroupEdge::edges_for_block_group(conn, block_group_id); - let (blocks, boundary_edges) = BlockGroup::blocks_from_edges(conn, &edges); + let (blocks, boundary_edges) = Edge::blocks_from_edges(conn, &edges); edges.extend(boundary_edges.clone()); - - let blocks_by_start = blocks - .clone() - .into_iter() - .map(|block| { - ( - BlockKey { - sequence_hash: block.sequence_hash, - coordinate: block.start, - }, - block.id, - ) - }) - .collect::>(); - let blocks_by_end = blocks - .clone() - .into_iter() - .map(|block| { - ( - BlockKey { - sequence_hash: block.sequence_hash, - coordinate: block.end, - }, - block.id, - ) - }) - .collect::>(); - let blocks_by_id = blocks - .clone() - .into_iter() - .map(|block| (block.id, block)) - .collect::>(); - - let mut graph: DiGraphMap = DiGraphMap::new(); - for block in blocks { - graph.add_node(block.id); - } - for edge in edges { - let source_key = BlockKey { - sequence_hash: edge.source_hash, - coordinate: edge.source_coordinate, - }; - let source_id = blocks_by_end.get(&source_key); - let target_key = BlockKey { - sequence_hash: edge.target_hash, - coordinate: edge.target_coordinate, - }; - let target_id = blocks_by_start.get(&target_key); - - if let Some(source_id_value) = source_id { - if let Some(target_id_value) = target_id { - graph.add_edge(*source_id_value, *target_id_value, ()); - } - } - } + let (graph, _) = Edge::build_graph(conn, &edges, &blocks); let mut start_nodes = vec![]; let mut end_nodes = vec![]; @@ -479,6 +256,12 @@ impl BlockGroup { end_nodes.push(node); } } + + let blocks_by_id = blocks + .clone() + .into_iter() + .map(|block| (block.id, block)) + .collect::>(); let mut sequences = HashSet::::new(); for start_node in start_nodes { @@ -486,7 +269,11 @@ impl BlockGroup { // TODO: maybe make all_simple_paths return a single path id where start == end if start_node == *end_node { let block = blocks_by_id.get(&start_node).unwrap(); - sequences.insert(block.sequence.clone()); + if block.sequence_hash != Sequence::PATH_START_HASH + && block.sequence_hash != Sequence::PATH_END_HASH + { + sequences.insert(block.sequence.clone()); + } } else { for path in all_simple_paths(&graph, start_node, *end_node) { let mut current_sequence = "".to_string(); @@ -629,7 +416,7 @@ impl BlockGroup { #[cfg(test)] mod tests { use super::*; - use crate::models::{Collection, Sample}; + use crate::models::{collection::Collection, Sample}; use crate::test_helpers::get_connection; fn setup_block_group(conn: &Connection) -> (i32, Path) { diff --git a/src/models/collection.rs b/src/models/collection.rs new file mode 100644 index 0000000..e60185a --- /dev/null +++ b/src/models/collection.rs @@ -0,0 +1,70 @@ +use rusqlite::types::Value; +use rusqlite::{params_from_iter, Connection}; + +use crate::models::block_group::BlockGroup; + +#[derive(Clone, Debug)] +pub struct Collection { + pub name: String, +} + +impl Collection { + pub fn exists(conn: &Connection, name: &str) -> bool { + let mut stmt = conn + .prepare("select name from collection where name = ?1") + .unwrap(); + stmt.exists([name]).unwrap() + } + pub fn create(conn: &Connection, name: &str) -> Collection { + let mut stmt = conn + .prepare("INSERT INTO collection (name) VALUES (?1) RETURNING *") + .unwrap(); + let mut rows = stmt + .query_map((name,), |row| Ok(Collection { name: row.get(0)? })) + .unwrap(); + rows.next().unwrap().unwrap() + } + + pub fn bulk_create(conn: &Connection, names: &Vec) -> Vec { + let placeholders = names.iter().map(|_| "(?)").collect::>().join(", "); + let q = format!( + "INSERT INTO collection (name) VALUES {} RETURNING *", + placeholders + ); + let mut stmt = conn.prepare(&q).unwrap(); + let rows = stmt + .query_map(params_from_iter(names), |row| { + Ok(Collection { name: row.get(0)? }) + }) + .unwrap(); + rows.map(|row| row.unwrap()).collect() + } + + pub fn get_block_groups(conn: &Connection, collection_name: &str) -> Vec { + // Load all block groups that have the given collection_name + let mut stmt = conn + .prepare("SELECT * FROM block_group WHERE collection_name = ?1") + .unwrap(); + let block_group_iter = stmt + .query_map([collection_name], |row| { + Ok(BlockGroup { + id: row.get(0)?, + collection_name: row.get(1)?, + sample_name: row.get(2)?, + name: row.get(3)?, + }) + }) + .unwrap(); + block_group_iter.map(|bg| bg.unwrap()).collect() + } + + pub fn query(conn: &Connection, query: &str, placeholders: Vec) -> Vec { + let mut stmt = conn.prepare(query).unwrap(); + let rows = stmt + .query_map(params_from_iter(placeholders), |row| { + Ok(Collection { name: row.get(0)? }) + }) + .unwrap(); + rows.map(|row| row.unwrap()).collect() + } +} diff --git a/src/models/edge.rs b/src/models/edge.rs index 226d8ac..6e8c6c7 100644 --- a/src/models/edge.rs +++ b/src/models/edge.rs @@ -1,11 +1,13 @@ +use itertools::Itertools; +use petgraph::graphmap::DiGraphMap; use rusqlite::types::Value; use rusqlite::{params_from_iter, Connection}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::hash::RandomState; -use crate::models::strand::Strand; +use crate::models::{sequence::Sequence, strand::Strand}; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct Edge { pub id: i32, pub source_hash: String, @@ -30,6 +32,21 @@ pub struct EdgeData { pub phased: i32, } +#[derive(Eq, Hash, PartialEq)] +pub struct BlockKey { + pub sequence_hash: String, + pub coordinate: i32, +} + +#[derive(Clone)] +pub struct GroupBlock { + pub id: i32, + pub sequence_hash: String, + pub sequence: String, + pub start: i32, + pub end: i32, +} + impl Edge { #[allow(clippy::too_many_arguments)] pub fn create( @@ -224,6 +241,219 @@ impl Edge { phased: edge.phased, } } + + fn get_block_boundaries( + source_edges: Option<&Vec<&Edge>>, + target_edges: Option<&Vec<&Edge>>, + sequence_length: i32, + ) -> Vec { + let mut block_boundary_coordinates = HashSet::new(); + if let Some(actual_source_edges) = source_edges { + for source_edge in actual_source_edges { + if source_edge.source_coordinate > 0 + && source_edge.source_coordinate < sequence_length + { + block_boundary_coordinates.insert(source_edge.source_coordinate); + } + } + } + if let Some(actual_target_edges) = target_edges { + for target_edge in actual_target_edges { + if target_edge.target_coordinate > 0 + && target_edge.target_coordinate < sequence_length + { + block_boundary_coordinates.insert(target_edge.target_coordinate); + } + } + } + + block_boundary_coordinates + .into_iter() + .sorted_by(|c1, c2| Ord::cmp(&c1, &c2)) + .collect::>() + } + + pub fn blocks_from_edges(conn: &Connection, edges: &Vec) -> (Vec, Vec) { + let mut sequence_hashes = HashSet::new(); + let mut edges_by_source_hash: HashMap<&str, Vec<&Edge>> = HashMap::new(); + let mut edges_by_target_hash: HashMap<&str, Vec<&Edge>> = HashMap::new(); + for edge in edges { + if edge.source_hash != Sequence::PATH_START_HASH { + sequence_hashes.insert(edge.source_hash.as_str()); + edges_by_source_hash + .entry(&edge.source_hash) + .and_modify(|edges| edges.push(edge)) + .or_default(); + } + if edge.target_hash != Sequence::PATH_END_HASH { + sequence_hashes.insert(edge.target_hash.as_str()); + edges_by_target_hash + .entry(&edge.target_hash) + .and_modify(|edges| edges.push(edge)) + .or_default(); + } + } + + let sequences_by_hash = + Sequence::sequences_by_hash(conn, sequence_hashes.into_iter().collect::>()); + let mut blocks = vec![]; + let mut block_index = 0; + let mut boundary_edges = vec![]; + for (hash, sequence) in sequences_by_hash.into_iter() { + let block_boundaries = Edge::get_block_boundaries( + edges_by_source_hash.get(hash.as_str()), + edges_by_target_hash.get(hash.as_str()), + sequence.length, + ); + for block_boundary in &block_boundaries { + // NOTE: Most of this data is bogus, the Edge struct is just a convenient wrapper + // for the data we need to set up boundary edges in the block group graph + boundary_edges.push(Edge { + id: -1, + source_hash: hash.clone(), + source_coordinate: *block_boundary, + source_strand: Strand::Unknown, + target_hash: hash.clone(), + target_coordinate: *block_boundary, + target_strand: Strand::Unknown, + chromosome_index: 0, + phased: 0, + }); + } + + if !block_boundaries.is_empty() { + let start = 0; + let end = block_boundaries[0]; + let block_sequence = sequence.get_sequence(start, end).to_string(); + let first_block = GroupBlock { + id: block_index, + sequence_hash: hash.clone(), + sequence: block_sequence, + start, + end, + }; + blocks.push(first_block); + block_index += 1; + for (start, end) in block_boundaries.clone().into_iter().tuple_windows() { + let block_sequence = sequence.get_sequence(start, end).to_string(); + let block = GroupBlock { + id: block_index, + sequence_hash: hash.clone(), + sequence: block_sequence, + start, + end, + }; + blocks.push(block); + block_index += 1; + } + let start = block_boundaries[block_boundaries.len() - 1]; + let end = sequence.length; + let block_sequence = sequence.get_sequence(start, end).to_string(); + let last_block = GroupBlock { + id: block_index, + sequence_hash: hash.clone(), + sequence: block_sequence, + start, + end, + }; + blocks.push(last_block); + block_index += 1; + } else { + blocks.push(GroupBlock { + id: block_index, + sequence_hash: hash.clone(), + sequence: sequence.get_sequence(None, None), + start: 0, + end: sequence.length, + }); + block_index += 1; + } + } + + // NOTE: We need a dedicated start node and a dedicated end node for the graph formed by the + // block group, since different paths in the block group may start or end at different + // places on sequences. These two "start sequence" and "end sequence" blocks will serve + // that role. + let start_sequence = Sequence::sequence_from_hash(conn, Sequence::PATH_START_HASH).unwrap(); + let start_block = GroupBlock { + id: block_index + 1, + sequence_hash: start_sequence.hash.clone(), + sequence: "".to_string(), + start: 0, + end: 0, + }; + blocks.push(start_block); + let end_sequence = Sequence::sequence_from_hash(conn, Sequence::PATH_END_HASH).unwrap(); + let end_block = GroupBlock { + id: block_index + 2, + sequence_hash: end_sequence.hash.clone(), + sequence: "".to_string(), + start: 0, + end: 0, + }; + blocks.push(end_block); + (blocks, boundary_edges) + } + + pub fn build_graph( + conn: &Connection, + edges: &Vec, + blocks: &Vec, + ) -> (DiGraphMap, HashMap<(i32, i32), Edge>) { + let blocks_by_start = blocks + .clone() + .into_iter() + .map(|block| { + ( + BlockKey { + sequence_hash: block.sequence_hash, + coordinate: block.start, + }, + block.id, + ) + }) + .collect::>(); + let blocks_by_end = blocks + .clone() + .into_iter() + .map(|block| { + ( + BlockKey { + sequence_hash: block.sequence_hash, + coordinate: block.end, + }, + block.id, + ) + }) + .collect::>(); + + let mut graph: DiGraphMap = DiGraphMap::new(); + let mut edges_by_node_pair = HashMap::new(); + for block in blocks { + graph.add_node(block.id); + } + for edge in edges { + let source_key = BlockKey { + sequence_hash: edge.source_hash.clone(), + coordinate: edge.source_coordinate, + }; + let source_id = blocks_by_end.get(&source_key); + let target_key = BlockKey { + sequence_hash: edge.target_hash.clone(), + coordinate: edge.target_coordinate, + }; + let target_id = blocks_by_start.get(&target_key); + + if let Some(source_id_value) = source_id { + if let Some(target_id_value) = target_id { + graph.add_edge(*source_id_value, *target_id_value, ()); + edges_by_node_pair.insert((*source_id_value, *target_id_value), edge.clone()); + } + } + } + + (graph, edges_by_node_pair) + } } mod tests { @@ -232,7 +462,7 @@ mod tests { use super::*; use std::collections::HashMap; - use crate::models::{sequence::Sequence, Collection}; + use crate::models::{collection::Collection, sequence::Sequence}; use crate::test_helpers::get_connection; #[test] diff --git a/src/models/path.rs b/src/models/path.rs index 9ffdde9..5932ed6 100644 --- a/src/models/path.rs +++ b/src/models/path.rs @@ -256,7 +256,7 @@ mod tests { // Note this useful idiom: importing names from outer (for mod tests) scope. use super::*; - use crate::models::{block_group::BlockGroup, sequence::NewSequence, Collection}; + use crate::models::{block_group::BlockGroup, collection::Collection, sequence::NewSequence}; use crate::test_helpers::get_connection; #[test]