TopNet: Topology Preserving Metric Learning

Posted on Fri, Jul 16, 2021 paper-summary

These are some of my notes for the MICCAI 2020 paper TopNet: Topology Preserving Metric Learning.

Motivation

Introduction

Normal segmentation makes local mistakes which are amplifies the error across the image.

Contributions:

  1. Multi-task network for tree reconstruction which detects centerlines and connectivity between centerlines.
  2. Topology metric which learns both intra and interclass topological distance between vascular voxel pairs.
  3. Verified better accuracy

Method

There are 3 parts in their multi-task network.

1. Vessel extraction decoder

Separate vessels from background. Formulated as a semantic segmentation problem. Use dice loss to train this task.

where vv is the ground truth binary vessel volume patch, vv' is the predicted volume patch.

2. Centerness decoder

The centerness score can be formulated as a regression problem (it will be unstable to train as a binary segmentation problem).

Let miMm_i \in M be the binary volume representing center voxels.

For a given voxel jj, define distance transform:

Sj=[distance to closest mi=1]S_j = [ \text{distance to closest } m_i = 1 ]

The values are considered only for points inside vessel (a lot of points outside vessel will skew the output otherwise).

Use smooth L1 loss:

The denominator ensures that the outputs are not skewed towards voxels away from center (there are S2\propto S^2 voxels at a distance SS from centerline).

Perform NMS (non-maximal suppression) over a window size of 5×5×55\times 5 \times 5.

3. Topology distance Decoder

>>> Most interesting part

This network outputs a 8d8d feature vector for each voxel in CT volume.

The idea here is that the distance between the features xi,xjx_i, x_j at voxels i,ji, j given by xixj\| x_i - x_j\| corresponds to topological distance in the vessel space if they are in the same tree.

If they are in different trees, the distance should be large.

Given labels li,ljl_i, l_j for voxels i,ji, j, the loss function is given by:

For a given input image patch, the loss is:

where NiN_i is a neighborhood of voxel ii.

Neighborhood is defined inside a radius of 15 voxels ⇒ topological distance is bounded.

They set α=115\alpha = \frac{1}{15} to approximately normalize topological distance αDij\alpha D_{ij} from 0 to 1. Also set K=3K = 3. To balance loss terms, set γ=13\gamma = \frac{1}{3}.

>>> Tree reconstruction (during test-time post-training)

  1. Use vessel prediction to mask center-line distance.
  2. Use NMS on center-line distance SiS_i to find center voxels.
  3. For voxel ii, create edges for all nodes jj where xixj2\|x_i - x_j\| \le 2.
  4. Use Dijkstra's multi-source shortest path tree algorithm.

Experiments

Datasets

IRCAD public dataset, internal dataset.

Methods

  1. Comparison with single task 3D UNet (dice loss).
  2. For multi-task, replace topology distance with 2-channel probability map (dice loss over all center-voxels).
  3. Multi-task cosine metric learning (SijS_{ij} = cosine distance)

    For tree reconstruction, check where Sij<=0.5S_{ij} <= 0.5 and take edge weights as:

    wij=Eij(1Sij)w_{ij} = E_{ij} (1 - S_{ij})

where EijE_{ij} is the euclidean distance between voxels i,ji, j.