diff --git a/src/main/java/org/janelia/colormipsearch/SparkMaskSearch.java b/src/main/java/org/janelia/colormipsearch/SparkMaskSearch.java index 755e9b43..18a6d8d3 100644 --- a/src/main/java/org/janelia/colormipsearch/SparkMaskSearch.java +++ b/src/main/java/org/janelia/colormipsearch/SparkMaskSearch.java @@ -18,9 +18,13 @@ import javax.imageio.ImageIO; import java.io.*; +import java.nio.file.DirectoryStream; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -75,6 +79,16 @@ private ImagePlus readImagePlus(String filepath, String title, PortableDataStrea } } + private ImagePlus readImagePlus(String filepath, String title) throws Exception { + switch (getImageFormat(filepath)) { + case PNG: + return readPngToImagePlus(title, new FileInputStream(filepath)); + case TIFF: + return readTiffToImagePlus(title, new FileInputStream(filepath)); + } + throw new IllegalArgumentException("Image must be in PNG or TIFF format"); + } + private MaskSearchResult search(String filepath, ImagePlus image, ImagePlus mask, Integer maskThreshold) { try { @@ -107,29 +121,38 @@ private MaskSearchResult search(String filepath, ImagePlus image, ImagePlus mask * Load an image archive into memory. * @param imagesFilepath */ - public void loadImages(String imagesFilepath) { + public void loadImages(String imagesFilepath) throws IOException { - // We have to ensure each filepath ends with a wildcard, because it's very slow without. - // See https://issues.apache.org/jira/browse/SPARK-8437 - StringBuffer filepaths = new StringBuffer(); + List paths = new ArrayList<>(); for(String filepath : imagesFilepath.split(",")) { - if (!filepath.contains("*")) { - if (!filepath.endsWith("/")) { - filepath += "/"; + log.info("Loading image archive at: {}", filepath); + Path folder = Paths.get(filepath); + try (DirectoryStream stream = Files.newDirectoryStream(folder)) { + int c = 0; + for (Path entry : stream) { + paths.add(entry.toString()); + c++; } - filepath += "*"; + log.info(" Read {} files", c); } - if (filepaths.length()>0) filepaths.append(","); - filepaths.append(filepath); } - log.info("Loading image archive at: {}", filepaths); + // Randomize path list so that each task has some paths from each directory. Otherwise, some tasks would only get + // files from an "easy" directory where all the files are small + Collections.shuffle(paths); + + log.info("Total paths: {}", paths.size()); log.info("Default parallelism: {}", context.defaultParallelism()); - JavaPairRDD filesRdd = context.binaryFiles(filepaths.toString()); - log.info("filesRdd.numPartitions: {}", filesRdd.getNumPartitions()); + // This is a lot faster than using binaryFiles because 1) the paths are shuffled, 2) we use an optimized + // directory listing stream which does not consider file sizes. As a bonus, it actually respects the parallelism + // setting, unlike binaryFiles which ignores it unless you set other arcane settings like openCostInByte. + JavaRDD pathRDD = context.parallelize(paths); + log.info("filesRdd.numPartitions: {}", pathRDD.getNumPartitions()); - this.imagePlusRDD = filesRdd.mapToPair(pair -> new Tuple2<>(pair._1, readImagePlus(pair._1, "search", pair._2))).cache(); + // This RDD is cached so that it can be reused to search with multiple masks + this.imagePlusRDD = pathRDD.mapToPair(filepath -> + new Tuple2<>(filepath, readImagePlus(filepath, "search"))).cache(); log.info("imagePlusRDD.numPartitions: {}", imagePlusRDD.getNumPartitions()); log.info("imagePlusRDD.count: {}", imagePlusRDD.count());