Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added realignment to dlgp and pdm #428

Open
wants to merge 14 commits into
base: release-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/main/scala/scalismo/numerics/GramDiagonalize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package scalismo.numerics

import breeze.linalg.{*, given}

object GramDiagonalize {

/**
* Given a non orthogonal basis nxr and the variance (squared [eigen]scalars) of that basis, returns an orthonormal
* basis with the adjusted variance. sets small eigenvalues to zero.
*/
def rediagonalizeGram(basis: DenseMatrix[Double],
s: DenseVector[Double]
): (DenseMatrix[Double], DenseVector[Double]) = {
// val l: DenseMatrix[Double] = basis(*, ::) * breeze.numerics.sqrt(s)
val l: DenseMatrix[Double] = DenseMatrix.zeros[Double](basis.rows, basis.cols)
val sqs: DenseVector[Double] = breeze.numerics.sqrt(s)
for i <- 0 until basis.cols do l(::, i) := sqs(i) * basis(::, i)

val gram = l.t * l
val svd = breeze.linalg.svd(gram)
val newS: DenseVector[Double] = breeze.numerics.sqrt(svd.S).map(d => if (d > 1e-10) 1.0 / d else 0.0)

// val newbasis: DenseMatrix[Double] = l * (svd.U(*, ::) * newS)
val inner: DenseMatrix[Double] = DenseMatrix.zeros[Double](gram.rows, gram.cols)
for i <- 0 until basis.cols do inner(::, i) := newS(i) * svd.U(::, i)
val newbasis = l * inner

(newbasis, svd.S)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ import breeze.linalg.svd.SVD
import breeze.linalg.{diag, DenseMatrix, DenseVector}
import breeze.stats.distributions.Gaussian
import scalismo.common.DiscreteField.vectorize
import scalismo.common._
import scalismo.common.*
import scalismo.common.interpolation.{FieldInterpolator, NearestNeighborInterpolator}
import scalismo.geometry._
import scalismo.geometry.*
import scalismo.image.StructuredPoints
import scalismo.kernels.{DiscreteMatrixValuedPDKernel, MatrixValuedPDKernel}
import scalismo.numerics.{PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair => DiscreteEigenpair, _}
import scalismo.numerics.{GramDiagonalize, PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair as DiscreteEigenpair, *}
import scalismo.statisticalmodel.LowRankGaussianProcess.Eigenpair
import scalismo.statisticalmodel.NaNStrategy.NanIsNumericValue
import scalismo.statisticalmodel.dataset.DataCollection
import scalismo.utils.{Memoize, Random}

import scala.annotation.threadUnsafe
import scala.language.higherKinds
import scala.collection.parallel.immutable.ParVector

Expand Down Expand Up @@ -358,6 +359,29 @@ class DiscreteLowRankGaussianProcess[D: NDSpace, DDomain[DD] <: DiscreteDomain[D
)
}

/**
* realigns the model on the provided part of the domain. By default aligns over the translation and approximately
* over rotation as well. The rotation will always be calculated around the center of the provided ids. Rotations are
* around the cardinal directions.
*
* @param ids
* these define the parts of the domain that are aligned to. Depending on the withExtendedBasis parameter has a
* minimum requirements (default basis extension in 3D should be used with >=4 provided ids)
* @param withExtendedBasis
* True if the extended basis should be included. By default this uses a rotation extension. that means False makes
* the realignment over translation exact.
* @param diagonalize
* True if a diagonal basis should be returned. False is cheaper for exclusively drawing samples.
* @return
* The resulting [[DiscreteLowRankGaussianProcess]] aligned on the provided instances of [[PointId]]
*/
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
DiscreteLowRankGaussianProcess.realignment(this, ids, withExtendedBasis, diagonalize)
}

protected[statisticalmodel] def instanceVector(alpha: DenseVector[Double]): DenseVector[Double] = {
require(rank == alpha.size)

Expand Down Expand Up @@ -638,6 +662,75 @@ object DiscreteLowRankGaussianProcess {
DiscreteMatrixValuedPDKernel(domain, cov, outputDim)
}

def realignment[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD], Value](
model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
ids: IndexedSeq[PointId],
withExtendedBasis: Boolean,
diagonalize: Boolean
)(using
vectorizer: Vectorizer[Value],
realigning: RealignExtendedBasis[D, Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
val d = NDSpace.apply[D].dimensionality
// build the projection matrix for the desired pose
val p = {
@threadUnsafe
lazy val pt = breeze.linalg.tile(DenseMatrix.eye[Double](d), model.domain.pointSet.numberOfPoints, 1)
if withExtendedBasis then
val center = ids.map(id => model.domain.pointSet.point(id).toVector).reduce(_ + _).map(_ / ids.length).toPoint
val pr = realigning.getBasis[DDomain](model, center)
if realigning.useTranslation then DenseMatrix.horzcat(pt, pr)
else pr
else pt
}
// call the realignment implementation
val (nmean, nbasis, nvar) = realignmentComputation(model.meanVector,
model.basisMatrix,
model.variance,
p,
ids.map(_.id),
dim = d,
diagonalize = diagonalize,
projectMean = false
)

new DiscreteLowRankGaussianProcess[D, DDomain, Value](model.domain, nmean, nvar, nbasis)
}

private def realignmentComputation(mean: DenseVector[Double],
basis: DenseMatrix[Double],
s: DenseVector[Double],
p: DenseMatrix[Double],
ids: IndexedSeq[Int],
dim: Int,
diagonalize: Boolean,
projectMean: Boolean
): (DenseVector[Double], DenseMatrix[Double], DenseVector[Double]) = {
val x = for // prepare indices
id <- ids
d <- 0 until dim
yield id * dim + d
// prepare the majority of the projection matrix
val px = p(x, ::).toDenseMatrix
val ptpipt = breeze.linalg.pinv(px.t * px) * px.t

// performs the actual projection. batches all basis vectors
// p -> projection rank, n number of indexes*dim, r cols of basis, N rows of basis
val alignedC = ptpipt * basis(x, ::).toDenseMatrix // pxn * nxr
val alignedEigf = basis - p * alignedC // Nxr - Nxp * pxr
val alignedMean = if projectMean then // if desired projects the mean vector as well
val alignedMc = ptpipt * mean // same projection with r==1
mean - p * alignedMc
else mean

// rediagonalize. You can skip this if you ONLY sample from the resulting model
val (newbasis, news) =
if diagonalize then GramDiagonalize.rediagonalizeGram(alignedEigf, s)
else (alignedEigf, s)

(alignedMean, newbasis, news)
}

}

// Explicit variants for 1D, 2D and 3D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ case class PointDistributionModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
PointDistributionModel(newGP)
}

