diff --git a/Cargo.toml b/Cargo.toml index cad62245..108ba706 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,19 @@ crate-type = ["lib", "cdylib"] [dependencies] autocxx = "0.26.0" +bio = { version = "1.5.0", features = ["phylogeny"] } clap = { version = "4.4.18", features = ["derive", "cargo"] } csv = "1.3.0" cxx = "1.0.115" flate2 = "1.0.28" +libc = "0.2.153" +log = "0.4.20" +log4rs = "1.3.0" nom = "7.1.3" polars = { version = "0.37.0", features = ["lazy"] } pyo3 = { version = "0.20", features = ["abi3-py38", "extension-module", "multiple-pymethods"], optional = true } pyo3-polars = { version = "0.11.1", optional = true } +rust-htslib = { version = "0.46.0", features = ["libdeflate"] } serde = { version = "1.0.195", features = ["derive"] } tempfile = "3.9.0" diff --git a/README+s.md b/README+s.md index a682f7b0..99815cc0 100644 --- a/README+s.md +++ b/README+s.md @@ -26,7 +26,7 @@ This repository provides Python wrappers for the following programs: And rewrite of the following programs: -- [smc2bed-all](./argweaver/scripts/smc2bed_all.py) +- [smc2bed-all](./argweavers/scripts/smc2bed_all.py) ## Usage diff --git a/src/bin/arg-sample.rs b/src/bin/arg-sample.rs index f5e50643..286c1b47 100644 --- a/src/bin/arg-sample.rs +++ b/src/bin/arg-sample.rs @@ -1,9 +1,18 @@ -use std::path::PathBuf; +use std::fs::read_to_string; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; -use clap::{Args, Parser}; +use clap::{builder::ArgAction, Args, Parser}; +use log::{error, info, warn, LevelFilter}; +use log4rs::append::console::ConsoleAppender; +use log4rs::append::file::FileAppender; +use log4rs::config::{Appender, Config as LogConfig, Root}; +use log4rs::filter::threshold::ThresholdFilter; +use polars::prelude::*; use argweavers::{ ser::{StatsRecord, StatsStage, StatsWriter}, + sites::Sites, Result, }; @@ -32,6 +41,28 @@ struct InputOptions { vcf_list_file: Option, } +impl InputOptions { + fn sites(&self) -> Result { + if let Some(sites_file) = &self.sites_file { + return Sites::from_path(sites_file); + } + if let Some(fasta_file) = &self.fasta_file { + return Sites::from_msa(fasta_file); + } + if let Some(vcf_file) = &self.vcf_file { + return Sites::from_vcf(vcf_file); + } + if let Some(vcf_list_file) = &self.vcf_list_file { + let vcfs: Vec = read_to_string(vcf_list_file)? + .lines() + .map(PathBuf::from) + .collect(); + return Sites::from_vcfs(&vcfs); + } + Err("No input file provided".into()) + } +} + #[derive(Args)] struct IOArgs { /// prefix for all output filenames @@ -52,6 +83,33 @@ struct IOArgs { /// set will be ignored. #[arg(long = "rename-seqs", value_name = "name_map_file.txt")] rename_file: Option, + + /// sample ARG for only a region of the sites (optional). Note the [chr:] + /// prefix should be added ONLY if alignment is in vcf format. If this + /// option is given, the vcf must also be indexed with tabix + #[arg(long = "region", value_name = "[chr:]start-end")] + region: Option, + + /// file listing NAMES from sites file (or sequences from fasta) to keep; + /// others will not be used. May be diploid or haploid names (i.e., ind will + /// have same effect as ind_1 and ind_2). + #[arg(long, visible_aliases = ["subsites", "keep"], value_delimiter = ' ', num_args = 1..)] + keep_ids: Option>, + + /// data is unphased (will integrate over phasings). + #[arg(long)] + unphased: bool, + + /// do not gzip output files + #[clap(long = "no-compress-output", action = ArgAction::SetFalse)] + compress_output: bool, + + #[clap( + long = "compress-output", + overrides_with = "compress_output", + hide = true + )] + _no_compress_output: (), } #[derive(Args)] @@ -64,6 +122,13 @@ struct SamplingArgs { overwrite: bool, } +#[derive(Args)] +struct MiscArgs { + /// seed for random number generator (default=current time) + #[arg(short = 'x', long = "randseed", value_name = "random seed")] + seed: Option, +} + /// Sampler for large ancestral recombination graphs #[derive(Parser)] #[clap(author, version, about)] @@ -72,6 +137,8 @@ struct Config { io_args: IOArgs, #[command(flatten, next_help_heading = "Sampling")] sampling_args: SamplingArgs, + #[command(flatten, next_help_heading = "Miscellaneous")] + misc_args: MiscArgs, } impl Config { @@ -89,6 +156,13 @@ impl Config { } Ok(()) } + fn out_postfix(&self) -> &str { + if self.io_args.compress_output { + ".gz" + } else { + "" + } + } } fn ensure_output_dir_exists>(path: P) -> Result<()> { @@ -99,20 +173,160 @@ fn ensure_output_dir_exists>(path: P) -> Result<()> { Ok(()) } +fn setup_logging(log_file: &Path, resume: bool) -> Result<()> { + let stdout = ConsoleAppender::builder().build(); + + let logfile = FileAppender::builder().append(resume).build(log_file)?; + + let config = LogConfig::builder() + .appender(Appender::builder().build("stdout", Box::new(stdout))) + .appender( + Appender::builder() + .filter(Box::new(ThresholdFilter::new(LevelFilter::Debug))) + .build("logfile", Box::new(logfile)), + ) + .build( + Root::builder() + .appender("stdout") + .appender("logfile") + .build(LevelFilter::Trace), + )?; + + let _handle = log4rs::init_config(config)?; + Ok(()) +} + +fn find_previous_smc_file(out_prefix: &PathBuf) -> Result<(PathBuf, i64)> { + let stats_filename = out_prefix.with_extension("stats"); + info!( + "Checking previous run from stats file: {}", + stats_filename.display() + ); + let stats = CsvReader::from_path(stats_filename)? + .has_header(true) + .with_separator(b'\t') + .finish() + .map_err(|e| match e { + polars::error::PolarsError::NoData(_) => { + polars::error::PolarsError::NoData("stats file is empty".into()) + } + _ => e, + })?; + let stage = stats.column("stage")?; + let resample = stats + .filter(&stage.equal("resample")?)? + .sort(["iter"], true, true)?; + for i in resample.column("iter")?.i64()? { + let i = i.unwrap(); + let smc_file = out_prefix.with_extension(format!("{}.smc.gz", i)); + if smc_file.exists() { + return Ok((smc_file, i)); + } + let smc_file = smc_file.with_extension(""); + if smc_file.exists() { + return Ok((smc_file, i)); + } + } + let msg = "Could not find any previously written SMC files. Try disabling resume"; + error!("{}", msg); + Err(std::io::Error::new(std::io::ErrorKind::NotFound, msg).into()) +} + +fn setup_resume(config: &mut Config) -> Result<()> { + info!("Resuming from previous run"); + let (arg_file, resume_iter) = find_previous_smc_file(&config.io_args.out_prefix)?; + let sites_file = config.io_args.out_prefix.with_extension(format!( + "{}.sites{}", + resume_iter, + config.out_postfix(), + )); + if sites_file.exists() { + let test_sites = Sites::from_path(&sites_file)?; + info!( + "Detected phased output sites file. Using %s as input {} and assuming data is unphased", + &sites_file.display() + ); + config.io_args.input_options.sites_file = Some(sites_file); + config.io_args.unphased = true; + let mask_file = config + .io_args + .out_prefix + .with_extension("masked_regions.bed"); + } + let sites = config.io_args.input_options.sites()?; + info!( + "resuming at stage={}, iter={}, arg={}", + "resample", + resume_iter, + arg_file.display() + ); + Ok(()) +} + +fn now() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() +} + fn main() -> Result<()> { - let config = Config::parse(); + let mut config = Config::parse(); ensure_output_dir_exists(&config.io_args.out_prefix)?; config.check_overwrite()?; if let Some(fasta_file) = &config.io_args.input_options.fasta_file { println!("fasta_file: {:?}", fasta_file); } let stats_filename = &config.stats_filename(); + setup_logging( + &config.io_args.out_prefix.with_extension("log"), + config.sampling_args.resume, + )?; + if config.sampling_args.resume { + match setup_resume(&mut config) { + Ok(_) => info!("RESUME"), + Err(e) => { + if config.sampling_args.overwrite { + warn!("Resume failed. Sampling will start from scratch since overwrite is enabled."); + config.sampling_args.resume = false; + } else { + error!("Could not resume: {}", e); + return Err(e); + } + } + } + } + info!("arg-sample {}", env!("CARGO_PKG_VERSION")); + info!( + "next line for backward compatibility\ncommand: {}", + std::env::args() + .into_iter() + .map(|s| s.to_string()) + .collect::>() + .join(" ") + ); + let seed = config.misc_args.seed.unwrap_or(now()); + info!("random seed: {}", seed); + unsafe { + libc::srand(seed as _); + } + if let Some(keep_ids) = config.io_args.keep_ids { + if keep_ids.len() == 1 && PathBuf::from(&keep_ids[0]).exists() { + let keep_ids = std::fs::read_to_string(&keep_ids[0])?; + config.io_args.keep_ids = + Some(keep_ids.split_whitespace().map(|s| s.to_string()).collect()); + } + } let mut writer = StatsWriter::from_path(stats_filename)?; writer.serialize(&StatsRecord { stage: StatsStage::Resample, iter: 0, prior: 0.0, })?; - println!("Hello, world!"); + writer.serialize(&StatsRecord { + stage: StatsStage::Resample, + iter: 1, + prior: 0.0, + })?; Ok(()) } diff --git a/src/ser/stats.rs b/src/ser/stats.rs index 20de4c20..c305d2f5 100644 --- a/src/ser/stats.rs +++ b/src/ser/stats.rs @@ -1,7 +1,7 @@ use crate::Result; use csv::WriterBuilder; -#[derive(serde::Serialize)] +#[derive(serde::Serialize, serde::Deserialize)] #[serde(rename_all = "snake_case")] pub enum StatsStage { Resample, @@ -10,7 +10,7 @@ pub enum StatsStage { Climb, } -#[derive(serde::Serialize)] +#[derive(serde::Serialize, serde::Deserialize)] pub struct StatsRecord { pub stage: StatsStage, pub iter: usize, diff --git a/src/sites.rs b/src/sites.rs index 4a5b5bd9..4a43b8e5 100644 --- a/src/sites.rs +++ b/src/sites.rs @@ -2,6 +2,7 @@ use std::convert::TryInto; use std::fs::read_to_string; use autocxx::prelude::{c_int, UniquePtr, WithinUniquePtr}; +use bio::{bio_types::genome::AbstractLocus, io::fasta}; use nom::{ bytes::complete::tag, character::complete::{alphanumeric1, digit1, newline, one_of, tab}, @@ -19,6 +20,7 @@ use polars::{ use pyo3::{exceptions::PyIndexError, prelude::*, types::PySlice}; #[cfg(feature = "extension-module")] use pyo3_polars::PyDataFrame; +use rust_htslib::bcf::{self, Read}; use crate::{de::parse_names, ffi, Result}; @@ -194,7 +196,7 @@ fn parse_locs(input: &str) -> IResult<&str, (Vec, Vec>)> { separated_pair( parse_u32, tab, - map_res(many1(one_of("ACGT")), |r| { + map_res(many1(one_of("ACGTN-")), |r| { Ok::, nom::error::Error<&str>>( r.into_iter().map(|c| c as u32).collect::>(), ) @@ -206,7 +208,7 @@ fn parse_locs(input: &str) -> IResult<&str, (Vec, Vec>)> { } impl Sites { - pub fn from_path(path: std::path::PathBuf) -> Result { + pub fn from_path(path: &std::path::PathBuf) -> Result { let content = read_to_string(path)?; let (input, names) = parse_names(&content).map_err(|e| e.map_input(|s| s.to_owned()))?; let num_of_cols = names.len(); @@ -241,11 +243,119 @@ impl Sites { end, }) } + /// Read a sites file from a multi-sequence alignment FASTA file. + /// All sequences must be the same length. + pub fn from_msa(path: &std::path::PathBuf) -> Result { + let reader = fasta::Reader::from_file(path)?; + let mut records = reader.records(); + let first = records.next().unwrap()?; + let seq = first.seq(); + let len = seq.len(); + let mut seqs = vec![Series::from_iter(seq.iter().map(|b| *b as u32)).with_name(first.id())]; + for record in records { + let record = record?; + let seq = record.seq(); + assert_eq!(seq.len(), len); + seqs.push(Series::from_iter(seq.iter().map(|b| *b as u32)).with_name(record.id())); + } + let pos = Series::from_iter(1..=len as u32).with_name("pos"); + Ok(Self { + data: DataFrame::new(std::iter::once(pos).chain(seqs).collect())?, + chrom: "chr".to_string(), + start: 1, + end: len, + }) + } + + pub fn from_vcf(path: &std::path::PathBuf) -> Result { + let mut reader = bcf::Reader::from_path(path)?; + let header = reader.header(); + let names: Vec = header + .samples() + .into_iter() + .map(|s| std::str::from_utf8(s).unwrap().to_owned()) + .collect(); + let samples = names + .into_iter() + .enumerate() + .map(|(i, name)| { + let col = reader + .records() + .into_iter() + .map(|r| { + let r = r.unwrap(); + let allele = std::str::from_utf8(r.alleles()[i + 1]).unwrap(); + assert!( + allele.len() == 1, + "allele {} has length {}", + allele, + allele.len() + ); + allele.chars().next().unwrap() as u32 + }) + .collect::>(); + let series = ChunkedArray::::from_vec(&name, col) + .into_series(); + series + }) + .collect::>(); + let pos = ChunkedArray::::from_vec( + "pos", + reader + .records() + .into_iter() + .map(|r| r.unwrap().pos() as u32) + .collect::>(), + ) + .into_series(); + let chrom = reader + .records() + .into_iter() + .next() + .unwrap() + .unwrap() + .contig() + .to_string(); + Ok(Self { + data: DataFrame::new(std::iter::once(pos).chain(samples).collect())?, + chrom, + start: 1, + end: 100000, + }) + } + + pub fn from_vcfs(paths: &[std::path::PathBuf]) -> Result { + let mut sites = Self::from_vcf(&paths[0])?; + for path in &paths[1..] { + let other = Self::from_vcf(path)?; + sites.hstack_mut(&other); + } + Ok(sites) + } + + fn hstack_mut(&mut self, other: &Self) { + let self_names = self.data.get_column_names(); + let names: Vec<&str> = other + .data + .get_column_names() + .iter() + .filter(|&&c| !self_names.contains(&c)) + .map(|&c| c) + .collect(); + let columns: Vec = other + .data + .columns(names) + .unwrap() + .into_iter() + .map(|s| s.to_owned()) + .collect(); + self.data.hstack_mut(&columns).unwrap(); + } } #[cfg(feature = "extension-module")] #[pyfunction] pub fn read_sites(path: std::path::PathBuf) -> PyResult { - let sites = Sites::from_path(std::path::PathBuf::from(path)).unwrap(); + let sites = Sites::from_path(&path).unwrap(); Ok(sites) } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index fcf77a42..a69443a9 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -27,7 +27,7 @@ fn test_ffi_read_sites() { #[test] fn test_from_path() { let path = std::path::PathBuf::from("examples/sim1/sim1.sites"); - let sites = Sites::from_path(path).unwrap(); + let sites = Sites::from_path(&path).unwrap(); assert_eq!(sites.chrom, "chr"); assert_eq!(sites.start, 1); assert_eq!(sites.end, 100000); @@ -37,7 +37,7 @@ fn test_from_path() { #[test] fn test_from_path_and_to_string() { let path = std::path::PathBuf::from("examples/sim1/sim1.sites"); - let sites = Sites::from_path(path).unwrap(); + let sites = Sites::from_path(&path).unwrap(); let s = sites.to_string(); assert_eq!( s.to_string(), @@ -48,7 +48,7 @@ fn test_from_path_and_to_string() { #[test] fn test_try_into_ffi_sites() { let path = std::path::PathBuf::from("examples/sim1/sim1.sites"); - let sites = Sites::from_path(path).unwrap(); + let sites = Sites::from_path(&path).unwrap(); let ffi_sites: UniquePtr = sites.try_into().unwrap(); let num_seq: i32 = ffi_sites.get_num_seqs().into(); assert_eq!(num_seq, 8);