# Copyright 2021, Blue Brain Project, EPFL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Volume interpolation based on pairwise interpolation between slices."""
from __future__ import annotations
import logging
import warnings
from abc import ABC, abstractmethod
from math import ceil, log2
import numpy as np
import torch
from torchvision.transforms import ToTensor
from atlinter.data import GeneDataset
from atlinter.utils import find_closest
logger = logging.getLogger(__name__)
[docs]class PairInterpolationModel(ABC):
"""Base class for pair-interpolation models.
Subclasses of this class implement an interpolation between two given
images `img1` and `img2` to produce and intermediate image `img_mid`.
This class and its subclasses are used by the PairInterpolate class,
which applies a given interpolation model to concrete data.
"""
[docs] def before_interpolation(self, img1, img2):
"""Run initialization and pre-processing steps before interpolation.
Typical applications of this method are padding and cropping of
input images to fit the model requirements, as well as initialisation
of any internal state, should one be necessary.
Parameters
----------
img1 : np.ndarray
The left image of shape (width, height)
img2 : np.ndarray
The right image of shape (width, height).
Returns
-------
img1 : np.ndarray
The pre-processed left image.
img2 : np.ndarray
The pre-processed right image.
"""
return img1, img2
[docs] @abstractmethod
def interpolate(self, img1, img2):
"""Interpolate two images.
In the typical setting the input images are going to be of the format
as returned by the `before_interpolation`.
Parameters
----------
img1 : np.ndarray
The left image.
img2 : np.ndarray
The right image.
Returns
-------
img_mid : np.ndarray
The interpolated image.
"""
[docs] def after_interpolation(self, interpolated_images):
"""Run any post-processing after all interpolation is done.
Typical applications are padding and cropping of the image stack,
as well as any clean-up of the model state.
Parameters
----------
interpolated_images : np.ndarray
The stacked interpolated images. The array will include the input
images as the first and the last items respectively and will
therefore have the shape (n_interpolated + 2, height, width)
Returns
-------
np.ndarray
The post-processed interpolated images.
"""
return interpolated_images
[docs]class LinearPairInterpolationModel(PairInterpolationModel):
"""Linear pairwise interpolation.
This is the simplest possible interpolation model where the middle
image is the average of the left and right images.
"""
[docs] def interpolate(self, img1, img2):
"""Interpolate two images using linear interpolation.
Parameters
----------
img1 : np.ndarray
The left image.
img2 : np.ndarray
The right image.
Returns
-------
img_mid : np.ndarray
The interpolated image.
"""
img_mid = np.mean([img1, img2], axis=0)
return img_mid
[docs]class RIFEPairInterpolationModel(PairInterpolationModel):
"""Pairwise image interpolation using the RIFE model.
The typical use is
>>> from atlinter.vendor.rife.RIFE_HD import Model as RifeModel
>>> from atlinter.vendor.rife.RIFE_HD import device as rife_device
>>> rife_model = RifeModel()
>>> rife_model.load_model("/path/to/train_log", -1)
>>> rife_model.eval()
>>> rife_interpolation_model = RIFEPairInterpolationModel(rife_model, rife_device)
Parameters
----------
rife_model : atlinter.vendor.rife.RIFE_HD.Model
The RIFE model instance.
rife_device : from atlinter.vendor.rife.RIFE_HD.device
The RIFE device.
"""
def __init__(self, rife_model, rife_device):
# The behaviour of torch.nn.functional.interpolate has slightly changed,
# which leads to this warning. It doesn't seem to have an impact on the
# results, but if the authors of RIFE decide to update their code base
# by either specifying the `recompute_scale_factor` parameter or by
# some other means, then this warning filter should be removed.
# TODO: check the RIFE code for updates and remove the filter if necessary.
warnings.filterwarnings(
"ignore",
"The default behavior for interpolate/upsample with float scale_factor",
UserWarning,
)
self.rife_model = rife_model
self.rife_device = rife_device
self.shape = (0, 0)
[docs] def before_interpolation(self, img1, img2):
"""Pad input images to a multiple of 32 pixels.
Parameters
----------
img1 : np.ndarray
The left image of shape.
img2 : np.ndarray
The right image of shape.
Returns
-------
img1 : np.ndarray
The padded left image.
img2 : np.ndarray
The padded right image.
"""
image_shape = img1.shape
if len(image_shape) == 3 and image_shape[-1] == 3:
rgb = True
image_shape = image_shape[:-1]
else:
rgb = False
self.shape = np.array(image_shape)
pad_x, pad_y = ((self.shape - 1) // 32 + 1) * 32 - self.shape
if rgb:
img1 = np.pad(img1, ((0, pad_x), (0, pad_y), (0, 0)))
img2 = np.pad(img2, ((0, pad_x), (0, pad_y), (0, 0)))
else:
img1 = np.pad(img1, ((0, pad_x), (0, pad_y)))
img2 = np.pad(img2, ((0, pad_x), (0, pad_y)))
return img1, img2
[docs] def interpolate(self, img1, img2):
"""Interpolate two images using RIFE.
Note: img1 and img2 needs to have the same shape.
If img1, img2 are grayscale, the dimension should be (height, width).
If img1, img2 are RGB image, the dimension should be (height, width, 3).
Parameters
----------
img1 : np.ndarray
The left image.
img2 : np.ndarray
The right image.
Returns
-------
img_mid : np.ndarray
The interpolated image.
"""
# Add batch and RGB dimensions (if not already), set device
if len(img1.shape) == 2:
rgb = False
img1 = (
torch.tensor(img1, dtype=torch.float32)
.repeat((1, 3, 1, 1))
.to(self.rife_device)
)
img2 = (
torch.tensor(img2, dtype=torch.float32)
.repeat((1, 3, 1, 1))
.to(self.rife_device)
)
else:
rgb = True
img1 = np.transpose(img1, (2, 0, 1))[np.newaxis]
img2 = np.transpose(img2, (2, 0, 1))[np.newaxis]
img1 = torch.tensor(img1, dtype=torch.float32).to(self.rife_device)
img2 = torch.tensor(img2, dtype=torch.float32).to(self.rife_device)
# The actual interpolation
img_mid = self.rife_model.inference(img1, img2).detach().cpu().numpy()
img_mid = img_mid.squeeze()
if rgb:
# Put the RGB channel at the end
img_mid = np.transpose(img_mid, (1, 2, 0))
else:
# Average out the RGB dimension
img_mid = img_mid.mean(axis=0)
return img_mid
[docs] def after_interpolation(self, interpolated_images):
"""Undo the padding added in `before_interpolation`.
Parameters
----------
interpolated_images : np.ndarray
The stacked interpolated images.
If input images are grayscale,
the dimension should be (n_img, height, width) or (height, width).
If input images are RGB image,
the dimension should be (n_img, height, width, 3) or (height, width, 3).
Returns
-------
np.ndarray
The stacked interpolated images with padding removed.
"""
# No n_img dimension: (height, width) or (height, width, 3)
if len(interpolated_images.shape) == 2 or (
len(interpolated_images.shape) == 3 and interpolated_images.shape[-1] == 3
):
return interpolated_images[: self.shape[0], : self.shape[1]]
# n_img dimension: (n_img, height, width) or (n_img, height, width, 3)
else:
return interpolated_images[:, : self.shape[0], : self.shape[1]]
[docs]class CAINPairInterpolationModel(PairInterpolationModel):
"""Pairwise image interpolation using the CAIN model.
The typical use is
>>> from atlinter.vendor.cain.cain import CAIN
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> cain_model = CAIN().to(device)
>>> cain_checkpoint = torch.load("pretrained_cain.pth", map_location=device)
>>> cain_model.load_state_dict(cain_checkpoint)
>>> cain_interpolation_model = CAINPairInterpolationModel(cain_model)
Parameters
----------
cain_model : atlinter.vendor.cain.cain.CAIN or torch.nn.DataParallel
The CAIN model instance.
"""
def __init__(self, cain_model):
self.cain_model = cain_model
self.to_tensor = ToTensor()
[docs] def interpolate(self, img1, img2):
"""Interpolate two images using CAIN.
Note: img1 and img2 needs to have the same shape.
If img1, img2 are grayscale, the dimension should be (height, width).
If img1, img2 are RGB image, the dimension should be (height, width, 3).
Parameters
----------
img1 : np.ndarray
The left image.
img2 : np.ndarray
The right image.
Returns
-------
img_mid : np.ndarray
The interpolated image.
"""
# Add batch and RGB dimensions
if len(img1.shape) == 2:
rgb = False
img1 = self.to_tensor(img1).repeat((1, 3, 1, 1))
img2 = self.to_tensor(img2).repeat((1, 3, 1, 1))
else:
rgb = True
img1 = self.to_tensor(np.transpose(img1, (2, 0, 1)))[None]
img2 = self.to_tensor(np.transpose(img2, (2, 0, 1)))[None]
# The actual interpolation
img_mid, _ = self.cain_model(img1, img2)
img_mid = img_mid.detach().cpu().numpy()
img_mid = img_mid.squeeze()
if rgb:
# Put the RGB channel at the end
img_mid = np.transpose(img_mid, (1, 2, 0))
else:
# Average out the RGB dimension
img_mid = img_mid.mean(axis=0)
return img_mid
[docs]class AntsPairInterpolationModel(PairInterpolationModel):
"""Pairwise image interpolation using AntsPy registration.
Typical use is
>>> from atlannot.ants import register, transform
>>> ants_interpolation_model = AntsPairInterpolationModel(register, transform)
Parameters
----------
register_fn : atlannot.ants.register
The AntsPy registration function
transform_fn : atlannot.ants.transform
The AntsPy transformation function
"""
def __init__(self, register_fn, transform_fn):
self.register_fn = register_fn
self.transform_fn = transform_fn
[docs] def interpolate(self, img1, img2):
"""Interpolate two images using AntsPy registration.
Parameters
----------
img1 : np.ndarray
The left image.
img2 : np.ndarray
The right image.
Returns
-------
img_mid : np.ndarray
The interpolated image.
"""
# Ensure the correct d-type
img1 = img1.astype(np.float32)
img2 = img2.astype(np.float32)
# The actual interpolation
nii_data = self.register_fn(fixed=img2, moving=img1)
img_mid = self.transform_fn(img1, nii_data / 2)
return img_mid
[docs]class PairInterpolate:
"""Runner for pairwise interpolation using different models.
Parameters
----------
n_repeat : int (optional)
The number of times the interpolation should be iterated. For each
iteration an interpolated image is inserted between each pair of
images from the previous iteration. Therefore n_{i+1} = n_i + (n_i + 1).
For example, for n_repeat=3 the progression of the number of images
will be the following: input = 0 -> 1 -> 3 -> 7
"""
def __init__(self, n_repeat=1):
self.n_repeat = n_repeat
[docs] def repeat(self, n_repeat):
"""Set the number of interpolation iterations.
Parameters
----------
n_repeat : int
The new number of interpolation iterations. See `__init__` for more
details.
"""
self.n_repeat = n_repeat
return self
def __call__(self, img1, img2, model: PairInterpolationModel):
"""Run the interpolation with the given interpolation model.
Parameters
----------
img1 : np.ndarray
The left input image.
img2 : np.ndarray
The right input image.
model : PairInterpolationModel
The interpolation model.
Returns
-------
interpolated_images : np.ndarray
A stack of interpolation images. The input images are not included
in this stack.
"""
img1, img2 = model.before_interpolation(img1, img2)
interpolated_images = self._repeated_interpolation(
img1, img2, model, self.n_repeat
)
interpolated_images = np.stack(interpolated_images)
interpolated_images = model.after_interpolation(interpolated_images)
return interpolated_images
def _repeated_interpolation(self, img1, img2, model, n_repeat):
# End of recursion
if n_repeat <= 0:
return []
# Recursion step
img_mid = model.interpolate(img1, img2)
left_images = self._repeated_interpolation(img1, img_mid, model, n_repeat - 1)
right_images = self._repeated_interpolation(img_mid, img2, model, n_repeat - 1)
return [*left_images, img_mid, *right_images]
[docs]class GeneInterpolate:
"""Interpolation of a gene dataset.
Parameters
----------
gene_data : GeneData
Gene Dataset to interpolate. It contains a `volume` of reference shape
with all known places located at the right place and a `metadata` dictionary
containing information about the axis of the dataset and the section numbers.
model : PairInterpolationModel
Pair-interpolation model.
border_predictions: boolean
If False, slices before the first and after the last known slice are
just background. Otherwise, a copy of the extremity slice is done.
"""
def __init__(
self,
gene_data: GeneDataset,
model: PairInterpolationModel,
border_predictions: bool = True,
):
self.gene_data = gene_data
self.model = model
self.border_predictions = border_predictions
self.axis = self.gene_data.axis
self.gene_volume = self.gene_data.volume.copy()
# If sagittal axis, put the sagittal dimension first
if self.axis == "sagittal":
self.gene_volume = np.moveaxis(self.gene_volume, 2, 0)
[docs] def get_interpolation(
self, left: int, right: int
) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Compute the interpolation for a pair of images.
Parameters
----------
left
Slice number of the left image to consider.
right
Slice number of the right image to consider.
Returns
-------
interpolated_images : np.array or None
Interpolated image for the given pair of images.
Array of shape (N, dim1, dim2, 3) with N the number of
interpolated images.
predicted_section_numbers : np.array or None
Slice value of the predicted images.
Array of shape (N, 1) with N the number of interpolated images.
"""
diff = right - left
if diff == 0:
return None, None
n_repeat = self.get_n_repeat(diff)
pair_interpolate = PairInterpolate(n_repeat=n_repeat)
interpolated_images = pair_interpolate(
self.gene_volume[left], self.gene_volume[right], self.model
)
predicted_section_numbers = self.get_predicted_section_numbers(
left, right, n_repeat
)
return interpolated_images, predicted_section_numbers
[docs] def get_all_interpolation(self) -> tuple[np.ndarray, np.ndarray]:
"""Compute pair interpolation for the entire volume.
Returns
-------
all_interpolated_images : np.array
Interpolated image for the entire volume.
Array of shape (N, dim1, dim2, 3) with N the number of
interpolated images.
all_predicted_section_numbers : np.array
Slice value of the predicted images.
Array of shape (N, 1) with N the number of interpolated images.
"""
# TODO: Try to change the implementation of the prediction so that
# we do not predict slices that are not needed.
logger.info("Start predicting interpolation between two known slices")
known_slices = sorted(self.gene_data.known_slices)
all_interpolated_images = []
all_predicted_section_numbers = []
for i in range(len(known_slices) - 1):
left, right = known_slices[i], known_slices[i + 1]
(
interpolated_images,
predicted_section_numbers,
) = self.get_interpolation(left, right)
if interpolated_images is None:
continue
all_interpolated_images.append(interpolated_images)
all_predicted_section_numbers.append(predicted_section_numbers)
if i % 5 == 0:
logger.info(f"{i} / {len(known_slices) - 1} interpolations predicted")
all_interpolated_images = np.concatenate(all_interpolated_images)
all_predicted_section_numbers = np.concatenate(all_predicted_section_numbers)
return all_interpolated_images, all_predicted_section_numbers
[docs] def predict_slice(self, slice_number: int) -> np.ndarray:
"""Predict one gene slice.
Parameters
----------
slice_number
Slice section to predict.
Returns
-------
np.ndarray
Predicted gene slice. Array of shape (dim1, dim2, 3)
being (528, 320) for sagittal dataset and
(320, 456) for coronal dataset.
"""
left, right = self.gene_data.get_surrounding_slices(slice_number)
if left is None:
return self.gene_volume[right]
elif right is None:
return self.gene_volume[left]
else:
interpolated_images, predicted_section_numbers = self.get_interpolation(
left, right
)
index = find_closest(slice_number, predicted_section_numbers)[0]
return interpolated_images[index]
[docs] def predict_volume(self) -> np.ndarray:
"""Predict entire volume with known gene slices.
This function might be slow.
"""
volume_shape = self.gene_data.volume_shape
volume = np.zeros(volume_shape, dtype="float32")
logger.info(f"Start predicting the volume of shape {volume_shape}")
if self.gene_data.axis == "sagittal":
volume = np.moveaxis(volume, 2, 0)
# Get all the predictions
(
all_interpolated_images,
all_predicted_section_numbers,
) = self.get_all_interpolation()
min_slice_number = min(self.gene_data.known_slices)
max_slice_number = max(self.gene_data.known_slices)
end = volume_shape[0] if self.gene_data.axis == "coronal" else volume_shape[2]
# Populate the volume
logger.info("Populate volume with interpolation predictions")
for slice_number in range(end):
# If the slice is known, just copy the gene.
if slice_number in self.gene_data.known_slices:
volume[slice_number] = self.gene_volume[slice_number]
# If the slice section is smaller than all known slice
# We copy-paste the smallest known slice if border_predictions
# is True, else keep background slices.
elif slice_number < min_slice_number:
if self.border_predictions:
volume[slice_number] = self.gene_volume[min_slice_number]
# If the slice section is bigger than all known slice
# We copy-paste the biggest known slice if border_predictions
# is True, else keep background slices.
elif slice_number > max_slice_number:
if self.border_predictions:
volume[slice_number] = self.gene_volume[max_slice_number]
# If the slice is surrounded by two known slice.
# Determine the prediction closest to the slice section.
else:
index = find_closest(slice_number, all_predicted_section_numbers)[0]
volume[slice_number] = all_interpolated_images[index]
if slice_number % 5 == 0:
logger.info(f"{slice_number} / {end} populated slices")
if self.gene_data.axis == "sagittal":
volume = np.moveaxis(volume, 0, 2)
return volume
[docs] @staticmethod
def get_n_repeat(diff: int) -> int:
"""Determine the number of repetitions to compute."""
if diff <= 0:
return 0
n_repeat = ceil(log2(diff))
return n_repeat
[docs] @staticmethod
def get_predicted_section_numbers(
left: int, right: int, n_repeat: int
) -> np.ndarray:
"""Get slice values of predicted images."""
n_steps = 2**n_repeat + 1
predicted_section_numbers = np.linspace(left, right, n_steps)
return predicted_section_numbers[1:-1]