diff --git a/Cargo.lock b/Cargo.lock index 2d45ded..556714c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -261,7 +261,7 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "labelme2yolo" -version = "0.2.4" +version = "0.2.5" dependencies = [ "base64", "clap", diff --git a/Cargo.toml b/Cargo.toml index 30613d2..280b28d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "labelme2yolo" -version = "0.2.4" +version = "0.2.5" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/main.rs b/src/main.rs index e030e9f..cab717e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,13 @@ +use std::collections::HashMap; +use std::fs::{self, copy, File}; +use std::io::{BufWriter, Write}; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, Mutex, +}; + use clap::{Parser, ValueEnum}; use env_logger; use glob::glob; @@ -9,15 +19,6 @@ use rand::SeedableRng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use serde_json; -use std::collections::HashMap; -use std::fs::{self, copy, File}; -use std::io::{BufWriter, Write}; -use std::path::{Path, PathBuf}; -use std::str::FromStr; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, -}; #[derive(Debug, Serialize, Deserialize, Clone)] struct Shape { @@ -46,45 +47,33 @@ struct ImageAnnotation { #[command(version, about = "Convert LabelMe JSON to YOLO format", long_about = None)] struct Args { // Directory containing LabelMe JSON files - #[arg( - short = 'd', - long = "json_dir", - help = "Directory containing LabelMe JSON files" - )] + #[arg(short = 'd', long = "json_dir")] json_dir: String, // Proportion of the dataset to use for validation - #[arg(long = "val_size", default_value_t = 0.2, value_parser = validate_size, help = "Proportion of the dataset to use for validation (between 0.0 and 1.0)")] + #[arg(long = "val_size", default_value_t = 0.2, value_parser = validate_size)] val_size: f32, // Proportion of the dataset to use for testing - #[arg(long = "test_size", default_value_t = 0.0, value_parser = validate_size, help = "Proportion of the dataset to use for testing (between 0.0 and 1.0)")] + #[arg(long = "test_size", default_value_t = 0.0, value_parser = validate_size)] test_size: f32, - // Output format (bbox or polygon) for YOLO annotations + // Output format for YOLO annotations: 'bbox' or 'polygon' #[arg( long = "output_format", visible_alias = "format", value_enum, - default_value = "bbox", - help = "Output format for YOLO annotations: 'bbox' or 'polygon'" + default_value = "bbox" )] output_format: Format, - // List of labels in the dataset - #[arg( - use_value_delimiter = true, - help = "Comma-separated list of labels in the dataset" - )] - label_list: Vec, - // Seed for random shuffling - #[arg( - long = "seed", - default_value_t = 42, - help = "Seed for random shuffling" - )] + #[arg(long = "seed", default_value_t = 42)] seed: u64, + + // List of labels in the dataset + #[arg(use_value_delimiter = true)] + label_list: Vec, } // Enumeration for the YOLO output format @@ -159,9 +148,10 @@ fn create_output_directory(path: &Path) -> std::io::Result { "Directory {:?} already exists. Deleting and recreating it.", path ); - fs::remove_dir_all(path)?; + fs::remove_dir_all(path).and_then(|_| fs::create_dir_all(path))?; + } else { + fs::create_dir_all(path)?; } - fs::create_dir_all(path)?; Ok(path.to_path_buf()) } @@ -264,7 +254,7 @@ fn initialize_label_map( for (id, label) in args.label_list.iter().enumerate() { map.insert(label.clone(), id); } - next_class_id.store(args.label_list.len(), Ordering::Relaxed); + next_class_id.store(args.label_list.len(), Relaxed); } else { // Otherwise, use labels found in the dataset split_data @@ -275,7 +265,7 @@ fn initialize_label_map( .flat_map(|(_, annotation)| annotation.shapes.iter()) .for_each(|shape| { if !map.contains_key(&shape.label) { - let new_id = next_class_id.fetch_add(1, Ordering::Relaxed); + let new_id = next_class_id.fetch_add(1, Relaxed); map.insert(shape.label.clone(), new_id); } }); @@ -392,7 +382,7 @@ fn process_annotation( let yolo_data = convert_to_yolo_format(annotation, args, label_map); - let sanitized_name = sanitize_filename::sanitize(path.file_stem().unwrap().to_str().unwrap()); + let sanitized_name = sanitize_filename::sanitize(path.file_name().unwrap().to_str().unwrap()); let output_path = labels_dir.join(&sanitized_name).with_extension("txt"); let file = File::create(&output_path)?; @@ -474,11 +464,12 @@ fn process_polygon_shape(yolo_data: &mut String, annotation: &ImageAnnotation, s yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); } } else if shape.shape_type == "circle" { + const CIRCLE_POINTS: usize = 12; let (cx, cy) = shape.points[0]; let (px, py) = shape.points[1]; let radius = ((cx - px).powi(2) + (cy - py).powi(2)).sqrt(); - for i in 0..12 { - let angle = 2.0 * std::f64::consts::PI * i as f64 / 12.0; + for i in 0..CIRCLE_POINTS { + let angle = 2.0 * std::f64::consts::PI * i as f64 / CIRCLE_POINTS as f64; let x = cx + radius * angle.cos(); let y = cy + radius * angle.sin(); let x_norm = x / annotation.image_width as f64;