From f0e32c6c3e04a16f2ca1210f46c7a8a7e0236708 Mon Sep 17 00:00:00 2001 From: Chris Norman Date: Tue, 3 May 2022 16:15:01 -0400 Subject: [PATCH] Add support for re-serializing (sharding and reassembling) CRAM containers to a new stream. --- .../samtools/CRAMContainerStreamRewriter.java | 93 +++++++ .../java/htsjdk/samtools/CRAMIndexer.java | 1 - .../samtools/cram/structure/Container.java | 25 +- .../cram/structure/ContainerHeader.java | 6 +- .../htsjdk/samtools/cram/structure/Slice.java | 20 +- .../CRAMContainerStreamRewriterTest.java | 227 ++++++++++++++++++ 6 files changed, 367 insertions(+), 5 deletions(-) create mode 100644 src/main/java/htsjdk/samtools/CRAMContainerStreamRewriter.java create mode 100644 src/test/java/htsjdk/samtools/CRAMContainerStreamRewriterTest.java diff --git a/src/main/java/htsjdk/samtools/CRAMContainerStreamRewriter.java b/src/main/java/htsjdk/samtools/CRAMContainerStreamRewriter.java new file mode 100644 index 0000000000..4e9a929440 --- /dev/null +++ b/src/main/java/htsjdk/samtools/CRAMContainerStreamRewriter.java @@ -0,0 +1,93 @@ +package htsjdk.samtools; + +import htsjdk.samtools.cram.build.CramIO; +import htsjdk.samtools.cram.structure.Container; +import htsjdk.samtools.cram.structure.CramHeader; +import htsjdk.samtools.util.RuntimeIOException; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Rewrite a series of containers to a new stream. The CRAM header and SAMFileHeader containers are automatically + * written to the stream when this class is instantiated. An EOF container is automatically written when + * {@link #finish()} is called. + */ +public class CRAMContainerStreamRewriter { + private final OutputStream outputStream; + private final String outputStreamIdentifier; + private final CramHeader cramHeader; + private final SAMFileHeader samFileHeader; + private final CRAMIndexer cramIndexer; + + private long streamOffset = 0L; + private long recordCounter = 0L; + + /** + * Create a CRAMContainerStreamWriter for writing SAM records into a series of CRAM + * containers on an output stream, with an optional output index. + * + * @param outputStream where to write the CRAM stream. + * @param samFileHeader {@link SAMFileHeader} to be used. Sort order is determined by the sortOrder property of this arg. + * @param outputStreamIdentifier used for display in error message display + * @param indexer CRAM indexer. Can be null if no index is required. + */ + public CRAMContainerStreamRewriter( + final OutputStream outputStream, + final CramHeader cramHeader, + final SAMFileHeader samFileHeader, + final String outputStreamIdentifier, + final CRAMIndexer indexer) { + this.outputStream = outputStream; + this.cramHeader = cramHeader; + this.samFileHeader = samFileHeader; + this.outputStreamIdentifier = outputStreamIdentifier; + this.cramIndexer = indexer; + + //TODO: update the SAMFileHeader with a program group to leave a paper trail? + streamOffset = CramIO.writeCramHeader(cramHeader, outputStream); + streamOffset += Container.writeSAMFileHeaderContainer(cramHeader.getCRAMVersion(), samFileHeader, outputStream); + } + + /** + * Writes a container to a stream, updating the (stream-relative) global record counter and byte offsets. + * + * Since this method mutates the values in the container, the container is no longer valid in the context + * of the stream from which it originated. + * + * @param container the container to emit to the stream. the container must conform to the version and sort + * order specified in the CRAM header and SAM header provided to the constructor + * {@link #CRAMContainerStreamRewriter(OutputStream, CramHeader, SAMFileHeader, String, CRAMIndexer)}. + * All the containers serialized to a single stream using this method must have originated from the + * same original context(/stream), obtained via {@link htsjdk.samtools.cram.build.CramContainerIterator}. + */ + public void rewriteContainer(final Container container) { + // update the container and slices with the correct global record counter and byte offsets + // (required for indexing) + container.relocateContainer(recordCounter, streamOffset); + + // re-serialize the entire container and slice(s), block by block + streamOffset += container.write(cramHeader.getCRAMVersion(), outputStream); + recordCounter += container.getContainerHeader().getNumberOfRecords(); + + if (cramIndexer != null) { + cramIndexer.processContainer(container, ValidationStringency.SILENT); + } + } + + /** + * Finish writing to the stream. Flushes the record cache and optionally emits an EOF container. + */ + public void finish() { + try { + CramIO.writeCramEOF(cramHeader.getCRAMVersion(), outputStream); + outputStream.flush(); + if (cramIndexer != null) { + cramIndexer.finish(); + } + } catch (final IOException e) { + throw new RuntimeIOException(String.format("IOException closing stream for %s", outputStreamIdentifier)); + } + } + +} diff --git a/src/main/java/htsjdk/samtools/CRAMIndexer.java b/src/main/java/htsjdk/samtools/CRAMIndexer.java index 5e332a87ea..f58a2d8ce8 100644 --- a/src/main/java/htsjdk/samtools/CRAMIndexer.java +++ b/src/main/java/htsjdk/samtools/CRAMIndexer.java @@ -1,6 +1,5 @@ package htsjdk.samtools; -import htsjdk.samtools.cram.structure.CompressorCache; import htsjdk.samtools.cram.structure.Container; /** diff --git a/src/main/java/htsjdk/samtools/cram/structure/Container.java b/src/main/java/htsjdk/samtools/cram/structure/Container.java index 890fb9db81..160d5eaf3c 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/Container.java +++ b/src/main/java/htsjdk/samtools/cram/structure/Container.java @@ -49,7 +49,7 @@ public class Container { private final List slices; // container's byte offset from the start of the containing stream, used for indexing - private final long containerByteOffset; + private long containerByteOffset; /** * Create a Container with a {@link ReferenceContext} derived from its {@link Slice}s. @@ -190,6 +190,7 @@ public int write(final CRAMVersion cramVersion, final OutputStream outputStream) // landmark 0 = byte length of the compression header // landmarks after 0 = byte length of the compression header plus all slices before this one landmarks.add(tempOutputStream.size()); + slice.byteOffsetOfContainer = containerByteOffset; slice.write(cramVersion, tempOutputStream); } getContainerHeader().setLandmarks(landmarks); @@ -335,6 +336,28 @@ public List getSAMRecords( public CompressionHeader getCompressionHeader() { return compressionHeader; } public AlignmentContext getAlignmentContext() { return containerHeader.getAlignmentContext(); } public long getContainerByteOffset() { return containerByteOffset; } + + /** + * Update the stream-relative values (global record counter and stream byte offset) for this + * container. For use when re-serializing a container that has been read from an existing stream + * into a new stream. This method mutates the container and it's slices - the container is no + * longer valid in the context of it's original stream. + * + * @param containerRecordCounter the new global record counter for this container + * @param streamByteOffset the new stream byte offset counter for this container + * @return the updated global record counter + */ + public long relocateContainer(final long containerRecordCounter, final long streamByteOffset) { + this.containerByteOffset = streamByteOffset; + this.getContainerHeader().setGlobalRecordCounter(containerRecordCounter); + + long sliceRecordCounter = containerRecordCounter; + for (final Slice slice : getSlices()) { + sliceRecordCounter = slice.relocateSlice(sliceRecordCounter, streamByteOffset); + } + return sliceRecordCounter; + } + public List getSlices() { return slices; } public boolean isEOF() { return containerHeader.isEOF() && (getSlices() == null || getSlices().size() == 0); diff --git a/src/main/java/htsjdk/samtools/cram/structure/ContainerHeader.java b/src/main/java/htsjdk/samtools/cram/structure/ContainerHeader.java index fd4cac79d7..0b9319d3bc 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/ContainerHeader.java +++ b/src/main/java/htsjdk/samtools/cram/structure/ContainerHeader.java @@ -42,7 +42,7 @@ public class ContainerHeader { // total length of all blocks in this container (total length of this container, minus the Container Header). private final AlignmentContext alignmentContext; private final int recordCount; - private final long globalRecordCounter; + private long globalRecordCounter; private final long baseCount; private final int blockCount; @@ -249,4 +249,8 @@ public boolean isEOF() { return v3 || v2; } + void setGlobalRecordCounter(final long recordCounter) { + this.globalRecordCounter = recordCount; + } + } diff --git a/src/main/java/htsjdk/samtools/cram/structure/Slice.java b/src/main/java/htsjdk/samtools/cram/structure/Slice.java index 0b6f37ec46..626e46b26e 100644 --- a/src/main/java/htsjdk/samtools/cram/structure/Slice.java +++ b/src/main/java/htsjdk/samtools/cram/structure/Slice.java @@ -67,7 +67,7 @@ public class Slice { // Slice header components as defined in the spec private final AlignmentContext alignmentContext; // ref sequence, alignment start and span private final int nRecords; - private final long globalRecordCounter; + private long globalRecordCounter; private final int nSliceBlocks; // includes the core block and external blocks, but not the header block private List contentIDs; private int embeddedReferenceBlockContentID = EMBEDDED_REFERENCE_ABSENT_CONTENT_ID; @@ -78,7 +78,7 @@ public class Slice { private final CompressionHeader compressionHeader; private final SliceBlocks sliceBlocks; - private final long byteOffsetOfContainer; + public long byteOffsetOfContainer; private Block sliceHeaderBlock; @@ -518,6 +518,22 @@ public void normalizeCRAMRecords(final List cramCompressi } } + /** + * Update the stream-relative values (global record counter and container stream byte offset) for + * this slice. For use when re-serializing a container that has been read from an existing stream + * into a new stream. This method mutates the container and it's slices - the container is no + * longer valid in the context of it's original stream. + * + * @param sliceRecordCounter the new global record counter for this slice + * @param containerByteOffset the new stream byte offset counter for this slice's enclosing container + * @return the updated global record counter + */ + long relocateSlice(final long sliceRecordCounter, final long containerByteOffset) { + this.byteOffsetOfContainer = containerByteOffset; + this.globalRecordCounter = sliceRecordCounter; + return sliceRecordCounter + getNumberOfRecords(); + } + private int getReferenceOffset(final boolean hasEmbeddedReference) { final ReferenceContext sliceReferenceContext = getAlignmentContext().getReferenceContext(); return sliceReferenceContext.isMappedSingleRef() && hasEmbeddedReference ? diff --git a/src/test/java/htsjdk/samtools/CRAMContainerStreamRewriterTest.java b/src/test/java/htsjdk/samtools/CRAMContainerStreamRewriterTest.java new file mode 100644 index 0000000000..11ea586692 --- /dev/null +++ b/src/test/java/htsjdk/samtools/CRAMContainerStreamRewriterTest.java @@ -0,0 +1,227 @@ +package htsjdk.samtools; + +import htsjdk.HtsjdkTest; +import htsjdk.beta.io.IOPathUtils; +import htsjdk.io.IOPath; +import htsjdk.samtools.cram.build.CramContainerIterator; +import htsjdk.samtools.util.IOUtil; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; + +public class CRAMContainerStreamRewriterTest extends HtsjdkTest { + private static final File TEST_DATA_DIR = new File("src/test/resources/htsjdk/samtools/cram"); + + // use a test file with artificially small containers and slices, since it has multiple containers + // (most test files in the repo use the default container size and have only one container), includes + // some unmapped reads, and already has an index + public final static File testCRAM = new File(TEST_DATA_DIR, "NA12878.20.21.1-100.100-SeqsPerSlice.500-unMapped.cram"); + public final static File testFASTA = new File(TEST_DATA_DIR, "human_g1k_v37.20.21.1-100.fasta"); + + private enum CRAM_TEST_INDEX_TYPE { + NO_INDEX, + BAI_INDEX, + CRAI_INDEX + } + + @DataProvider(name="containerStreamRewriterTests") + public Object[] getContainerStreamRewriterTests() { + return new Object[][] { + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.NO_INDEX, + null }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.BAI_INDEX, + (Function) (SamReader samReader) -> samReader.query("20", 1, 100000, true), + }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.CRAI_INDEX, + (Function) (SamReader samReader) -> samReader.query("20", 1, 100000, true), + }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.BAI_INDEX, + (Function) (SamReader samReader) -> samReader.query("20", 34, 134, false), + }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.CRAI_INDEX, + (Function) (SamReader samReader) -> samReader.query("20", 34, 134, false), + }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.BAI_INDEX, + (Function) (SamReader samReader) -> samReader.queryUnmapped(), + }, + { + testCRAM, + testFASTA, + CRAM_TEST_INDEX_TYPE.CRAI_INDEX, + (Function) (SamReader samReader) -> samReader.queryUnmapped(), + } + }; + } + @Test(dataProvider = "containerStreamRewriterTests") + private void testCRAMRewriteContainerStream( + final File testCRAM, + final File referenceFile, + final CRAM_TEST_INDEX_TYPE indexType, + final Function> queryFunction) throws IOException { + final IOPath tempOutputCRAM = IOPathUtils.createTempPath("cramContainerStreamRewriterTest", ".cram"); + + try (final CramContainerIterator cramContainerIterator = + new CramContainerIterator(new BufferedInputStream(new FileInputStream(testCRAM.toPath().toFile()))); + final BufferedOutputStream outputStream = + new BufferedOutputStream(new FileOutputStream(tempOutputCRAM.toPath().toFile())) + ) { + final CRAMContainerStreamRewriter containerStreamRewriter = + new CRAMContainerStreamRewriter( + outputStream, + cramContainerIterator.getCramHeader(), + cramContainerIterator.getSamFileHeader(), + "test", + getIndexerForType(indexType, tempOutputCRAM, cramContainerIterator.getSamFileHeader())); + while (cramContainerIterator.hasNext()) { + containerStreamRewriter.rewriteContainer(cramContainerIterator.next()); + } + containerStreamRewriter.finish(); + } + + // iterate through all the records in the rewritten file and compare them with those in the original file + try (final SamReader originalReader = SamReaderFactory.makeDefault() + .referenceSequence(referenceFile) + .validationStringency(ValidationStringency.SILENT) + .open(testCRAM); + final SamReader rewrittenReader = SamReaderFactory.makeDefault() + .referenceSequence(referenceFile) + .validationStringency(ValidationStringency.SILENT) + .open(tempOutputCRAM.toPath())) { + // rewriting the SAMHeader "upgrades" it's version because it gets re-serialized by the text codec, so we + // can't compare the headers directly, so settle for a sequence dictionary check + Assert.assertEquals( + rewrittenReader.getFileHeader().getSequenceDictionary(), + originalReader.getFileHeader().getSequenceDictionary()); + + final Iterator originalIterator = originalReader.iterator(); + final Iterator rewrittenIterator = rewrittenReader.iterator(); + while (originalIterator.hasNext() && rewrittenIterator.hasNext()) { + Assert.assertEquals(originalIterator.next(), rewrittenIterator.next()); + } + Assert.assertEquals(originalIterator.hasNext(), rewrittenIterator.hasNext()); + } + + // now compare the results from a simple index query on the original with the results from the rewritten file + if (indexType != CRAM_TEST_INDEX_TYPE.NO_INDEX) { + try (final SamReader originalReader = SamReaderFactory.makeDefault() + .referenceSequence(referenceFile) + .validationStringency(ValidationStringency.SILENT) + .open(testCRAM); + final SamReader rewrittenReader = SamReaderFactory.makeDefault() + .referenceSequence(referenceFile) + .validationStringency(ValidationStringency.SILENT) + .open(tempOutputCRAM.toPath())) { + Assert.assertEquals(queryFunction.apply(originalReader), queryFunction.apply(rewrittenReader)); + } + } + } + + @Test + private void testShardingReassembly() throws IOException { + // break up a file into multiple shards (1 container per shard), then reassemble, and compare the results + // of the contents of the original with the contents of the reassembled file + final List outputShards = new ArrayList<>(); + try (final CramContainerIterator cramContainerIterator = + new CramContainerIterator(new BufferedInputStream(new FileInputStream(testCRAM)))) { + while (cramContainerIterator.hasNext()) { + final IOPath tempOutputCRAM = IOPathUtils.createTempPath("cramContainerStreamRewriterTest", ".cram"); + outputShards.add(tempOutputCRAM.toPath().toFile()); + + try (final BufferedOutputStream outputStream = + new BufferedOutputStream(new FileOutputStream(tempOutputCRAM.toPath().toFile()))) { + final CRAMContainerStreamRewriter containerStreamRewriter = new CRAMContainerStreamRewriter( + outputStream, + cramContainerIterator.getCramHeader(), + cramContainerIterator.getSamFileHeader(), + "test", + null); + containerStreamRewriter.rewriteContainer(cramContainerIterator.next()); + containerStreamRewriter.finish(); + } + } + } + + // we need to make sure we have at least a few shards for this to be interesting + Assert.assertEquals(outputShards.size(), 9); + + try (final SamReader originalReader = SamReaderFactory.makeDefault() + .referenceSequence(testFASTA) + .validationStringency(ValidationStringency.SILENT).open(testCRAM)) { + final Iterator originalIterator = originalReader.iterator(); + final Iterator rewrittenIterator = getIteratorFromShards(outputShards); + while (originalIterator.hasNext() && rewrittenIterator.hasNext()) { + Assert.assertEquals(originalIterator.next(), rewrittenIterator.next()); + } + Assert.assertEquals(originalIterator.hasNext(), rewrittenIterator.hasNext()); + } + } + + private Iterator getIteratorFromShards(final List outputShards) throws IOException { + final List shardedSAMRecords = new ArrayList<>(); + for (final File shardFile: outputShards) { + try (final SamReader originalReader = SamReaderFactory.makeDefault() + .referenceSequence(testFASTA) + .validationStringency(ValidationStringency.SILENT).open(shardFile)) { + for ( final SAMRecord samRecord: originalReader) { + shardedSAMRecords.add(samRecord); + } + } + } + return shardedSAMRecords.iterator(); + } + + private CRAMIndexer getIndexerForType( + final CRAM_TEST_INDEX_TYPE indexType, + final IOPath cramFile, + final SAMFileHeader samFileHeader) throws IOException { + switch (indexType) { + case NO_INDEX: + return null; + case BAI_INDEX: + final Path tempOutputBAI = IOUtil.addExtension(cramFile.toPath(), ".bai"); + IOUtil.deleteOnExit(tempOutputBAI); + return new CRAMBAIIndexer( + new BufferedOutputStream(new FileOutputStream(tempOutputBAI.toFile())), + samFileHeader); + case CRAI_INDEX: + final Path tempOutputCRAI = IOUtil.addExtension(cramFile.toPath(), ".crai"); + IOUtil.deleteOnExit(tempOutputCRAI); + return new CRAMCRAIIndexer( + new BufferedOutputStream(new FileOutputStream(tempOutputCRAI.toFile())), + samFileHeader); + default: + throw new IllegalArgumentException("Unknown cram index type"); + } + } + +}