diff --git a/.gitignore b/.gitignore index 1ea5fee..47f14f7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ __pycache__ .idea/ venv/ +.csv +.vscode/ +results/ \ No newline at end of file diff --git a/scripts/analyze_data.py b/scripts/analyze_data.py new file mode 100644 index 0000000..de35480 --- /dev/null +++ b/scripts/analyze_data.py @@ -0,0 +1,183 @@ +''' +This script loops on a config file (see init_data_config.py) to calculate metrics (.csv) and generate plots. +''' + +import os +import argparse +import json +import glob +from progress.bar import Bar +import csv + +from utils import get_img_path_from_mask_path, get_mask_path_from_img_path, edit_metric_dict, save_graphs, change_mask_suffix, get_deriv_sub_from_img_path, str_to_float_list, str_to_str_list, mergedict + +def run_analysis(args): + """ + Run analysis on a config file + """ + + short_suffix_disc = '_label' + short_suffix_seg = '_seg' + derivatives_folder = 'derivatives' + output_folder = 'results' + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + if args.config: + data_form = 'split' + # Read json file and create a dictionary + with open(args.config, "r") as file: + config_data = json.load(file) + + if config_data['TYPE'] == 'LABEL': + isImage = False + elif config_data['TYPE'] == 'IMAGE': + isImage = True + else: + raise ValueError(f'config with unknown TYPE {config_data['TYPE']}') + + # Remove keys that are not lists of paths + keys = list(config_data.keys()) + for key in keys: + if key not in ['TRAINING', 'VALIDATION', 'TESTING']: + del config_data[key] + + elif args.paths_to_bids: + data_form = 'dataset' + config_data = {} + for path_bids in args.paths_to_bids: + files = glob.glob(path_bids + "/**/" + "*.nii.gz", recursive=True) # Get all niftii files + config_data[os.path.basename(os.path.normpath(path_bids))] = [f for f in files if derivatives_folder not in f] # Remove masks from derivatives folder + isImage = True + + elif args.paths_to_csv: + data_form = 'dataset' + config_data = {} + else: + raise ValueError(f"Need to specify either args.paths_to_bids, args.config or args.paths_to_csv !") + + # Initialize metrics dictionary + metrics_dict = dict() + + if args.paths_to_csv: + for path_csv in args.paths_to_csv: + dataset_name = os.path.basename(path_csv).split('_')[-1].split('.csv')[0] + metrics_dict[dataset_name] = {} + with open(path_csv) as csv_file: + reader = csv.reader(csv_file) + for k, v in dict(reader).items(): + metric = k.split('_') + if len(metric) == 2: + metric_name, metric_value = metric + if metric_name not in metrics_dict[dataset_name].keys(): + metrics_dict[dataset_name][metric_name] = {metric_value:int(v)} + else: + metrics_dict[dataset_name][metric_name][metric_value] = int(v) + else: + if k.startswith('mismatch'): + metrics_dict[dataset_name][k] = int(v) + else: + metrics_dict[dataset_name][k] = str_to_str_list(v) + + # Initialize data finguerprint + fprint_dict = dict() + + if config_data.keys(): + missing_data = [] + # Extract information from the data + for key in config_data.keys(): + metrics_dict[key] = dict() + fprint_dict[key] = dict() + + # Init progression bar + bar = Bar(f'Analyze data {key} ', max=len(config_data[key])) + + for path in config_data[key]: + if isImage: + img_path = path # str + deriv_sub_folders = get_deriv_sub_from_img_path(img_path=img_path, derivatives_folder=derivatives_folder) # list of str + seg_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_seg, deriv_sub_folders=deriv_sub_folders, counterexample=['lesion', 'GM', 'WM']) # list of str + discs_paths = get_mask_path_from_img_path(img_path, short_suffix=short_suffix_disc, deriv_sub_folders=deriv_sub_folders, counterexample=['compression', 'SC_mask', 'seg', 'lesion', 'GM', 'WM']) # list of str + else: + img_path = get_img_path_from_mask_path(path, derivatives_folder=derivatives_folder) + deriv_sub_folders = [os.path.dirname(path)] + # Extract field of view information thanks to discs labels + if short_suffix_disc in path: + discs_paths = [path] + seg_paths = [change_mask_suffix(discs_paths, short_suffix=short_suffix_seg)] + elif short_suffix_seg in path: + seg_paths = [path] + discs_paths = [change_mask_suffix(seg_paths, short_suffix=short_suffix_disc)] + else: + seg_paths = [change_mask_suffix(path, short_suffix=short_suffix_seg)] + discs_paths = [change_mask_suffix(path, short_suffix=short_suffix_disc)] + + # Extract data + if os.path.exists(img_path): + metrics_dict[key], fprint_dict[key] = edit_metric_dict(metrics_dict[key], fprint_dict[key], img_path, seg_paths, discs_paths, deriv_sub_folders) + else: + missing_data.append(img_path) + + # Plot progress + bar.suffix = f'{config_data[key].index(path)+1}/{len(config_data[key])}' + bar.next() + bar.finish() + + # Store csv with computed metrics + if args.create_csv: + # Based on https://stackoverflow.com/questions/8685809/writing-a-dictionary-to-a-csv-file-with-one-line-for-every-key-value + out_csv_folder = os.path.join(output_folder, 'files') + if not os.path.exists(out_csv_folder): + os.makedirs(out_csv_folder) + csv_path_sum = os.path.join(out_csv_folder, f'computed_metrics_{key}.csv') + with open(csv_path_sum, 'w') as csv_file: + writer = csv.writer(csv_file) + for metric_name, metric in sorted(metrics_dict[key].items()): + if isinstance(metric,dict): + for metric_value, count in sorted(metric.items()): + k = f'{metric_name}_{metric_value}' + writer.writerow([k, count]) + else: + writer.writerow([metric_name, metric]) + + # Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + csv_path_fprint = os.path.join(out_csv_folder, f'fprint_{key}.csv') + sub_list = [sub for sub in fprint_dict[key].keys() if sub.startswith('sub')] + fields = ['subject'] + [k for k in fprint_dict[key][sub_list[0]].keys()] + with open(csv_path_fprint, 'w') as f: + w = csv.DictWriter(f, fields) + w.writeheader() + for k, v in fprint_dict[key].items(): + w.writerow(mergedict({'subject': k},v)) + + + if missing_data: + print("missing files:\n" + '\n'.join(missing_data)) + + # Plot data informations + save_graphs(output_folder=output_folder, metrics_dict=metrics_dict, data_form=data_form) + + + + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Analyse config file') + + ## Parameters + parser.add_argument('--paths-to-bids', default='', nargs='+', + help='Paths to BIDS compliant datasets (You can add multiple paths using spaces)') + parser.add_argument('--config', default='', + help='Path to JSON config file that contains all the training splits') + parser.add_argument('--paths-to-csv', default='', nargs='+', + help='Paths to csv files with already computed metrics (You can add multiple paths using spaces)') + parser.add_argument('--split', default='ALL', choices=('TRAINING', 'VALIDATION', 'TESTING', 'ALL'), + help='Split of the data that will be analysed (default="ALL")') + parser.add_argument('--create-csv', default=True, + help='Store computed metrics using a csv file in results/files (default=True)') + + # Start analysis + run_analysis(parser.parse_args()) \ No newline at end of file diff --git a/scripts/init_data_config.py b/scripts/init_data_config.py new file mode 100644 index 0000000..96fe5d4 --- /dev/null +++ b/scripts/init_data_config.py @@ -0,0 +1,120 @@ +""" +Generate a config file with all the paths to the used files. +See https://github.com/spinalcordtoolbox/disc-labeling-hourglass/issues/25#issuecomment-1695818382 +Script copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass +""" + +import os +import argparse +import random +import json +import itertools +import numpy as np + +from utils import get_img_path_from_mask_path, get_cont_path_from_other_cont, fetch_contrast, fetch_subject_and_session + + +# Determine specified contrasts +def init_data_config(args): + """ + Create a JSON configuration file from a TXT file where images paths are specified + """ + if (args.split_validation + args.split_test) > 1: + raise ValueError("The sum of the ratio between testing and validation cannot exceed 1") + + # Get input paths, could be label files or image files, + # and make sure they all exist. + file_paths = [os.path.abspath(path.replace('\n', '')) for path in open(args.txt)] + if args.type == 'LABEL': + label_paths = file_paths + img_paths = [get_img_path_from_mask_path(lp) for lp in label_paths] + file_paths = label_paths + img_paths + elif args.type == 'IMAGE': + img_paths = file_paths + elif args.type == 'CONTRAST': + if not args.cont: # If the target contrast is not specified + raise ValueError(f'When using the type CONTRAST, please specify the target contrast using the flag "--cont"') + img_paths = file_paths + new_contrast = args.cont + label_paths = [get_cont_path_from_other_cont(ip) for ip in img_paths] + file_paths = label_paths + img_paths + else: + raise ValueError(f"invalid args.type: {args.type}") + missing_paths = [ + path for path in file_paths + if not os.path.isfile(path) + ] + + if missing_paths: + raise ValueError("missing files:\n" + '\n'.join(missing_paths)) + + # Extract BIDS parent folder path + dataset_parent_path_list = ['/'.join(path.split('/sub')[0].split('/')[:-1]) for path in img_paths] + + # Check if all the BIDS folders are stored inside the same parent repository + if (np.array(dataset_parent_path_list) == dataset_parent_path_list[0]).all(): + dataset_parent_path = dataset_parent_path_list[0] + else: + raise ValueError('Please store all the BIDS datasets inside the same parent folder !') + + # Look up the right code for the set of contrasts present + contrasts = "_".join(tuple(sorted(set(map(fetch_contrast, img_paths))))) + + config = { + 'TYPE': args.type, + 'CONTRASTS': contrasts, + 'DATASETS_PATH': dataset_parent_path + } + + # Add target contrast when the type CONTRAST is used + if args.type == 'CONTRAST': + config['TARGET_CONTRAST'] = args.cont + + # Split into training, validation, and testing sets + split_ratio = (1 - (args.split_validation + args.split_test), args.split_validation, args.split_test) # TRAIN, VALIDATION, and TEST + config_paths = label_paths if args.type == 'LABEL' else img_paths + config_paths = [path.split(dataset_parent_path + '/')[-1] for path in config_paths] # Remove DATASETS_PATH + random.shuffle(config_paths) + splits = [0] + [ + int(len(config_paths) * ratio) + for ratio in itertools.accumulate(split_ratio) + ] + for key, (begin, end) in zip( + ['TRAINING', 'VALIDATION', 'TESTING'], + pairwise(splits), + ): + config[key] = config_paths[begin:end] + + # Save the config + config_path = args.txt.replace('.txt', '') + '.json' + json.dump(config, open(config_path, 'w'), indent=4) + +def pairwise(iterable): + # pairwise('ABCDEFG') --> AB BC CD DE EF FG + # based on https://docs.python.org/3.11/library/itertools.html + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Create config JSON from a TXT file which contains list of paths') + + ## Parameters + parser.add_argument('--txt', required=True, + help='Path to TXT file that contains only image or label paths. (Required)') + parser.add_argument('--type', choices=('LABEL', 'IMAGE', 'CONTRAST'), + help='Type of paths specified. Choices are "LABEL", "IMAGE" or "CONTRAST". (Required)') + parser.add_argument('--cont', type=str, default='', + help='If the type CONTRAST is selected, this variable specifies the wanted contrast for target.') + parser.add_argument('--split-validation', type=float, default=0.1, + help='Split ratio for validation. Default=0.1') + parser.add_argument('--split-test', type=float, default=0.1, + help='Split ratio for testing. Default=0.1') + + args = parser.parse_args() + + if args.split_test > 0.9: + args.split_validation = 1 - args.split_test + + init_data_config(args) diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000..c37dffa --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,632 @@ +import os +import re +from pathlib import Path +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np +import glob +import math + +from image import Image + +## Global variables +CONTRAST = {'t1': ['T1w'], + 't2': ['T2w'], + 't2s':['T2star'], + 't1_t2': ['T1w', 'T2w'], + 'psir': ['PSIR'], + 'stir': ['STIR'], + 'psir_stir': ['PSIR', 'STIR'], + 't1_t2_psir_stir': ['T1w', 'T2w', 'PSIR', 'STIR'] + } + +## Functions +def get_img_path_from_mask_path(str_path, derivatives_folder='derivatives'): + """ + This function does 2 things: ⚠️ Files need to be stored in a BIDS compliant dataset + - Step 1: Remove label suffix (e.g. "_labels-disc-manual"). The suffix is always between the MRI contrast and the file extension. + - Step 2: Remove derivatives path (e.g. derivatives/labels/). The first folders is always called derivatives but the second may vary (e.g. labels_soft) + + :param path: absolute path to the label img. Example: //derivatives/labels/sub-amuALT/anat/sub-amuALT_T1w_labels-disc-manual.nii.gz + :return: img path. Example: //sub-amuALT/anat/sub-amuALT_T1w.nii.gz + Based on https://github.com/spinalcordtoolbox/disc-labeling-hourglass + """ + # Load path + path = Path(str_path) + + # Extract file extension + ext = ''.join(path.suffixes) + + # Get img name + img_name = '_'.join(path.name.split('_')[:-1]) + ext + + # Create a list of the directories + dir_list = str(path.parent).split('/') + + # Remove "derivatives" and "labels" folders + derivatives_idx = dir_list.index(derivatives_folder) + dir_path = '/'.join(dir_list[0:derivatives_idx] + dir_list[derivatives_idx+2:]) + + # Recreate img path + img_path = os.path.join(dir_path, img_name) + + return img_path + +## +def get_mask_path_from_img_path(img_path, deriv_sub_folders, short_suffix='_seg', ext='.nii.gz', counterexample=[]): + """ + This function returns the mask path from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset. + + :param img_path: String path to niftii image + :param suffix: Mask suffix + :param ext: File extension + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(img_path) + + # Find corresponding mask + mask_path = [] + for deriv_path in deriv_sub_folders: + if counterexample: # Deal with counter examples + paths = [] + for path in glob.glob(deriv_path + filename.split(ext)[0] + short_suffix + "*" + ext): + iswrong = False + for c in counterexample: + if c in path: + iswrong = True + if not iswrong: + paths.append(path) + else: + paths = glob.glob(deriv_path + filename.split(ext)[0] + "*" + short_suffix + "*" + ext) + + if len(paths) > 1: + print(f'Image {img_path} has multiple masks\n: {"\n".join(paths)}') + elif len(paths) == 1: + mask_path.append(paths[0]) + return mask_path + + +def get_cont_path_from_other_cont(str_path, cont): + """ + :param str_path: absolute path to the input nifti img. Example: //sub-amuALT/anat/sub-amuALT_T1w.nii.gz + :param cont: contrast of the target output image stored in the same data folder. Example: T2w + :return: path to the output target image. Example: //sub-amuALT/anat/sub-amuALT_T2w.nii.gz + + """ + # Load path + path = Path(str_path) + + # Extract file extension + ext = ''.join(path.suffixes) + + # Remove input contrast from name + path_list = path.name.split('_') + suffixes_pos = [1 if len(part.split('-')) == 1 else 0 for part in path_list] + contrast_idx = suffixes_pos.index(1) # Find suffix + + # New image name + img_name = '_'.join(path_list[:contrast_idx]+[cont]) + ext + + # Recreate img path + img_path = os.path.join(str(path.parent), img_name) + + return img_path + +def get_deriv_sub_from_img_path(img_path, derivatives_folder='derivatives'): + """ + This function returns the derivatives path of the subject from an image path or an empty string if the path does not exists. Images need to be stored in a BIDS compliant dataset. + + :param img_path: String path to niftii image + :param derivatives_folder: List of derivatives paths + :param ext: File extension + """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(img_path) + path_bids, path_sub_folder = img_path.split(subjectID)[0:-1] + path_sub_folder = subjectID + path_sub_folder + + # Find corresponding mask + deriv_sub_folder = glob.glob(path_bids + "**/" + derivatives_folder + "/**/" + path_sub_folder, recursive=True) + + return deriv_sub_folder + +## +def change_mask_suffix(mask_path, short_suffix='_seg', ext='.nii.gz'): + """ + This function replace the current suffix with a new suffix suffix. If path is specified, make sure the dataset is BIDS compliant. + + :param mask_path: Input mask filepath or filename + :param new_suffix: New mask suffix + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + """ + # Extract information from path + subjectID, sessionID, filename, contrast, echoID, acq = fetch_subject_and_session(mask_path) + path_deriv_sub = mask_path.split(filename)[0] + + # Find corresponding new_mask + new_mask_path = glob.glob(path_deriv_sub + '_'.join(filename.split('_')[:-1]) + short_suffix + "*" + ext) + + if len(new_mask_path) > 1: + print(f'Multiple {short_suffix} masks for subject {subjectID} \n: {'\n'.join(new_mask_path)}') + mask_path = '' + elif len(new_mask_path) == 1: + new_mask_path = new_mask_path[0] + else: # mask does not exist + new_mask_path = '' + + return new_mask_path + + +def list_der_suffixes(folder_path, ext='.nii.gz'): + """ + This function return all the labels suffixes. If path is specified, make sure the dataset is BIDS compliant. + + :param folder_path: Path to folder where labels are stored. + """ + folder_path = os.path.normpath(folder_path) + files = [file for file in os.listdir(folder_path) if file.endswith(ext)] + suffixes = [] + for file in files: + subjectID, sessionID, filename, contrast, echoID, acquisition = fetch_subject_and_session(file) + split_file = file.split(ext)[0].split('_') + skip_idx = 0 + for sp in [subjectID, sessionID, echoID, acquisition]: + if sp: + skip_idx = skip_idx + 1 + suffix = '_' + '_'.join(split_file[skip_idx+1:]) # +1 to skip contrast + if not suffix =='_': + suffixes.append(suffix) + return suffixes +## +def fetch_subject_and_session(filename_path): + """ + Get subject ID, session ID and filename from the input BIDS-compatible filename or file path + The function works both on absolute file path as well as filename + :param filename_path: input nifti filename (e.g., sub-001_ses-01_T1w.nii.gz) or file path + (e.g., /home/user/MRI/bids/derivatives/labels/sub-001/ses-01/anat/sub-001_ses-01_T1w.nii.gz + :return: subjectID: subject ID (e.g., sub-001) + :return: sessionID: session ID (e.g., ses-01) + :return: filename: nii filename (e.g., sub-001_ses-01_T1w.nii.gz) + :return: contrast: MRI modality (dwi or anat) + :return: echoID: echo ID (e.g., echo-1) + :return: acquisition: acquisition (e.g., acq_sag) + Based on https://github.com/spinalcordtoolbox/manual-correction + """ + + _, filename = os.path.split(filename_path) # Get just the filename (i.e., remove the path) + subject = re.search('sub-(.*?)[_/]', filename_path) # [_/] means either underscore or slash + subjectID = subject.group(0)[:-1] if subject else "" # [:-1] removes the last underscore or slash + + session = re.search('ses-(.*?)[_/]', filename_path) # [_/] means either underscore or slash + sessionID = session.group(0)[:-1] if session else "" # [:-1] removes the last underscore or slash + + echo = re.search('echo-(.*?)[_]', filename_path) # [_/] means either underscore or slash + echoID = echo.group(0)[:-1] if echo else "" # [:-1] removes the last underscore or slash + + acq = re.search('acq-(.*?)[_]', filename_path) # [_/] means either underscore or slash + acquisition = acq.group(0)[:-1] if acq else "" # [:-1] removes the last underscore or slash + # REGEX explanation + # . - match any character (except newline) + # *? - match the previous element as few times as possible (zero or more times) + + contrast = 'dwi' if 'dwi' in filename_path else 'anat' # Return contrast (dwi or anat) + + return subjectID, sessionID, filename, contrast, echoID, acquisition + + +def fetch_contrast(filename_path): + ''' + Extract MRI contrast from a BIDS-compatible IMAGE filename/filepath + The function handles images only. + :param filename_path: image file path or file name. (e.g sub-001_ses-01_T1w.nii.gz) + Copied from https://github.com/spinalcordtoolbox/disc-labeling-hourglass + ''' + return filename_path.rstrip(''.join(Path(filename_path).suffixes)).split('_')[-1] + +def str_to_str_list(string): + string = string[1:-1] # remove brackets + return [s[1:-1] for s in string.split(', ')] + +def str_to_float_list(string): + string = string[1:-1] # remove brackets + return [float(s) for s in string.split(', ')] + + +def edit_metric_dict(metrics_dict, fprint_dict, img_path, seg_paths, discs_paths, deriv_sub_folders): + ''' + This function extracts information and metadata from an image and its mask. Values are then + gathered inside a dictionary. + + :param metrics_dict: dictionary containing summary metadata + :param fprint_dict: dictionary containing all the informations + :param img_path: niftii image path + :param seg_path: corresponding niftii spinal cord segmentation path + :param discs_path: corresponding niftii discs mask path + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + #-----------------------------------------------------------------------# + #----------------------- Extracting metadata ---------------------------# + #-----------------------------------------------------------------------# + # Extract original image orientation + img = Image(img_path) + orientation = img.orientation + + # Extract information from path + subjectID, sessionID, filename, c, echoID, acq = fetch_subject_and_session(img_path) + + # Extract image dimensions and resolutions + img_RPI = img.change_orientation("RPI") + nx, ny, nz, nt, px, py, pz, pt = img_RPI.dim + + # Extract discs + check for shape mismatch between discs labels and image + discs_labels = [] + count_discs = 0 + if discs_paths: + for path in discs_paths: + discs_mask = Image(path).change_orientation("RPI") + discs_labels += [list(coord)[-1] for coord in discs_mask.getNonZeroCoordinates(sorting='value')] + if img_RPI.data.shape != discs_mask.data.shape: + count_discs += 1 + + # Check for shape mismatch between segmentation and image + count_seg = 0 + if seg_paths: + for path in seg_paths: + if img_RPI.data.shape != Image(path).change_orientation("RPI").data.shape: + count_seg += 1 + + # Compute image size + X, Y, Z = nx*px, ny*py, nz*pz + + # Extract MRI contrast from image only + contrast = fetch_contrast(img_path) + + # Extract suffixes + suffixes = [] + for path in deriv_sub_folders: + for suf in list_der_suffixes(path): + if not suf in suffixes: + suffixes.append(suf) + + # Extract derivatives folder + der_folders = [] + for path in deriv_sub_folders: + der_folders.append(os.path.basename(os.path.dirname(path.split(subjectID)[0]))) + + #-------------------------------------------------------------------------------# + #--------------------- Adding metadata to summary dict -------------------------# + #-------------------------------------------------------------------------------# + list_of_metrics = [orientation, contrast, X, Y, Z, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['orientation', 'contrast', 'X', 'Y', 'Z', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + for key, metric in zip(list_of_keys, list_of_metrics): + if not isinstance(metric,str): + metric = str(metric) + if key not in metrics_dict.keys(): + metrics_dict[key] = {metric:1} + else: + if metric not in metrics_dict[key].keys(): + metrics_dict[key][metric] = 1 + else: + metrics_dict[key][metric] += 1 + + # Add count shape mismatch + key_mis_seg = 'mismatch-seg' + if key_mis_seg not in metrics_dict.keys(): + metrics_dict[key_mis_seg] = count_seg + else: + metrics_dict[key_mis_seg] += count_seg + + key_mis_disc = 'mismatch-disc' + if key_mis_disc not in metrics_dict.keys(): + metrics_dict[key_mis_disc] = count_discs + else: + metrics_dict[key_mis_disc] += count_discs + + # Add discs labels + key_discs = 'discs-labels' + if discs_labels: + if key_discs not in metrics_dict.keys(): + metrics_dict[key_discs] = {} + for disc in discs_labels: + disc = str(disc) + if disc not in metrics_dict[key_discs].keys(): + metrics_dict[key_discs][disc] = 1 + else: + metrics_dict[key_discs][disc] += 1 + + # Add suffixes + suf_key = 'suffixes' + if suf_key not in metrics_dict.keys(): + metrics_dict[suf_key] = suffixes + else: + for suf in suffixes: + if not suf in metrics_dict[suf_key]: + metrics_dict[suf_key].append(suf) + + # Add derivatives folders + der_key = 'derivatives' + if der_key not in metrics_dict.keys(): + metrics_dict[der_key] = der_folders + else: + for der in der_folders: + if not der in metrics_dict[der_key]: + metrics_dict[der_key].append(der) + + #--------------------------------------------------------------------------------# + #--------------------- Storing metadata to exhaustive dict -------------------------# + #--------------------------------------------------------------------------------# + fprint_dict[filename] = {} + + # Add contrast + fprint_dict[filename]['contrast'] = contrast + + # Add orientation + fprint_dict[filename]['img_orientation'] = orientation + + # Add info SC segmentations + if seg_paths: + fprint_dict[filename]['seg-sc'] = True + suf_seg = [path.split(contrast)[-1].split('.')[0] for path in seg_paths] + fprint_dict[filename]['seg-suffix'] = '/'.join(suf_seg) + fprint_dict[filename]['seg-mismatch'] = count_seg + else: + fprint_dict[filename]['seg-sc'] = False + fprint_dict[filename]['seg-suffix'] = '' + fprint_dict[filename]['seg-mismatch'] = count_seg + + # Add info discs labels + if discs_paths: + fprint_dict[filename]['discs-label'] = True + suf_discs = [path.split(contrast)[-1].split('.')[0] for path in discs_paths] + fprint_dict[filename]['discs-suffix'] = '/'.join(suf_discs) + fprint_dict[filename]['discs-mismatch'] = count_discs + else: + fprint_dict[filename]['discs-label'] = False + fprint_dict[filename]['discs-suffix'] = '' + fprint_dict[filename]['discs-mismatch'] = count_discs + + # Add discs labels + key_discs = 'discs-labels' + label_list = np.arange(1,27).tolist() + [49, 50, 60] + for num_label in label_list: + if num_label in discs_labels: + fprint_dict[filename][f'label_{str(num_label)}'] = True + else: + fprint_dict[filename][f'label_{str(num_label)}'] = False + + # Add dim and resolutions + list_of_metrics = [X, Y, Z, nx, ny, nz, nt, px, py, pz, pt] + list_of_keys = ['X', 'Y', 'Z', 'nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt'] + for key, metric in zip(list_of_keys, list_of_metrics): + fprint_dict[filename][key] = metric + + return metrics_dict, fprint_dict + + +def save_violin(names, values, output_path, x_axis, y_axis): + ''' + Create a violin plot + :param names: String list of the names + :param values: List of values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: y-axis name + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + + # Set position of bar on X axis + result_dict = {'names':[], 'values':[]} + for i, name in enumerate(names): + result_dict['values'] += values[i] + for j in range(len(values[i])): + result_dict['names'] += [name] + + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + plt.figure() + sns.violinplot(x="names", y="values", data=result_df) + plt.xlabel(x_axis, fontsize = 15) + plt.ylabel(y_axis, fontsize = 15) + plt.title(y_axis, fontsize = 20) + + # Save plot + plt.savefig(output_path) + + +def save_group_violins(name, values, output_path, x_axis, y_axis): + ''' + Create a violin plot + :param name: Dataset name + :param values: List of metrics containing lists of values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: List of y-axis name corresponding to each metrics + ''' + + # Create plot + fig, axs = plt.subplots(3, len(values)//3 + 1, figsize=(1.8*len(values),11)) + + fig.suptitle(f'{x_axis} : {name}', fontsize = 30) + + for idx_line, val in enumerate(values): + # Set position of bar on X axis + result_dict = {} + result_dict['values'] = val + result_dict['metrics'] = [y_axis[idx_line]]*len(val) + + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + sns.violinplot(ax=axs[idx_line//4, idx_line%4], x="metrics", y="values", data=result_df) + axs[idx_line//4, idx_line%4].set(xticklabels=[]) + axs[idx_line//4, idx_line%4].set_ylabel("") + axs[idx_line//4, idx_line%4].set_xlabel("") + axs[idx_line//4, idx_line%4].set_title(y_axis[idx_line], fontsize=20) + + # Save plot + plt.savefig(output_path) + + +def save_hist(names, values, output_path, x_axis, y_axis): + ''' + Create a histogram plot + :param names: String list of the names + :param values: Values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: y-axis name + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + + # Set position of bar on X axis + result_dict = {'names':[], 'values':[]} + for i, name in enumerate(names): + result_dict['values'] += values[i] + for j in range(len(values[i])): + result_dict['names'] += [name] + + result_df = pd.DataFrame(data=result_dict) + + # Make the plot + binwidth= 1/(1*len(names)) if len(names) > 1 else 1/3 + shrink = 1 if len(names) > 1 else 0.7 + plt.figure(figsize=(np.max(result_dict['values']), 8)) + sns.histplot(data=result_df, x="values", hue="names", multiple="dodge", binwidth=binwidth, shrink=shrink) + plt.xlabel(x_axis, fontsize = 15) + plt.xticks(np.arange(1, np.max(result_dict['values'])+1)) + plt.ylabel(x_axis, fontsize = 15) + plt.title(y_axis, fontsize = 20) + + # Save plot + plt.savefig(output_path) + + +def save_pie(names, values, output_path, x_axis, y_axis): + ''' + Create a pie chart plot + :param names: String list of the names + :param values: Values associated with the names + :param output_path: Output path (string) + :param x_axis: x-axis name + :param y_axis: y-axis name + + Based on https://www.geeksforgeeks.org/how-to-create-a-pie-chart-in-seaborn/ + ''' + # Set position of bar on X axis + result_dict = {} + for i, name in enumerate(names): + result_dict[name] = {} + for val in values[i]: + if val not in result_dict[name].keys(): + result_dict[name][val] = 1 + else: + result_dict[name][val] += 1 + # Regroup small values + other_count = 0 + other_name_list = [] + for v, count in result_dict[name].items(): + if count <= math.ceil(0.004*len(values[i])): + other_count += count + other_name_list.append(v) + for v in other_name_list: + del result_dict[name][v] + if other_name_list: + result_dict[name]['other'] = other_count + + + # define Seaborn color palette to use + palette_color = sns.color_palette('bright') + + def autopct_format(values): + ''' + Based on https://stackoverflow.com/questions/53782591/how-to-display-actual-values-instead-of-percentages-on-my-pie-chart-using-matplo + ''' + def my_format(pct): + total = sum(values) + val = int(round(pct*total)/100) + return '{v:d}'.format(v=val) + return my_format + + # Make the plot + if len(names) == 1: + fig = plt.figure() + plt.pie(result_dict[names[0]].values(), labels=result_dict[names[0]].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[0]].values())) + plt.title(y_axis, fontsize = 20) + plt.xlabel(names[0], fontsize = 15) + #plt.ylabel(y_axis, fontsize = 15) + else: + fig, axs = plt.subplots(1, len(names), figsize=(3*len(names),5)) + fig.suptitle(y_axis, fontsize = 8*len(names)) + + for j, name in enumerate(result_dict.keys()): + axs[j].pie(result_dict[name].values(), labels=result_dict[name].keys(), colors=palette_color, autopct=autopct_format(result_dict[names[j]].values())) + axs[j].set_title(name) + + axs[0].set(ylabel=y_axis) + + # Save plot + plt.savefig(output_path) + +def convert_dict_to_float_list(dic): + """ + This function converts dictionary with {str(value):int(nb_occurence)} to a list [float(value)]*nb_occurence + """ + out_list = [] + for value, count in dic.items(): + out_list += [float(value)]*count + return out_list + +def convert_dict_to_list(dic): + """ + This function converts dictionary with {str(value):int(nb_occurence)} to a list [str(value)]*nb_occurence + """ + out_list = [] + for value, count in dic.items(): + out_list += [value]*count + return out_list + +def save_graphs(output_folder, metrics_dict, data_form='split'): + ''' + Plot and save metrics into an output folder + + Based on https://github.com/spinalcordtoolbox/disc-labeling-benchmark + ''' + # Extract subjects and metrics + data_name = np.array(list(metrics_dict.keys())) + + # Use violin plots + # for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): + # out_path = os.path.join(output_folder, f'{metric}.png') + # metric_name = metric + ' ' + f'({unit})' + # save_violin(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric_name) + + # Save violin plot in one fig + for name in data_name: + tot_values = [] + tot_names = [] + for metric, unit in zip(['nx', 'ny', 'nz', 'nt', 'px', 'py', 'pz', 'pt', 'X', 'Y', 'Z'], ['pixel', 'pixel', 'pixel', '', 'mm/pixel', 'mm/pixel', 'mm/pixel', '', 'mm', 'mm', 'mm']): + tot_values.append(convert_dict_to_float_list(metrics_dict[name][metric])) + tot_names.append(metric + ' ' + f'({unit})') + out_path = os.path.join(output_folder, f'violin_stats.png') + save_group_violins(name=name, values=tot_values, output_path=out_path, x_axis=data_form, y_axis=tot_names) + + # Use bar pie chart + for metric in ['orientation', 'contrast']: + out_path = os.path.join(output_folder, f'{metric}.png') + save_pie(names=data_name, values=[convert_dict_to_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=data_form, y_axis=metric) + + # Use bar graphs + for metric in ['discs-labels']: + out_path = os.path.join(output_folder, f'{metric}.png') + save_hist(names=data_name, values=[convert_dict_to_float_list(metrics_dict[name][metric]) for name in data_name], output_path=out_path, x_axis=metric, y_axis='Count') + +def mergedict(a,b): + a.update(b) + return a \ No newline at end of file