Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added realignment to dlgp and pdm #428

Open
wants to merge 14 commits into
base: release-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/main/scala/scalismo/numerics/GramDiagonalize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package scalismo.numerics

import breeze.linalg.{*, given}

object GramDiagonalize {

/**
* Given a non orthogonal basis nxr and the variance (squared [eigen]scalars) of that basis, returns an orthonormal
* basis with the adjusted variance. sets small eigenvalues to zero.
*/
def rediagonalizeGram(basis: DenseMatrix[Double],
s: DenseVector[Double]
): (DenseMatrix[Double], DenseVector[Double]) = {
// val l: DenseMatrix[Double] = basis(*, ::) * breeze.numerics.sqrt(s)
val l: DenseMatrix[Double] = DenseMatrix.zeros[Double](basis.rows, basis.cols)
val sqs: DenseVector[Double] = breeze.numerics.sqrt(s)
for i <- 0 until basis.cols do l(::, i) := sqs(i) * basis(::, i)

val gram = l.t * l
val svd = breeze.linalg.svd(gram)
val newS: DenseVector[Double] = breeze.numerics.sqrt(svd.S).map(d => if (d > 1e-10) 1.0 / d else 0.0)

// val newbasis: DenseMatrix[Double] = l * (svd.U(*, ::) * newS)
val inner: DenseMatrix[Double] = DenseMatrix.zeros[Double](gram.rows, gram.cols)
for i <- 0 until basis.cols do inner(::, i) := newS(i) * svd.U(::, i)
val newbasis = l * inner

(newbasis, svd.S)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ import breeze.linalg.svd.SVD
import breeze.linalg.{diag, DenseMatrix, DenseVector}
import breeze.stats.distributions.Gaussian
import scalismo.common.DiscreteField.vectorize
import scalismo.common._
import scalismo.common.*
import scalismo.common.interpolation.{FieldInterpolator, NearestNeighborInterpolator}
import scalismo.geometry._
import scalismo.geometry.*
import scalismo.image.StructuredPoints
import scalismo.kernels.{DiscreteMatrixValuedPDKernel, MatrixValuedPDKernel}
import scalismo.numerics.{PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair => DiscreteEigenpair, _}
import scalismo.numerics.{GramDiagonalize, PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair as DiscreteEigenpair, *}
import scalismo.statisticalmodel.LowRankGaussianProcess.Eigenpair
import scalismo.statisticalmodel.NaNStrategy.NanIsNumericValue
import scalismo.statisticalmodel.dataset.DataCollection
import scalismo.utils.{Memoize, Random}

import scala.annotation.threadUnsafe
import scala.language.higherKinds
import scala.collection.parallel.immutable.ParVector

Expand Down Expand Up @@ -358,6 +359,37 @@ class DiscreteLowRankGaussianProcess[D: NDSpace, DDomain[DD] <: DiscreteDomain[D
)
}

/**
* realigns the model on the provided part of the domain. Aligns over the translation and, when using
* withExtendedBasis = true, over the extended basis (the default implicit [[RealignExtendedBasis]] adds rotation.
* This rotation will always be calculated around the center of the provided ids. Rotations are around the cardinal
* directions.).
*
* @param ids
* these define the parts of the domain that are aligned to. Depending on the withExtendedBasis parameter has a
* minimum length requirements (default basis extension in 3D should be used with >=4 provided ids for example)
* @param withExtendedBasis
* True if the extended basis should be included. By default this uses a rotation extension. False makes the
* realignment only over translation. Translational alignment can be done exactly. For more information see
* [[RealignExtendedBasis]].
* @param diagonalize
* True if a diagonal basis should be returned. In general, it is strongly recommended to use a orthonormal basis -
* here referred to as diagonal. This does not increase complexity and is a more intuitive formulation of the model.
* If internal fields are accessed diagonalize should be set to true. This option can be set to false to make the
* same coefficient lead to very similar shapes in the pre- and after realignment model (or exactly the same shapes
* if withExtendedBasis = false).
* @return
* The resulting [[DiscreteLowRankGaussianProcess]] aligned on the provided instances of [[PointId]]. If
* withExtendedBasis = false then the original and the returned model can produce the same mesh with different
* translations. That means the shape spaces are the same but the fieldsa are translated.
*/
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
DiscreteLowRankGaussianProcess.realignment(this, ids, withExtendedBasis, diagonalize)
}

protected[statisticalmodel] def instanceVector(alpha: DenseVector[Double]): DenseVector[Double] = {
require(rank == alpha.size)

Expand Down Expand Up @@ -638,6 +670,75 @@ object DiscreteLowRankGaussianProcess {
DiscreteMatrixValuedPDKernel(domain, cov, outputDim)
}

def realignment[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD], Value](
model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
ids: IndexedSeq[PointId],
withExtendedBasis: Boolean,
diagonalize: Boolean
)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
val d = NDSpace.apply[D].dimensionality
// build the projection matrix for the desired pose
val p = {
@threadUnsafe
lazy val pt = breeze.linalg.tile(DenseMatrix.eye[Double](d), model.domain.pointSet.numberOfPoints, 1)
if withExtendedBasis then
val center = ids.map(id => model.domain.pointSet.point(id).toVector).reduce(_ + _).map(_ / ids.length).toPoint
val pr = realigning.getBasis[DDomain](model, center)
if realigning.useTranslation then DenseMatrix.horzcat(pt, pr)
else pr
else pt
}
// call the realignment implementation
val (nmean, nbasis, nvar) = realignmentComputation(model.meanVector,
model.basisMatrix,
model.variance,
p,
ids.map(_.id),
dim = d,
diagonalize = diagonalize,
projectMean = false
)

new DiscreteLowRankGaussianProcess[D, DDomain, Value](model.domain, nmean, nvar, nbasis)
}

