From meshes to deformation fields

The goal in this tutorial is to understand how to view a set of triangle meshes as a set of deformation fields that can be modeled by a Gaussian Process.

In case you would like to learn more about the syntax of the Scala programming language, have a look at the Scala Cheat Sheet. This document is accessible under:

Documents -> Scala Cheat Sheet

Note: for optimal performance and to keep the memory footprint low, we recommend to restart Scalismo Lab at the beginning of every tutorial.

Gaussian processes and shape modelling:

Gaussian Processes (GP) are a mathematical concept allowing to model normal distributions over functions, be they scalar or vector valued. Using this concept, one can define a space of functions from which one can sample probabilistically, knowing that the sampled functions exhibit similar properties.

In Statistical Shape Modelling (SSM), one is interested in shape variations and seeks to model those statistically.

A simple approach to model shapes when point correspondences are given, is to study the position variation of every point and approximate those with a multivariate normal distribution.

For example, let's load a set of 3 faces in correspondence:

val files = new File("datasets/testFaces/").listFiles.take(3)
val dataset ={f: File => MeshIO.readMesh(f).get}
(0 until 3).foreach{i :Int => show(dataset(i),"face_"+i)}

Knowing that these meshes are in correspondence, one can already perform basic statistical operations such as computing the mean by averaging the position of the points sharing the same point id:

