atlinter.pair_interpolation module

Volume interpolation based on pairwise interpolation between slices.

class atlinter.pair_interpolation.AntsPairInterpolationModel(register_fn, transform_fn)[source]

Bases: PairInterpolationModel

Pairwise image interpolation using AntsPy registration.

Typical use is >>> from atlannot.ants import register, transform >>> ants_interpolation_model = AntsPairInterpolationModel(register, transform)

Parameters:
  • register_fn (atlannot.ants.register) – The AntsPy registration function

  • transform_fn (atlannot.ants.transform) – The AntsPy transformation function

interpolate(img1, img2)[source]

Interpolate two images using AntsPy registration.

Parameters:
  • img1 (np.ndarray) – The left image.

  • img2 (np.ndarray) – The right image.

Returns:

img_mid – The interpolated image.

Return type:

np.ndarray

class atlinter.pair_interpolation.CAINPairInterpolationModel(cain_model)[source]

Bases: PairInterpolationModel

Pairwise image interpolation using the CAIN model.

The typical use is >>> from atlinter.vendor.cain.cain import CAIN >>> device = “cuda” if torch.cuda.is_available() else “cpu” >>> cain_model = CAIN().to(device) >>> cain_checkpoint = torch.load(“pretrained_cain.pth”, map_location=device) >>> cain_model.load_state_dict(cain_checkpoint) >>> cain_interpolation_model = CAINPairInterpolationModel(cain_model)

Parameters:

cain_model (atlinter.vendor.cain.cain.CAIN or torch.nn.DataParallel) – The CAIN model instance.

interpolate(img1, img2)[source]

Interpolate two images using CAIN.

Note: img1 and img2 needs to have the same shape. If img1, img2 are grayscale, the dimension should be (height, width). If img1, img2 are RGB image, the dimension should be (height, width, 3).

Parameters:
  • img1 (np.ndarray) – The left image.

  • img2 (np.ndarray) – The right image.

Returns:

img_mid – The interpolated image.

Return type:

np.ndarray

class atlinter.pair_interpolation.GeneInterpolate(gene_data: GeneDataset, model: PairInterpolationModel, border_predictions: bool = True)[source]

Bases: object

Interpolation of a gene dataset.

Parameters:
  • gene_data (GeneData) – Gene Dataset to interpolate. It contains a volume of reference shape with all known places located at the right place and a metadata dictionary containing information about the axis of the dataset and the section numbers.

  • model (PairInterpolationModel) – Pair-interpolation model.

  • border_predictions (boolean) – If False, slices before the first and after the last known slice are just background. Otherwise, a copy of the extremity slice is done.

get_all_interpolation() tuple[np.ndarray, np.ndarray][source]

Compute pair interpolation for the entire volume.

Returns:

  • all_interpolated_images (np.array) – Interpolated image for the entire volume. Array of shape (N, dim1, dim2, 3) with N the number of interpolated images.

  • all_predicted_section_numbers (np.array) – Slice value of the predicted images. Array of shape (N, 1) with N the number of interpolated images.

get_interpolation(left: int, right: int) tuple[np.ndarray | None, np.ndarray | None][source]

Compute the interpolation for a pair of images.

Parameters:
  • left – Slice number of the left image to consider.

  • right – Slice number of the right image to consider.

Returns:

  • interpolated_images (np.array or None) – Interpolated image for the given pair of images. Array of shape (N, dim1, dim2, 3) with N the number of interpolated images.

  • predicted_section_numbers (np.array or None) – Slice value of the predicted images. Array of shape (N, 1) with N the number of interpolated images.

static get_n_repeat(diff: int) int[source]

Determine the number of repetitions to compute.

static get_predicted_section_numbers(left: int, right: int, n_repeat: int) ndarray[source]

Get slice values of predicted images.

predict_slice(slice_number: int) ndarray[source]

Predict one gene slice.

Parameters:

slice_number – Slice section to predict.

Returns:

Predicted gene slice. Array of shape (dim1, dim2, 3) being (528, 320) for sagittal dataset and (320, 456) for coronal dataset.

Return type:

np.ndarray

predict_volume() ndarray[source]

Predict entire volume with known gene slices.

This function might be slow.

class atlinter.pair_interpolation.LinearPairInterpolationModel[source]

Bases: PairInterpolationModel

Linear pairwise interpolation.

This is the simplest possible interpolation model where the middle image is the average of the left and right images.

interpolate(img1, img2)[source]