private def realignmentComputation(mean: DenseVector[Double],
basis: DenseMatrix[Double],
s: DenseVector[Double],
p: DenseMatrix[Double],
ids: IndexedSeq[Int],
dim: Int,
diagonalize: Boolean,
projectMean: Boolean
): (DenseVector[Double], DenseMatrix[Double], DenseVector[Double]) = {
val x = for // prepare indices
id <- ids
d <- 0 until dim
yield id * dim + d
// prepare the majority of the projection matrix
val px = p(x, ::).toDenseMatrix
val ptpipt = breeze.linalg.pinv(px.t * px) * px.t

// performs the actual projection. batches all basis vectors
// p -> projection rank, n number of indexes*dim, r cols of basis, N rows of basis
val alignedC = ptpipt * basis(x, ::).toDenseMatrix // pxn * nxr
val alignedEigf = basis - p * alignedC // Nxr - Nxp * pxr
val alignedMean = if projectMean then // if desired projects the mean vector as well
val alignedMc = ptpipt * mean // same projection with r==1
mean - p * alignedMc
else mean

// rediagonalize. You can skip this if you ONLY sample from the resulting model
val (newbasis, news) =
if diagonalize then GramDiagonalize.rediagonalizeGram(alignedEigf, s)
else (alignedEigf, s)

(alignedMean, newbasis, news)
}

}

// Explicit variants for 1D, 2D and 3D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,17 @@ case class PointDistributionModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
PointDistributionModel(newGP)
}

/**
* realigns the internal [[DiscreteLowRankGaussianProcess]] and returns the resulting [[PointDistributionModel]]. this
* calls [[DiscreteLowRankGaussianProcess.realign]].
*/
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[EuclideanVector[D]],
realign: RealignExtendedBasis[D, EuclideanVector[D]]
): PointDistributionModel[D, DDomain] = {
new PointDistributionModel[D, DDomain](this.gp.realign(ids, withExtendedBasis, diagonalize))
}

}

object PointDistributionModel {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package scalismo.statisticalmodel

import breeze.linalg.DenseMatrix
import scalismo.common.DiscreteDomain
import scalismo.geometry.*

/**
* ideally used to represent linear effects that should be normalized between the training data. The realignment process
* then builds a linear projection matrix that is applied to an existing model. the space of the effects that should be
* normalized needs to be spanned by the returned matrix
*/
trait RealignExtendedBasis[D, Value]:

/**
* whether or not the default translation basis should also be used. that means false does not perform a translation
* realignment. This in combination with getBasis allows for complete control of the projection matrix.
*/
def useTranslation: Boolean

/**
* basis to span the kernel of the projection. for example, a translation alignment could be performed by spanning
* that space with constant vectors for each cardinal direction.
*/
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
center: Point[D]
): DenseMatrix[Double]

/**
* includes the additional default rotation centerpoint implementation which is useful to calculate the rotation basis.
*/
trait RealignExtendedBasisRotation[D, Value] extends RealignExtendedBasis[D, Value]:
def centeredP[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD]](domain: DDomain[D],
center: Point[D]
): DenseMatrix[Double] = {
// build centered data matrix
val x = DenseMatrix.zeros[Double](center.dimensionality, domain.pointSet.numberOfPoints)
val c = center.toBreezeVector
for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c
x
}

object RealignExtendedBasis:
/**
* returns a projection basis for rotation - the tangential speed for the rotations around the three cardinal
* directions.
*/
given realignBasis3D: RealignExtendedBasisRotation[_3D, EuclideanVector[_3D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](
model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]],
center: Point[_3D]
): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

val pr = DenseMatrix.zeros[Double](np * 3, 3)
// the derivative of the rotation matrices
val dr = new DenseMatrix[Double](9,
3,
Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
)
// get tangential speed
val dx = dr * x
for i <- 0 until 3 do
val v = dx(3 * i until 3 * i + 3, ::).toDenseVector
pr(::, i) := v / breeze.linalg.norm(v)
pr
}

/**
* returns a projection basis for rotation - the tangential speed for the single 2d rotation.
*/
given realignBasis2D: RealignExtendedBasisRotation[_2D, EuclideanVector[_2D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](
model: DiscreteLowRankGaussianProcess[_2D, DDomain, EuclideanVector[_2D]],
center: Point[_2D]
): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

// derivative of the rotation matrix
val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0))
val dx = (dr * x).reshape(2 * np, 1)
val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0)
dx / n(0)
}
Loading
Loading