/**
* realigns the [[DiscreteLowRankGaussianProcess]] and returns the resulting [[PointDistributionModel]]
*/
def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using
vectorizer: Vectorizer[EuclideanVector[D]],
realign: RealignExtendedBasis[D, EuclideanVector[D]]
): PointDistributionModel[D, DDomain] = {
new PointDistributionModel[D, DDomain](this.gp.realign(ids, withExtendedBasis, diagonalize))
}

}

object PointDistributionModel {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package scalismo.statisticalmodel

import breeze.linalg.DenseMatrix
import scalismo.common.DiscreteDomain
import scalismo.geometry.*

/**
* types the whole discrete low rank gp to make sure that it is applied to the appropriate models. The value type could
* be left out if the user knows what to do.
*/
trait RealignExtendedBasis[D, Value]:

def useTranslation: Boolean
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
center: Point[D]
): DenseMatrix[Double]
def centeredP[DDomain[DD] <: DiscreteDomain[DD]](domain: DDomain[D], center: Point[D]): DenseMatrix[Double] = {
// build centered data matrix
val x = DenseMatrix.zeros[Double](center.dimensionality, domain.pointSet.numberOfPoints)
val c = center.toBreezeVector
for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c
x
}

object RealignExtendedBasis:
/**
* returns a projection basis for rotation. that is the tangential speed for the rotations around the three cardinal
* directions.
*/
given realignBasis3D: RealignExtendedBasis[_3D, EuclideanVector[_3D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](
model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]],
center: Point[_3D]
): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