Interpolate two images using linear interpolation.

Parameters:
  • img1 (np.ndarray) – The left image.

  • img2 (np.ndarray) – The right image.

Returns:

img_mid – The interpolated image.

Return type:

np.ndarray

class atlinter.pair_interpolation.PairInterpolate(n_repeat=1)[source]

Bases: object

Runner for pairwise interpolation using different models.

Parameters:

n_repeat (int (optional)) – The number of times the interpolation should be iterated. For each iteration an interpolated image is inserted between each pair of images from the previous iteration. Therefore n_{i+1} = n_i + (n_i + 1). For example, for n_repeat=3 the progression of the number of images will be the following: input = 0 -> 1 -> 3 -> 7

repeat(n_repeat)[source]

Set the number of interpolation iterations.

Parameters:

n_repeat (int) – The new number of interpolation iterations. See __init__ for more details.

class atlinter.pair_interpolation.PairInterpolationModel[source]

Bases: ABC

Base class for pair-interpolation models.

Subclasses of this class implement an interpolation between two given images img1 and img2 to produce and intermediate image img_mid.

This class and its subclasses are used by the PairInterpolate class, which applies a given interpolation model to concrete data.

after_interpolation(interpolated_images)[source]

Run any post-processing after all interpolation is done.

Typical applications are padding and cropping of the image stack, as well as any clean-up of the model state.

Parameters:

interpolated_images (np.ndarray) – The stacked interpolated images. The array will include the input images as the first and the last items respectively and will therefore have the shape (n_interpolated + 2, height, width)

Returns:

The post-processed interpolated images.

Return type:

np.ndarray

before_interpolation(img1, img2)[source]

Run initialization and pre-processing steps before interpolation.

Typical applications of this method are padding and cropping of input images to fit the model requirements, as well as initialisation of any internal state, should one be necessary.

Parameters:
  • img1 (np.ndarray) – The left image of shape (width, height)

  • img2 (np.ndarray) – The right image of shape (width, height).

Returns:

  • img1 (np.ndarray) – The pre-processed left image.

  • img2 (np.ndarray) – The pre-processed right image.

abstract interpolate(img1, img2)[source]

Interpolate two images.

In the typical setting the input images are going to be of the format as returned by the before_interpolation.

Parameters:
  • img1 (np.ndarray) – The left image.

  • img2 (np.ndarray) – The right image.

Returns:

img_mid – The interpolated image.

Return type:

np.ndarray

class atlinter.pair_interpolation.RIFEPairInterpolationModel(rife_model, rife_device)[source]

Bases: PairInterpolationModel

Pairwise image interpolation using the RIFE model.

The typical use is >>> from atlinter.vendor.rife.RIFE_HD import Model as RifeModel >>> from atlinter.vendor.rife.RIFE_HD import device as rife_device >>> rife_model = RifeModel() >>> rife_model.load_model(“/path/to/train_log”, -1) >>> rife_model.eval() >>> rife_interpolation_model = RIFEPairInterpolationModel(rife_model, rife_device)

Parameters:
  • rife_model (atlinter.vendor.rife.RIFE_HD.Model) – The RIFE model instance.

  • rife_device (from atlinter.vendor.rife.RIFE_HD.device) – The RIFE device.

after_interpolation(interpolated_images)[source]

Undo the padding added in before_interpolation.

Parameters:

interpolated_images (np.ndarray) – The stacked interpolated images. If input images are grayscale, the dimension should be (n_img, height, width) or (height, width). If input images are RGB image, the dimension should be (n_img, height, width, 3) or (height, width, 3).

Returns:

The stacked interpolated images with padding removed.

Return type:

np.ndarray

before_interpolation(img1, img2)[source]

Pad input images to a multiple of 32 pixels.

Parameters:
  • img1 (np.ndarray) – The left image of shape.

  • img2 (np.ndarray) – The right image of shape.

Returns:

  • img1 (np.ndarray) – The padded left image.

  • img2 (np.ndarray) – The padded right image.

interpolate(img1, img2)[source]

Interpolate two images using RIFE.

Note: img1 and img2 needs to have the same shape. If img1, img2 are grayscale, the dimension should be (height, width). If img1, img2 are RGB image, the dimension should be (height, width, 3).

Parameters:
  • img1 (np.ndarray) – The left image.

  • img2 (np.ndarray) – The right image.

Returns:

img_mid – The interpolated image.

Return type:

np.ndarray