Skip to content

Commit

Permalink
merge in Doug's changes to samseg scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuang43 committed Mar 15, 2024
1 parent 8087e69 commit 61b7cf1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
8 changes: 4 additions & 4 deletions samseg/ProbabilisticAtlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def deformMesh(self, mesh, transform, data, mask, means, variances, mixtureWeigh
globalTic = time.perf_counter()
minLogLikelihoodTimesDeformationPrior, maximalDeformation = optimizer.step_optimizer_samseg()
globalToc = time.perf_counter()
print("maximalDeformation=%.4f minLogLikelihood=%.4f" % (
maximalDeformation, minLogLikelihoodTimesDeformationPrior))
#print("maximalDeformation=%.4f minLogLikelihood=%.4f" % (
#maximalDeformation, minLogLikelihoodTimesDeformationPrior))
historyOfDeformationCost.append(minLogLikelihoodTimesDeformationPrior)
print( f" Total time spent: {globalToc-globalTic:0.4f} sec" )
#print( f" Total time spent: {globalToc-globalTic:0.4f} sec" )
historyOfMaximalDeformation.append(maximalDeformation)
if maximalDeformation == 0:
break
Expand Down Expand Up @@ -325,7 +325,7 @@ def deformMesh(self, mesh, transform, data, mask, means, variances, mixtureWeigh
historyOfMaximalDeformation.append(maximalDeformation)

globalToc = time.perf_counter()
print( f" Total time spent: {globalToc-globalTic:0.4f} sec" )
#print( f" Total time spent: {globalToc-globalTic:0.4f} sec" )

if computeHistoryOfDeformationCost:
tic = time.perf_counter()
Expand Down
34 changes: 28 additions & 6 deletions samseg/SamsegLongitudinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self,
self.pallidumAsWM = pallidumAsWM
self.savePosteriors = savePosteriors
self.tpToBaseTransforms = tpToBaseTransforms
self.saveModelProbabilities = saveModelProbabilities

# Check if all time point to base transforms are identity matrices.
# If so, we can derive a combined 4D mask during preprocessing
Expand Down Expand Up @@ -177,24 +178,30 @@ def __init__(self,
self.historyOfTotalTimepointCost = None
self.historyOfLatentAtlasCost = None

def segment(self, saveWarp=False):
def segment(self, saveWarp=False, initTransformFile=None)):
# =======================================================================================
#
# Main function that runs the whole longitudinal segmentation pipeline
#
# =======================================================================================
self.constructAndRegisterSubjectSpecificTemplate()
self.constructAndRegisterSubjectSpecificTemplate(initTransformFile)
self.preProcess()
self.fitModel()
return self.postProcess(saveWarp=saveWarp)

def constructAndRegisterSubjectSpecificTemplate(self):
def constructAndRegisterSubjectSpecificTemplate(self, initTransformFile=None):
# =======================================================================================
#
# Construction and affine registration of subject-specific template (sst)
#
# =======================================================================================

# Initialization transform for registration
initTransform = None
if initTransformFile:
trg = self.validateTransform(sf.load_affine(initTransformFile))
initTransform = convertRASTransformToLPS(trg.convert(space='world').matrix)

# Generate the subject specific template (sst)
self.sstFileNames = self.generateSubjectSpecificTemplate()
sstDir, _ = os.path.split(self.sstFileNames[0])
Expand All @@ -208,7 +215,7 @@ def constructAndRegisterSubjectSpecificTemplate(self):
affine = Affine(imageFileName=self.sstFileNames[0],
meshCollectionFileName=affineRegistrationMeshCollectionFileName,
templateFileName=templateFileName)
self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer)
self.imageToImageTransformMatrix, _ = affine.registerAtlas(savePath=sstDir, visualizer=self.visualizer, initTransform=initTransform)


def preProcess(self):
Expand Down Expand Up @@ -752,6 +759,17 @@ def postProcess(self, saveWarp=False):
# Using estimated parameters, segment and write out results for each time point
#
# =======================================================================================
#

sstDir = os.path.join(self.savePath, 'base')
os.makedirs(sstDir, exist_ok=True)
baseModel = self.sstModel;
# Save the final mesh collection
if self.saveModelProbabilities:
print('Saving base model probs')
baseModel.saveGaussianProbabilities(os.path.join(sstDir, 'probabilities') )
if saveWarp:
baseModel.saveWarpField(os.path.join(sstDir, 'template.m3z'))

self.timepointVolumesInCubicMm = []
for timepointNumber in range(self.numberOfTimepoints):
Expand Down Expand Up @@ -785,7 +803,10 @@ def postProcess(self, saveWarp=False):
deformedAtlasFileName = os.path.join(timepointModel.savePath, 'mesh.txt')
timepointModel.probabilisticAtlas.saveDeformedAtlas(timepointModel.modelSpecifications.atlasFileName,
deformedAtlasFileName, nodePositions)

if self.saveModelProbabilities:
print('Saving model probs')
timepointModel.saveGaussianProbabilities( os.path.join(timepointModel.savePath, 'probabilities') )

# Save the history of the parameter estimation process
if self.saveHistory:
history = {'input': {
Expand Down Expand Up @@ -821,7 +842,7 @@ def postProcess(self, saveWarp=False):
with open(os.path.join(self.savePath, 'history.p'), 'wb') as file:
pickle.dump(self.history, file, protocol=pickle.HIGHEST_PROTOCOL)

def generateSubjectSpecificTemplate(self):
def generateSubjectSpecificTemplate(self, saveWarp=False):

sstDir = os.path.join(self.savePath, 'base')
os.makedirs(sstDir, exist_ok=True)
Expand Down Expand Up @@ -866,6 +887,7 @@ def constructSstModel(self):
userOptimizationOptions=self.userOptimizationOptions,
visualizer=self.visualizer,
saveHistory=True,
savePosteriors=self.savePosteriors,
targetIntensity=self.targetIntensity,
targetSearchStrings=self.targetSearchStrings,
modeNames=self.modeNames,
Expand Down
18 changes: 17 additions & 1 deletion samseg/cli/run_samseg_long.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import argparse
import surfa as sf
import json
import samseg
import numpy as np
from samseg import SAMSEGDIR
Expand All @@ -31,8 +32,10 @@ def parseArguments(argv):
# optional processing options
parser.add_argument('-m', '--mode', nargs='+', help='Output basenames for the input image mode.')
parser.add_argument('-a', '--atlas', metavar='DIR', help='Point to an alternative atlas directory.')
parser.add_argument('--init-reg', metavar='FILE', help='Initial affine registration.')
parser.add_argument('--deformation-hyperprior', type=float, default=20.0, help='Strength of the latent deformation hyperprior.')
parser.add_argument('--gmm-hyperprior', type=float, default=0.5, help='Strength of the latent GMM hyperprior.')
parser.add_argument('--options', metavar='FILE', help='Override advanced options via a json file.')
parser.add_argument('--pallidum-separate', action='store_true', default=False, help='Move pallidum outside of global white matter class. Use this flag when T2/flair is used.')
parser.add_argument('--threads', type=int, default=default_threads, help='Number of threads to use. Defaults to current OMP_NUM_THREADS or 1.')
parser.add_argument('--tp-to-base-transform', nargs='+', required=False, help='Transformation file for each time point to base.')
Expand Down Expand Up @@ -95,6 +98,17 @@ def main():
# Setup the visualization tool
visualizer = samseg.initVisualizer(args.showfigs, args.movie)

# Load user options from a JSON file
userModelSpecifications = {}
userOptimizationOptions = {}
if args.options:
with open(args.options) as f:
userOptions = json.load(f)
if userOptions.get('modelSpecifications') is not None:
userModelSpecifications = userOptions.get('modelSpecifications')
if userOptions.get('optimizationOptions') is not None:
userOptimizationOptions = userOptions.get('optimizationOptions')

# Start the process timer
timer = samseg.Timer()

Expand Down Expand Up @@ -164,6 +178,8 @@ def check_lta_file(filepath):
imageFileNamesList=args.timepoint,
atlasDir=atlasDir,
savePath=args.output,
userModelSpecifications=userModelSpecifications,
userOptimizationOptions=userOptimizationOptions,
targetIntensity=110,
targetSearchStrings=['Cerebral-White-Matter'],
modeNames=args.mode,
Expand Down Expand Up @@ -200,7 +216,7 @@ def check_lta_file(filepath):
else:
samsegLongitudinal = samseg.SamsegLongitudinal(**samseg_kwargs)

samsegLongitudinal.segment(saveWarp=args.save_warp)
samsegLongitudinal.segment(saveWarp=args.save_warp, initTransformFile=args.init_reg)

timer.mark('run_samseg_long complete')

Expand Down

0 comments on commit 61b7cf1

Please sign in to comment.