Post

NGPConv - Scaling 3D ConvNets to gigavoxel images using Implicit Neural Representations

NGPConv - Scaling 3D ConvNets to gigavoxel images using Implicit Neural Representations

Table of Contents

TL;DR

3D ConvNets are limited to low resolution images due to the cubic scaling of the number of voxels. However, applications like medical imaging, weather forecasting, high-fidelity scene representations, etc. require ultra high resolution images on which 3D ConvNets run out of memory 😓

To scale 3D ConvNets to gigavoxel images, we use InstantNGP to compress the data into an implicit codebook and define a convolution operator on the codebook itself. Our CUDA kernel is up to 10x faster than a PyTorch implementation 🚀🚀

Introduction

Convolutional architectures have been the backbone of computer vision and natural language processing. Even with the domination of transformer models from large language models, convolutional networks have stayed competitive for visual processing. Although standard 2D convolutional architectures support very high resolution 2D images (upto 4k images), 3D convolutional networks have have been limited to lower resolutions.

Although 3D convolutional networks are ubiquitiously used in fields like medical imaging, they have been limited to lower resolutions. For example, the 3D convolutional architectures in VoxelMorph, SynthSeg and nnUNet have been limited to smaller resolutions like $160\times 192\times 224$ or $128\times 128\times 128$. This is because the number of activations in convolutional architecture is proportional to the number of voxels in the input volume.

Larger datasets like the CCFv3 Atlas have resolutions of upto $1320\times 1140\times 800$ voxels, BigBrain has $1700 \times 1600 \times 1300$ voxels, more than $500$ times larger than the current capacity of 3d deep learning models. Even a single 3D convolutional layer with 32 hidden dimensions would generate an activation volume of size 420GB. This is not feasible for current hardware.

📝 Note: In this blog, I will use medical images as a representative example, but other representations like NeRFs, signed distance fields (SDFs), or occupancy fields are also represented by regular grids, but consume a lot of memory due to the cubic scaling of the number of voxels.

Figure: Most Deep Learning models are trained on small 3D datasets like OASIS. On the contrary, large datasets like CCFv3 and BigBrain are out of reach for current deep learning models.

There is a resolution – computation tradeoff. To work with high resolution voxel grids (medical images, 3D scene representations) or densely sampled point clouds, we have to sacrifice network capacity. Fortunately, most of the information stored in the representation is highly redundant. This motivates the use of implicit neural representations.

Implicit Neural Representations

Implicit Neural Representations (INRs) are a class of models that represent high-dimensional data as a function of the coordinates. The function is typically represented as a neural network that takes in the coordinates and outputs a scalar value. The scalar value is then used to compute the color and opacity of the pixel. For redundant data, the data is encoded succintly in the weights of the neural network.

inr

Figure: The data (brain on left) is encoded into an implicit neural representation (colored neurons on the right). The representation is typically weights of a neural network which resolved a coordinate to a scalar or vector value corresponding to the data.

In principle, a neural network can then process the encoded data for downstream tasks.

inr tasks

Figure: A typical network operating on the implicit representation would take the INR of the image as input, and emit the INR of the segmentation map as output. If the network operates entirely in the INR space, it can work with larger images out of the box.

However, there are two main challenges with the formulation above:

  • The implicit representation has no spatial locality unlike the data that it represents since all nodes encode the data at all locations. This makes it difficult to leverage the inductive bias of the original data (which is why we use convolutional networks in the first place).
  • The INR representations for different images do not share the same basis functions. The neural representation might encode data differently for different images – making it hard for a downstream network to infer the basis functions used to encode the data.

Instant NGP

Instant NGP was proposed to represent high resolution data like gigapixel images, signed distance fields, and NeRFs using an implicit representation. The idea is to use a hash function that converts quantized spatial coordinates to a set of indices that are used to index into a codebook. This is done across multiple resolutions and the codebook entries are concatenated together to form a single vector for a given coordinate. The vector is then fed into an MLP decoder to get the final output. The codebook and decoder are trained end-to-end to minimize the reconstruction loss.

Figure: Architecture of Instant NGP. Unlike other neural representations, this representation has some spatial locality since only a few fixed entries of the codebook are used for a given coordinate, solving our first problem. To solve the second problem, we initialize a different codebook for each image (which forms our implicit representation), and use a shared MLP decoder (see below).

Figure: The codebooks are initialized with random weights, and the decoder is shared across all codebooks.

This design ensures that different images are decoded the same way, solving our second problem – now multiple codebooks lie in the same encoded space. Given a dataset of $N$ images, we perform a joint optimization of $N$ codebooks and a shared MLP decoder to encode the data.