def computeCenter(points: Seq[Point[_3D]]) : Point[_3D] = {    
    val nbPoints = points.size
    val vectors ={p : Point[_3D] => p.toVector}
    val sumVectors :Vector[_3D] = vectors.reduce{(a,b) => a + b}
    val centerVec : Vector[_3D] = sumVectors * (1f/nbPoints)

val averagePoints : Iterator[Point[_3D]] = dataset(0) { id :PointId =>
    val positions ={ m: TriangleMesh => m.point(id)}    

val averageMesh = TriangleMesh(averagePoints.toIndexedSeq, dataset(0).cells)
show(averageMesh, "averageMesh")

Here, for every point identifier, we search the position of the point id in all our 3 faces and compute the center of these positions. A new mesh is then formed by simply selecting the average position of every point, and maintaining the same cell structure. This mesh can then be seen as the mean face of our dataset. (Doesn't look that mean really ! :) )

From meshes to deformations:

Presented this way, GPs and SSM seem not to have much in common as one models distributions over functions in one, and over point clouds in the other.

This changes, however, when adopting a different view on the dataset above. Instead of viewing the dataset as a set of point clouds, or point positions, we can see it as a set of deformation fields.

Let us from now on consider the mesh face_0 (the first item in our dataset) to be the reference of our dataset.

val reference = dataset(0)

The dataset can then be represented equivalently by the reference mesh and the set of deformations from the reference to the other items of the dataset.

For example, let's see how to deform the reference into face_1:

val deformations1 : IndexedSeq[Vector[_3D]] ={ 
  id : PointId =>  dataset(1).point(id) - reference.point(id)

In the code above, we compute the deformation vector between the reference and face_1 for every point of the reference mesh.

Let's now visualize it :

val vecField1 = DiscreteVectorField(reference, deformations1)
show(vecField1, "deformations1") 

(make the reference transparent and zoom in on the nose region for optimal visualization)

Exercise: Make face_1 and the reference visible in the scene and verify visually that the deformation vectors point from the reference to face_1

Using the deformation data, we created a DiscreteVectorField that maps 3-dimensional input (reference point) positions, into 3-dimensional output (the deformation) vectors.

Similar to discrete scalar images, a Discrete Vector Field is defined over a discrete domain. Notice, however, that the domain does not need to be structured (a grid for example) and can be any arbitrary finite set of points. In the above example code, we defined the domain to be the reference mesh points :

vecField1.domain == reference

The deformation vector associated with a particular point id in a DiscreteVectorField can be retrieved as follows (e.g. for id 0):

Exercise: generate the rest of the deformation fields that represent the rest of the faces in the dataset and display them.
dataset.zipWithIndex.foreach { case (mesh, i) =>  
    val deformations : IndexedSeq[Vector[_3D]] ={ 
    id : PointId =>  mesh.point(id) - reference.point(id)

  show(DiscreteVectorField(reference, deformations), "def_"+i)

Deformation fields over discrete and continuous domains:

The deformation field that we computed above is discrete as it is defined over a finite number of mesh points.

Given that the real-world objects that we model are continuous, we are generally interested in deformation fields defined over continuous domains.

Such deformations can be very useful when deforming higher resolution meshes, or even continuous surfaces or images, allowing to warp an entire image and not just a few chosen points.

To obtain such a vector field out of our discrete one, we can interpolate:

val contVectorField : VectorField[_3D, _3D] = vecField1.interpolateNearestNeighbor  

Here we created a deformation field defined over a continuous domain, by associating to every given point not belonging to the reference mesh, the deformation vector defined at its closest point on the reference mesh.

This vector field is now defined over the entire Realspace and can be evaluated at any point, even if it does not belong to the reference mesh vertices.

Exercise: Evaluate the new vector field above at the vertices of dataset(2) (face_2) and display the resulting vector field
val deformations2 : IndexedSeq[Vector[_3D]] = dataset(2){ p: Point[_3D] =>
// deformations2: IndexedSeq[scalismo.geometry.Vector[scalismo.geometry._3D]] = Vector(Vector3D(8.441597,-7.1066017,-0.20999908), Vector3D(8.416199,-7.0720024,-0.15770721), Vector3D(8.416199,-7.0720024,-0.15770721), Vector3D(8.441597,-7.1066017,-0.20999908), Vector3D(8.330803,-7.063404,-0.10410309), Vector3D(8.330803,-7.063404,-0.10410309), Vector3D(8.243801,-7.055998,-0.048606873), Vector3D(8.330803,-7.063404,-0.10410309), Vector3D(8.0784,-7.0897007,-0.0052948), Vector3D(8.1068,-7.114002,-0.12769318), Vector3D(7.9076996,-7.1274986,0.041893005), Vector3D(7.9496994,-7.141699,-0.04889679), Vector3D(7.9076996,-7.1274986,0.041893005), Vector3D(7.9496994,-7.141699,-0.04889679), Vector3D(7.7192993,-7.0980015,0.1476059), Vector3D(7.717102,-7.099201,0.046401978), Vector3D(7.531399,-7.068598,0.2503...

val vecField2 = DiscreteVectorField(dataset(2), deformations2)
// vecField2: scalismo.common.DiscreteVectorField[scalismo.geometry._3D,scalismo.geometry._3D] = <function1>

show(vecField2, "resampledAt2")
Exercise: Compute the mesh resulting from warping every point of face_2 with the vector field computed above and display it. Hint: you can define a transform as we did in the rigid alignment tutorial and use it to transform the mesh. (Do not expect a pretty result :))
def transform(p: Point[_3D]) = p + contVectorField(p)
// transform: (p: scalismo.geometry.Point[scalismo.geometry._3D])scalismo.geometry.Point[scalismo.geometry._3D]

val trans2 = dataset(2).transform(transform)
// trans2: scalismo.mesh.TriangleMesh = TriangleMesh(Vector(Point3D(-46.449303,33.125504,80.61171), Point3D(-46.424103,32.9051,80.660194), Point3D(-46.317104,32.9778,80.949196), Point3D(-46.38051,33.2537,80.8349), Point3D(-46.4632,32.6578,80.7117), Point3D(-46.3092,32.6769,81.0628), Point3D(-46.5003,32.3412,80.7211), Point3D(-46.2322,32.322796,81.059105), Point3D(-46.6136,31.984102,80.71721), Point3D(-46.375603,31.9188,81.03021), Point3D(-46.716103,31.567902,80.6923), Point3D(-46.468502,31.504103,81.060905), Point3D(-46.6502,31.188805,80.62129), Point3D(-46.351604,31.130404,80.98931), Point3D(-46.765305,30.817999,80.616714), Point3D(-46.465702,30.769802,80.98591), Point3D(-46.868504,30.451303,80.6039), Point3D(-46.5747,30.404902,80.9776), Point3D(-46.884407,30.023705,80.6123), Point3D(-46....

show(trans2, "trans2")