atlinter.pair_interpolation module¶
Volume interpolation based on pairwise interpolation between slices.
- class atlinter.pair_interpolation.AntsPairInterpolationModel(register_fn, transform_fn)[source]¶
Bases:
PairInterpolationModel
Pairwise image interpolation using AntsPy registration.
Typical use is >>> from atlannot.ants import register, transform >>> ants_interpolation_model = AntsPairInterpolationModel(register, transform)
- Parameters:
register_fn (atlannot.ants.register) – The AntsPy registration function
transform_fn (atlannot.ants.transform) – The AntsPy transformation function
- class atlinter.pair_interpolation.CAINPairInterpolationModel(cain_model)[source]¶
Bases:
PairInterpolationModel
Pairwise image interpolation using the CAIN model.
The typical use is >>> from atlinter.vendor.cain.cain import CAIN >>> device = “cuda” if torch.cuda.is_available() else “cpu” >>> cain_model = CAIN().to(device) >>> cain_checkpoint = torch.load(“pretrained_cain.pth”, map_location=device) >>> cain_model.load_state_dict(cain_checkpoint) >>> cain_interpolation_model = CAINPairInterpolationModel(cain_model)
- Parameters:
cain_model (atlinter.vendor.cain.cain.CAIN or torch.nn.DataParallel) – The CAIN model instance.
- interpolate(img1, img2)[source]¶
Interpolate two images using CAIN.
Note: img1 and img2 needs to have the same shape. If img1, img2 are grayscale, the dimension should be (height, width). If img1, img2 are RGB image, the dimension should be (height, width, 3).
- Parameters:
img1 (np.ndarray) – The left image.
img2 (np.ndarray) – The right image.
- Returns:
img_mid – The interpolated image.
- Return type:
np.ndarray
- class atlinter.pair_interpolation.GeneInterpolate(gene_data: GeneDataset, model: PairInterpolationModel, border_predictions: bool = True)[source]¶
Bases:
object
Interpolation of a gene dataset.
- Parameters:
gene_data (GeneData) – Gene Dataset to interpolate. It contains a volume of reference shape with all known places located at the right place and a metadata dictionary containing information about the axis of the dataset and the section numbers.
model (PairInterpolationModel) – Pair-interpolation model.
border_predictions (boolean) – If False, slices before the first and after the last known slice are just background. Otherwise, a copy of the extremity slice is done.
- get_all_interpolation() tuple[np.ndarray, np.ndarray] [source]¶
Compute pair interpolation for the entire volume.
- Returns:
all_interpolated_images (np.array) – Interpolated image for the entire volume. Array of shape (N, dim1, dim2, 3) with N the number of interpolated images.
all_predicted_section_numbers (np.array) – Slice value of the predicted images. Array of shape (N, 1) with N the number of interpolated images.
- get_interpolation(left: int, right: int) tuple[np.ndarray | None, np.ndarray | None] [source]¶
Compute the interpolation for a pair of images.
- Parameters:
left – Slice number of the left image to consider.
right – Slice number of the right image to consider.
- Returns:
interpolated_images (np.array or None) – Interpolated image for the given pair of images. Array of shape (N, dim1, dim2, 3) with N the number of interpolated images.
predicted_section_numbers (np.array or None) – Slice value of the predicted images. Array of shape (N, 1) with N the number of interpolated images.
- static get_predicted_section_numbers(left: int, right: int, n_repeat: int) ndarray [source]¶
Get slice values of predicted images.
- class atlinter.pair_interpolation.LinearPairInterpolationModel[source]¶
Bases:
PairInterpolationModel
Linear pairwise interpolation.
This is the simplest possible interpolation model where the middle image is the average of the left and right images.
- class atlinter.pair_interpolation.PairInterpolate(n_repeat=1)[source]¶
Bases:
object
Runner for pairwise interpolation using different models.
- Parameters:
n_repeat (int (optional)) – The number of times the interpolation should be iterated. For each iteration an interpolated image is inserted between each pair of images from the previous iteration. Therefore n_{i+1} = n_i + (n_i + 1). For example, for n_repeat=3 the progression of the number of images will be the following: input = 0 -> 1 -> 3 -> 7
- class atlinter.pair_interpolation.PairInterpolationModel[source]¶
Bases:
ABC
Base class for pair-interpolation models.
Subclasses of this class implement an interpolation between two given images img1 and img2 to produce and intermediate image img_mid.
This class and its subclasses are used by the PairInterpolate class, which applies a given interpolation model to concrete data.
- after_interpolation(interpolated_images)[source]¶
Run any post-processing after all interpolation is done.
Typical applications are padding and cropping of the image stack, as well as any clean-up of the model state.
- Parameters:
interpolated_images (np.ndarray) – The stacked interpolated images. The array will include the input images as the first and the last items respectively and will therefore have the shape (n_interpolated + 2, height, width)
- Returns:
The post-processed interpolated images.
- Return type:
np.ndarray
- before_interpolation(img1, img2)[source]¶
Run initialization and pre-processing steps before interpolation.
Typical applications of this method are padding and cropping of input images to fit the model requirements, as well as initialisation of any internal state, should one be necessary.
- Parameters:
img1 (np.ndarray) – The left image of shape (width, height)
img2 (np.ndarray) – The right image of shape (width, height).
- Returns:
img1 (np.ndarray) – The pre-processed left image.
img2 (np.ndarray) – The pre-processed right image.
- abstract interpolate(img1, img2)[source]¶
Interpolate two images.
In the typical setting the input images are going to be of the format as returned by the before_interpolation.
- Parameters:
img1 (np.ndarray) – The left image.
img2 (np.ndarray) – The right image.
- Returns:
img_mid – The interpolated image.
- Return type:
np.ndarray
- class atlinter.pair_interpolation.RIFEPairInterpolationModel(rife_model, rife_device)[source]¶
Bases:
PairInterpolationModel
Pairwise image interpolation using the RIFE model.
The typical use is >>> from atlinter.vendor.rife.RIFE_HD import Model as RifeModel >>> from atlinter.vendor.rife.RIFE_HD import device as rife_device >>> rife_model = RifeModel() >>> rife_model.load_model(“/path/to/train_log”, -1) >>> rife_model.eval() >>> rife_interpolation_model = RIFEPairInterpolationModel(rife_model, rife_device)
- Parameters:
rife_model (atlinter.vendor.rife.RIFE_HD.Model) – The RIFE model instance.
rife_device (from atlinter.vendor.rife.RIFE_HD.device) – The RIFE device.
- after_interpolation(interpolated_images)[source]¶
Undo the padding added in before_interpolation.
- Parameters:
interpolated_images (np.ndarray) – The stacked interpolated images. If input images are grayscale, the dimension should be (n_img, height, width) or (height, width). If input images are RGB image, the dimension should be (n_img, height, width, 3) or (height, width, 3).
- Returns:
The stacked interpolated images with padding removed.
- Return type:
np.ndarray
- before_interpolation(img1, img2)[source]¶
Pad input images to a multiple of 32 pixels.
- Parameters:
img1 (np.ndarray) – The left image of shape.
img2 (np.ndarray) – The right image of shape.
- Returns:
img1 (np.ndarray) – The padded left image.
img2 (np.ndarray) – The padded right image.
- interpolate(img1, img2)[source]¶
Interpolate two images using RIFE.
Note: img1 and img2 needs to have the same shape. If img1, img2 are grayscale, the dimension should be (height, width). If img1, img2 are RGB image, the dimension should be (height, width, 3).
- Parameters:
img1 (np.ndarray) – The left image.
img2 (np.ndarray) – The right image.
- Returns:
img_mid – The interpolated image.
- Return type:
np.ndarray