This deviates from how Implicit Neural Representations are typically trained, where a single datum is considered at a time. However, this idea allows compressing the data into a spatially local representation amenable to convolutional operations.

Does joint optimization regress performance?

📝 Note: I’m using 3D images from the BRATs brain MRI dataset as a representative example. But other data like scene representations, NeRFs, SDFs, can be used as well.

The first experiment is to train $N$ encoders and a single decoder jointly.

My intuition was that as we add more encoders and a single decoder, the reconstruction performance would degrade. However, I saw quite the opposite trend - the average PSNR of the reconstructions improved as we added more encoders.

Figure: The average PSNR of the reconstructions improves as we add more encoders.

This is an interesting result, suggesting that the decoder is not throttled by the number of encoders, but rather leverages shared knowledge across the different codebooks to reconstruct each datum better.

I suspected some images might be easier to learn than others and might be skewing this plot, so I plotted the PSNR for each image individually.

Figure: The PSNR of the reconstructions for each image individually. Almost all images benefit from training with more encoders.

Indeed, some images are easier to learn than others (as seen by the different PSNR bands), but most of the PSNR plots go up as we add more encoders. Even the “worst” images have a PSNR of more than 35dB!!

This is the first promising result - we are not losing reconstruction performance as the dataset size increases, in fact, quite the contrary.

Convolutional Neural Networks on Instant NGP codebooks

The second (and more interesting) piece of the puzzle is to actually define the convolutional operator on the implicit representation. This is where we dive into the dirty details.

The coordinate encoding

The data or scene is assumed to live inside a continuous bounding box $x_0 \in \mathcal{B} = [0, 1]^d$. The implicit representation defines a set of levels with resolutions $L_1, L_2, \ldots, L_K$. At level $k$, the input space $\mathcal{B}$ is quantized to a regular grid of size $L_k^d$, and the feature vector at a location $x_k = L_k x_0$ is computed using bilinear interpolation from the values at the 8 nearest grid points.

Random hash function

The index in the codebooks for a quantized grid location $q$ is computed as:

\[h_k(q) = \left( \bigoplus_{i=1}^{d} q(i) \pi_i \right) \bmod T,\]

where $\pi$ is the set of pre-determined prime numbers, $q(i)$ is the $i$-th coordinate of the quantized grid location $q$, and $T$ is the total number of codebook entries.

Tiled hash function

Another alternative hash function is to simply compute the flattened index of the quantized grid location $q$ as:

\[h_k(q) = \left( \sum_{i=1}^{d} q(i) L_k^{(i-1)} \right) \bmod T,\]

Convolution operator

Recall that the convolution operator is defined as:

\[g(q) = \sum_{i=1}^{d} w(\Delta q) f(q - \Delta q),\]

where $w$ is the kernel function and $f$ is the input feature map, and $\Delta q$ is the difference between the neighboring locations.

For a tiled hash function, note that if $s = h(q)$ is the entry corresponding to location $q$, then the hash for the neighboring locations $q + \Delta q$ is given by:

\[\begin{align*} h(q + \Delta q) &= \left( \sum_{i=1}^{d} (q(i) + \Delta q(i)) L_k^{(i-1)} \right) \bmod T \\ &= \left( \sum_{i=1}^{d} q(i) L_k^{(i-1)} + \sum_{i=1}^{d} \Delta q(i) L_k^{(i-1)} \right) \bmod T \\ &= \big(s + h(\Delta q)\big) \bmod T \end{align*}\]

📝 Note: The hash function for the neighbors of a location $q$ are simply constant offsets $h(\Delta q)$ added to the hash of the original location $q$.

Therefore, the convolution for a input codebook $f$ at location $s$ can be computed as:

\[g(s) = \sum_{i=1}^{d} w(\Delta q) f((s + h(\Delta q)) \bmod T)\]

