From 61b7cf134aeda9b150f31908f393b0de2784ca3d Mon Sep 17 00:00:00 2001 From: Yujing Huang Date: Fri, 15 Mar 2024 14:22:41 -0400 Subject: [PATCH] merge in Doug's changes to samseg scripts --- samseg/ProbabilisticAtlas.py | 8 ++++---- samseg/SamsegLongitudinal.py | 34 ++++++++++++++++++++++++++++------ samseg/cli/run_samseg_long.py | 18 +++++++++++++++++- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/samseg/ProbabilisticAtlas.py b/samseg/ProbabilisticAtlas.py index 9e5774d..33b5448 100644 --- a/samseg/ProbabilisticAtlas.py +++ b/samseg/ProbabilisticAtlas.py @@ -191,10 +191,10 @@ def deformMesh(self, mesh, transform, data, mask, means, variances, mixtureWeigh globalTic = time.perf_counter() minLogLikelihoodTimesDeformationPrior, maximalDeformation = optimizer.step_optimizer_samseg() globalToc = time.perf_counter() - print("maximalDeformation=%.4f minLogLikelihood=%.4f" % ( - maximalDeformation, minLogLikelihoodTimesDeformationPrior)) + #print("maximalDeformation=%.4f minLogLikelihood=%.4f" % ( + #maximalDeformation, minLogLikelihoodTimesDeformationPrior)) historyOfDeformationCost.append(minLogLikelihoodTimesDeformationPrior) - print( f" Total time spent: {globalToc-globalTic:0.4f} sec" ) + #print( f" Total time spent: {globalToc-globalTic:0.4f} sec" ) historyOfMaximalDeformation.append(maximalDeformation) if maximalDeformation == 0: break @@ -325,7 +325,7 @@ def deformMesh(self, mesh, transform, data, mask, means, variances, mixtureWeigh historyOfMaximalDeformation.append(maximalDeformation) globalToc = time.perf_counter() - print( f" Total time spent: {globalToc-globalTic:0.4f} sec" ) + #print( f" Total time spent: {globalToc-globalTic:0.4f} sec" ) if computeHistoryOfDeformationCost: tic = time.perf_counter() diff --git a/samseg/SamsegLongitudinal.py b/samseg/SamsegLongitudinal.py index 1bed546..df53d6f 100644 --- a/samseg/SamsegLongitudinal.py +++ b/samseg/SamsegLongitudinal.py @@ -113,6 +113,7 @@ def __init__(self, self.pallidumAsWM = pallidumAsWM self.savePosteriors = savePosteriors self.tpToBaseTransforms = tpToBaseTransforms + self.saveModelProbabilities = saveModelProbabilities # Check if all time point to base transforms are identity matrices. # If so, we can derive a combined 4D mask during preprocessing @@ -177,24 +178,30 @@ def __init__(self, self.historyOfTotalTimepointCost = None self.historyOfLatentAtlasCost = None - def segment(self, saveWarp=False): + def segment(self, saveWarp=False, initTransformFile=None)): # ======================================================================================= # # Main function that runs the whole longitudinal segmentation pipeline # # ======================================================================================= - self.constructAndRegisterSubjectSpecificTemplate() + self.constructAndRegisterSubjectSpecificTemplate(initTransformFile) self.preProcess() self.fitModel() return self.postProcess(saveWarp=saveWarp) - def constructAndRegisterSubjectSpecificTemplate(self): + def constructAndRegisterSubjectSpecificTemplate(self, initTransformFile=None): # ======================================================================================= # # Construction and affine registration of subject-specific template (sst) # # ======================================================================================= + # Initialization transform for registration + initTransform = None + if initTransformFile: + trg = self.validateTransform(sf.load_affine(initTransformFile)) + initTransform = convertRASTransformToLPS(trg.convert(space='world').matrix) + # Generate the subject specific template (sst) self.sstFileNames = self.generateSubjectSpecificTemplate() sstDir, _ = os.path.split(self.sstFileNames[0]) @@ -208,7 +215,7 @@ def constructAndRegisterSubjectSpecificTemplate(self): affine = Affine(imageFileName=self.sstFileNames[0], meshCollectionFileName=affineRegistrationMeshCollectionFileName, templateFileName=templateFileName) - self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer) + self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer, initTransform=initTransform) def preProcess(self): @@ -752,6 +759,17 @@ def postProcess(self, saveWarp=False): # Using estimated parameters, segment and write out results for each time point # # ======================================================================================= + # + + sstDir = os.path.join(self.savePath, 'base') + os.makedirs(sstDir, exist_ok=True) + baseModel = self.sstModel; + # Save the final mesh collection + if self.saveModelProbabilities: + print('Saving base model probs') + baseModel.saveGaussianProbabilities(os.path.join(sstDir, 'probabilities') ) + if saveWarp: + baseModel.saveWarpField(os.path.join(sstDir, 'template.m3z')) self.timepointVolumesInCubicMm = [] for timepointNumber in range(self.numberOfTimepoints): @@ -785,7 +803,10 @@ def postProcess(self, saveWarp=False): deformedAtlasFileName = os.path.join(timepointModel.savePath, 'mesh.txt') timepointModel.probabilisticAtlas.saveDeformedAtlas(timepointModel.modelSpecifications.atlasFileName, deformedAtlasFileName, nodePositions) - + if self.saveModelProbabilities: + print('Saving model probs') + timepointModel.saveGaussianProbabilities( os.path.join(timepointModel.savePath, 'probabilities') ) + # Save the history of the parameter estimation process if self.saveHistory: history = {'input': { @@ -821,7 +842,7 @@ def postProcess(self, saveWarp=False): with open(os.path.join(self.savePath, 'history.p'), 'wb') as file: pickle.dump(self.history, file, protocol=pickle.HIGHEST_PROTOCOL) - def generateSubjectSpecificTemplate(self): + def generateSubjectSpecificTemplate(self, saveWarp=False): sstDir = os.path.join(self.savePath, 'base') os.makedirs(sstDir, exist_ok=True) @@ -866,6 +887,7 @@ def constructSstModel(self): userOptimizationOptions=self.userOptimizationOptions, visualizer=self.visualizer, saveHistory=True, + savePosteriors=self.savePosteriors, targetIntensity=self.targetIntensity, targetSearchStrings=self.targetSearchStrings, modeNames=self.modeNames, diff --git a/samseg/cli/run_samseg_long.py b/samseg/cli/run_samseg_long.py index e54142e..1748e63 100644 --- a/samseg/cli/run_samseg_long.py +++ b/samseg/cli/run_samseg_long.py @@ -5,6 +5,7 @@ import os import argparse import surfa as sf +import json import samseg import numpy as np from samseg import SAMSEGDIR @@ -31,8 +32,10 @@ def parseArguments(argv): # optional processing options parser.add_argument('-m', '--mode', nargs='+', help='Output basenames for the input image mode.') parser.add_argument('-a', '--atlas', metavar='DIR', help='Point to an alternative atlas directory.') + parser.add_argument('--init-reg', metavar='FILE', help='Initial affine registration.') parser.add_argument('--deformation-hyperprior', type=float, default=20.0, help='Strength of the latent deformation hyperprior.') parser.add_argument('--gmm-hyperprior', type=float, default=0.5, help='Strength of the latent GMM hyperprior.') + parser.add_argument('--options', metavar='FILE', help='Override advanced options via a json file.') parser.add_argument('--pallidum-separate', action='store_true', default=False, help='Move pallidum outside of global white matter class. Use this flag when T2/flair is used.') parser.add_argument('--threads', type=int, default=default_threads, help='Number of threads to use. Defaults to current OMP_NUM_THREADS or 1.') parser.add_argument('--tp-to-base-transform', nargs='+', required=False, help='Transformation file for each time point to base.') @@ -95,6 +98,17 @@ def main(): # Setup the visualization tool visualizer = samseg.initVisualizer(args.showfigs, args.movie) + # Load user options from a JSON file + userModelSpecifications = {} + userOptimizationOptions = {} + if args.options: + with open(args.options) as f: + userOptions = json.load(f) + if userOptions.get('modelSpecifications') is not None: + userModelSpecifications = userOptions.get('modelSpecifications') + if userOptions.get('optimizationOptions') is not None: + userOptimizationOptions = userOptions.get('optimizationOptions') + # Start the process timer timer = samseg.Timer() @@ -164,6 +178,8 @@ def check_lta_file(filepath): imageFileNamesList=args.timepoint, atlasDir=atlasDir, savePath=args.output, + userModelSpecifications=userModelSpecifications, + userOptimizationOptions=userOptimizationOptions, targetIntensity=110, targetSearchStrings=['Cerebral-White-Matter'], modeNames=args.mode, @@ -200,7 +216,7 @@ def check_lta_file(filepath): else: samsegLongitudinal = samseg.SamsegLongitudinal(**samseg_kwargs) - samsegLongitudinal.segment(saveWarp=args.save_warp) + samsegLongitudinal.segment(saveWarp=args.save_warp, initTransformFile=args.init_reg) timer.mark('run_samseg_long complete')