val pr = DenseMatrix.zeros[Double](np * 3, 3)
// the derivative of the rotation matrices
val dr = new DenseMatrix[Double](9,
3,
Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
)
// get tangential speed
val dx = dr * x
for i <- 0 until 3 do
val v = dx(3 * i until 3 * i + 3, ::).toDenseVector
pr(::, i) := v / breeze.linalg.norm(v)
pr
}

/**
* returns a projection basis for rotation. that is the tangential speed for the single 2d rotation.
*/
given realignBasis2D: RealignExtendedBasis[_2D, EuclideanVector[_2D]] with
def useTranslation: Boolean = true
def getBasis[DDomain[DD] <: DiscreteDomain[DD]](
model: DiscreteLowRankGaussianProcess[_2D, DDomain, EuclideanVector[_2D]],
center: Point[_2D]
): DenseMatrix[Double] = {
val np = model.domain.pointSet.numberOfPoints
val x = centeredP(model.domain, center)

// derivative of the rotation matrix
val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0))
val dx = (dr * x).reshape(2 * np, 1)
val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0)
dx / n(0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package scalismo.statisticalmodel

import breeze.linalg.{DenseMatrix, DenseVector}
import breeze.linalg.{rand, DenseMatrix, DenseVector}
import breeze.stats.distributions.Gaussian
import scalismo.ScalismoTestSuite
import scalismo.common.*
Expand All @@ -28,6 +28,7 @@ import scalismo.io.statisticalmodel.StatismoIO
import scalismo.kernels.{DiagonalKernel, GaussianKernel, MatrixValuedPDKernel}
import scalismo.numerics.PivotedCholesky.RelativeTolerance
import scalismo.numerics.{GridSampler, UniformSampler}
import scalismo.transformations.TranslationAfterRotation
import scalismo.utils.Random

import java.net.URLDecoder
Expand Down Expand Up @@ -723,6 +724,54 @@ class GaussianProcessTests extends ScalismoTestSuite {
}
}

describe("when realigning a model") {
it("the translation aligned model should be exactly orthogonal to translations on the aligned ids.") {
val f = Fixture
val dgp = f.discreteLowRankGp

val alignedDgp = dgp.realign(dgp.mean.pointsWithIds.map(t => t._2).toIndexedSeq)

val shifts: IndexedSeq[Double] = alignedDgp.klBasis
.map(klp => {
val ef = klp.eigenfunction
ef.data.reduce(_ + _).norm
})
.toIndexedSeq
val res = shifts.sum

res shouldBe 0.0 +- 1e-7
}

it("the rotation aligned model should exhibit only small rotations on the aligned ids.") {
val f = Fixture
val dgp = f.discreteLowRankGp

val ids = {
val sorted = dgp.mean.pointsWithIds.toIndexedSeq.sortBy(t => t._1.toArray.sum)
sorted.take(10).map(_._2)
}
val coef = (0 until 5).map(_ => this.random.scalaRandom.nextInt(100))
val alignedDgp = dgp.realign(ids)
val res = IndexedSeq(dgp, alignedDgp).map(model => {
val samples = coef
.map(i => model.instance(DenseVector.tabulate[Double](model.rank) { j => if i == j then 0.1 else 0.0 }))
.map(_ => model.sample())
val rp = ids.map(id => model.mean.domain.pointSet.point(id).toVector).reduce(_ + _).map(d => d / ids.length)
val rotations = samples.map(sample => {
val ldms = ids.map(id =>
(model.domain.pointSet.point(id) + model.mean.data(id.id),
sample.domain.pointSet.point(id.id) + sample.data(id.id)
)
)
val rigidTransform: TranslationAfterRotation[_3D] =
scalismo.registration.LandmarkRegistration.rigid3DLandmarkRegistration(ldms, rp.toPoint)
rigidTransform.rotation.parameters
})
rotations.map(m => m.data.map(math.abs).sum).sum
})
res(1) shouldBe <(res(0) * 0.6)
}
}
}

describe("when comparing marginalLikelihoods") {
Expand Down
Loading