Unlike in convolution over a spatially regular grid, the 3D `convolution’ here is actually performed on a 1D codebook with irregular (but constant) offsets determined by the hash function. This motivates the use of a fast convolution operator on the Implicit Neural Representations.

CUDA implementation

Now that we got the math out of the way, let’s implement the convolution operator in CUDA. The code is available here.

First, we need to convert a 3D quantized coordinate to a hash index using the tiled hash function.

1
2
3
4
5
6
__device__ __forceinline__ int compute_ravel_hash(const int *coord, const int resolution, const int hashmap_size) {
    // compute hash function for tiled grid
    int index;
    index = coord[0] + resolution*(coord[1] + resolution*coord[2]);
    return modpow2(index, hashmap_size);
}

We also compute the hash values for the neighboring locations.

1
2
3
4
5
6
__device__ __forceinline__ int compute_diff_hash(const int k1, const int k2, const int k3, const int lvl_res, const int hashmap_size) {
    int index = k1 + lvl_res*(k2 + lvl_res*k3);
    if(index >= 0)
        return index;
    return (hashmap_size - (modpow2(-index, hashmap_size)));
}

Note that the neighboring coordinates and the corresponding hash offset index can be negative (say for e.g. $\Delta q = (k_1, k_2, k_3) = (-1, -1, -1)$). If the index $h(\Delta q)$ is negative, we need to compute the value

\[\begin{align*} h(q + \Delta q) &= \big(h(q) + h(\Delta q)\big) \bmod T \\ &= \left[h(q) \bmod T + \textcolor{red}{\left((T - \left[-h(\Delta q)\right]) \bmod T\right)}\right] \bmod T \\ \end{align*}\]

using the fact that $-h(\Delta q)$ is positive, and computing the value in red.

Now let’s dive into the core logic of the forward pass of the convolution operator.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    CUDA_KERNEL_LOOP(index, num_outputs) {
        // get n, batch index, and output channel index
        // int c_out = index % output_channels;
        int c_out = modpow2(index, output_channels);    // assume channels are powers of 2 
        int ibyout = divpow2(index, output_channels);
        int b_idx = modpow2(ibyout, batch_size);
        int n_idx = divpow2(ibyout, batch_size);
        // get level information
        int level = get_level(offsets_shared, n_idx, num_levels);
        int offset_lvl = offsets_shared[level];
        int local_n = n_idx - offset_lvl;
        int lvl_res = resolutions_shared[level];
        int lvl_res3 = lvl_res*lvl_res*lvl_res;
        // if this is a tail end, skip t
        if(local_n >= lvl_res3)
            continue;
        // now we have n, b, c --> time to get y[n, b, cout]
        scalar_t res = 0;
        int weight_index = level*kernel_volume*iosize + c_out;
        // cache this result in the beginning
        int coordstart[3];
        unravel_index(local_n, lvl_res, coordstart);

This code simply resolves the thread index index into the batch index b_idx, channel index c_out and the codebook entry index n_idx. The global index n_idx is resolved into a local codebook index local_n and the level level of the codebook. Finally, we unravel (x, y, z) = local_n to start the convolution.

The actual convolution is simply a loop

1
2
3
4
5
6
7
8
9
10
11
12
int fwd_index = n_idx * kernel_volume;   // set to n_idx, and loop over all kernel indices
for(int k1idx=0; k1idx<K1; k1idx++) {
    int k1 = k1idx - K1/2;
    for(int k2idx=0; k2idx<K2; k2idx++) {
        int k2 = k2idx - K2/2;
        for(int k3idx=0; k3idx<K3; k3idx++) {
            int k3 = k3idx - K3/2;
            // get neighboring index
            ...
        }
    }
}

This is simply a loop over the kernel indices $(k_1, k_2, k_3)$. Note that we do not perform tiling operations since the actual codebook entries are not spatially colocated. A tiled implementation using shared memory is also present in the code, but it is not as performant.

Inside the loop, this snippet locates the index of the neighboring entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int coord[3];
#pragma unroll
for(int i=0; i<3; i++)
    coord[i] = coordstart[i];
// resolve x index
if(lvl_res3 > hashmap_size) {   // this is a big resolution, simply compute the (x + dx) % T
    x_index = modpow2((local_n + compute_diff_hash(k1, k2, k3, lvl_res, hashmap_size)), hashmap_size) + offset_lvl;
}
else {  // only one point corresponds to this n, find it
    coord[0] += k1; coord[1] += k2; coord[2] += k3;
    if(out_of_bounds(coord, lvl_res))
        x_index = -1;
    else
        x_index = compute_ravel_hash(coord, lvl_res, hashmap_size) + offset_lvl;
}

Pretty straightforward, right? If the resolution is large (if-branch), then compute the hash index as defined before. Otherwise, there is a one-to-one mapping between the codebook entry and the location, we can simply add the kernel offset to the coordinate and compute the hash index.

The backward pass is implemented similarly. For more details, please refer to the code. I have tested four versions of the forward pass, all of them are in the code.

Building a 3D Implicit CNN

Once we defined the convolution operator, we can build a 3D Implicit CNN. Similar to the AbstractConv3D class, we also define a pooling layer defined in AbstractContextLayer class that pools features from finer grids to coarser grids. We also define an AbstractLayerNorm class that performs layer normalization on the features per level. Our new implicit ResNet looks like this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Resblock(nn.Module):
    '''
    General residual block containing convolutions+non-linearity followed by residual block 
    the residual block can either be a simple per-grid layer, or a context layer which interpolates
    features from the previous grid
    '''
    def __init__(self, in_channels, out_channels, resolutions, offsets, layers=2, context=True, affine_context=True, 
                 kernel_size=3,
                 num_levels=16, log_hashmap_size=19, activation=nn.LeakyReLU(negative_slope=0.1),
                 layernorm=False):
        super().__init__()
        self.context = None
        self.activation = activation
        self.layernorm = layernorm
        if context:
            self.context = AbstractContextLayer(in_channels, out_channels, resolutions=resolutions, offsets=offsets, \
                                    affine=(in_channels!=out_channels or affine_context), num_levels=num_levels, log_hashmap_size=log_hashmap_size)
        else:
            self.context = nn.Linear(in_channels, out_channels)
            nn.init.kaiming_uniform_(self.context.weight)
            nn.init.zeros_(self.context.bias)
        # define convolutions
        convs = []
        lns = []
        for _ in range(layers):
            convs.append(AbstractConv3D(in_channels, out_channels, resolutions=resolutions, offsets=offsets,
                                        kernel_size=kernel_size, num_levels=num_levels, log_hashmap_size=log_hashmap_size))
            if layernorm:
                lns.append(AbstractLayerNorm(out_channels, resolutions, offsets))
            in_channels = out_channels
        self.convs = nn.ModuleList(convs)
        if layernorm:
            self.lns = nn.ModuleList(lns)

Note that the inputs of the Resnet block also include resolutions and offsets which are defined for the Implicit NGP codebooks.

Performance Benchmark

We ablate the performance of the 3D Implicit Convolution with respect to two parameters: the maximum codebook resolution $L$ and desired resolution $D$. Lower values of $L$ mean more compressed representations, and higher values of $D$ mean the latent grid is high-resolution implying more compression for a fixed codebook size.

The ablations on the maximum codebook resolution $L$ are shown below:

LFast FwdSlow FwdFast BwdSlow BwdFwd SpeedupBwd Speedup
170.03010.59400.02030.221419.747210.9056
180.01110.07970.02810.20477.16657.2932
190.01920.12220.04800.32576.36166.7842
200.03290.19890.08830.51616.03635.8422
210.05460.26730.14690.67194.89814.5728
220.08770.39550.23230.94714.50924.0777

Our method is consistently 6-7x faster than the PyTorch implementation for $L = 19, 20$ (default values), showing the efficiency of the CUDA implementation for forward pass based on non-coalesced memory access.

The ablations on the desired resolution $D$ are shown below:

ResolutionCUDA FwdPyTorch FwdCUDA BwdPyTorch BwdFwd SpeedupBwd Speedup
320.00110.00660.00480.00765.90261.5876
480.00240.00610.00730.01042.55901.4132
640.00470.00760.01310.01941.63271.4824
960.01090.04210.02680.09713.84803.6272
1280.01380.07370.03560.17785.36064.9871
1960.01720.10620.04490.27266.16406.0705
2560.01900.12400.05000.32696.50946.5437
3840.02100.13880.05550.38596.60496.9488

In the above table, we see that the CUDA implementation is consistently faster than the PyTorch implementation for all resolutions. The speedup is more pronounced for higher resolutions because higher desired resolutions imply more compression at the same codebook size.

Conclusion

In this blog post, we introduced NGPConv, a method for scaling 3D convolutional networks to process much larger volumes by leveraging implicit neural representations. By using Instant NGP’s hash-based encoding with shared decoders, we can represent high-resolution 3D data efficiently while maintaining spatial locality and consistent basis functions across samples. This allows us to perform convolutions directly on the compressed implicit representation rather than the full voxel grid.

Our CUDA implementation demonstrates significant speedups over a naive PyTorch implementation, achieving 6-7x faster forward passes and 4-6x faster backward passes for typical parameter settings. The performance benefits are particularly pronounced at higher resolutions, where the compression ratio is greater. This makes NGPConv practical for processing large 3D volumes that were previously intractable, like high-resolution medical imaging data or neural radiance fields.

The approach opens up new possibilities for 3D deep learning by breaking through the memory bottleneck that has traditionally limited 3D ConvNet architectures. Future work could explore applying this technique to other 3D learning tasks and further optimizing the implementation for even better performance.

This post is licensed under CC BY 4.0 by the author.