Source code for pytomography.utils.cardiac_spect

import math
import torch
import torch.nn.functional as F
import pytomography
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import center_of_mass, shift, affine_transform
from matplotlib.patches import FancyArrowPatch
from scipy.ndimage import rotate

[docs]def shift_object(object: torch.Tensor, tx: float, ty: float, tz: float) -> torch.Tensor: """ Shift the 3D volume based on the given translation parameters. Parameters: object: Input 3D volume to be shifted (torch.Tensor[Lx, Ly, Lz]). tx (float): Translation in x direction (voxels). ty (float): Translation in y direction (voxels). tz (float): Translation in z direction (voxels). Returns: Shifted 3D volume: The shifted object (torch.Tensor[Lx, Ly, Lz]). """ object = object.cpu().numpy() if isinstance(object, torch.Tensor) else object translation = [ty, tx, tz] identity_matrix = np.eye(3) shifted_object = affine_transform(object, identity_matrix, offset=translation, order=1) return torch.tensor(shifted_object).to(pytomography.dtype).to(pytomography.device)
[docs]def rotate_object(object: torch.Tensor, transaxial_angle: float, transsagittal_angle: float) -> torch.Tensor: """ Rotates the volume around the transaxial and transsagittal slices. Parameters: object: The 3D volume to be rotated (torch.Tensor[Lx, Ly, Lz]). angle_transaxial (float): The angle to rotate around the transaxial slice (z-axis). angle_transsagittal (float): The angle to rotate around the transsagittal slice (y-axis). Returns: The reoriented object (torch.Tensor[Lx, Ly, Lz]). """ rotated_object = rotate(object.cpu(), transaxial_angle, axes=(1, 0), reshape=False) rotated_object = rotate(rotated_object, transsagittal_angle, axes=(2, 0), reshape=False) return torch.tensor(rotated_object).to(pytomography.dtype).to(pytomography.device)
[docs]def get_mask(object: torch.Tensor, lowerThreshold: float, upperThreshold: float) -> torch.Tensor: """ Segment the object based on the given thresholds. Args: object: torch.Tensor[Lx, Ly, Lz] lowerThreshold: Lower threshold for the binary mask upperThreshold: Upper threshold for the binary mask """ lower_threshold, upper_threshold = lowerThreshold * object.max(), upperThreshold * object.max() mask = (object > lower_threshold) & (object < upper_threshold) return torch.tensor(mask).to(pytomography.dtype).to(pytomography.device)
[docs]def get_shift_values(slice: torch.Tensor) -> torch.Tensor: """ Calculate the shift values for the given object based on the center of mass. Args: slice (torch.Tensor): torch.Tensor[Lx, Ly] representing the slice. Returns: shifted_image (torch.Tensor): The shifted image. shift_values (torch.Tensor): The shift values in x and y directions. """ com = center_of_mass(slice.cpu().numpy()) center_of_image = torch.tensor(slice.shape, dtype=torch.float32) / 2 shift_values = center_of_image - torch.tensor(com, dtype=torch.float32) shifted_image = shift(slice.cpu().numpy().astype(float), shift=shift_values.cpu().numpy(), order=0) # shifted_image = torch.from_numpy(shifted_image) shifted_image = torch.tensor(shifted_image).to(pytomography.dtype).to(pytomography.device) return shifted_image, shift_values[0], shift_values[1]
[docs]def get_angle(slice: torch.Tensor, angle_range_degrees: torch.Tensor) -> float: """ Get the angle of the best line passing through the center of the image. Args: slice (torch.Tensor): torch.Tensor[Lx, Ly] representing the slice. angle_range_degrees (torch.Tensor): torch.Tensor[N] with angles in degrees. Returns: best_angle_degree: The best angle in degrees. """ angle_range_radians = angle_range_degrees * (math.pi / 180) center_x, center_y = slice.shape[1] // 2, slice.shape[0] // 2 y_indices, x_indices = torch.nonzero(slice, as_tuple=True) slopes = torch.tan(angle_range_radians) intercepts = center_y - slopes * center_x total_distances = torch.sum( torch.abs(slopes[:, None] * x_indices - y_indices + intercepts[:, None]) / torch.sqrt(slopes[:, None] ** 2 + 1), dim=1 ) best_idx = torch.argmin(total_distances) best_angle_radians = angle_range_radians[best_idx] best_angle_degree = math.degrees(best_angle_radians) return best_angle_degree
[docs]def plot_arrow(image, angle_degrees, fontsize=10): """ Draws an arrow from the center of the image based on the given angle and displays a polar plot around it. The angle is also displayed inside the plot. Parameters: image (np.ndarray): 2D numpy array representing the image. angle_degrees (float): Angle of the arrow with respect to the vertical, measured clockwise. """ center_x, center_y = image.shape[1] // 2, image.shape[0] // 2 angle_radians = np.deg2rad(angle_degrees) length = min(image.shape) // 2 # Ensure the arrow length is within the image bounds end_x = center_x + length * np.sin(angle_radians) end_y = center_y - length * np.cos(angle_radians) # Note the subtraction for y fig, ax = plt.subplots(figsize=(5, 5)) ax.imshow(image.cpu().numpy(), cmap='gray') ax.add_patch(FancyArrowPatch( (center_x, center_y), (end_x, end_y), color='red', arrowstyle='-|>', mutation_scale=15 )) num_degrees = 360 circle = plt.Circle((center_x, center_y), length, color='red', fill=False, linestyle='--') ax.add_artist(circle) for angle in range(0, num_degrees, 10): angle_rad = np.deg2rad(angle) x = center_x + (length + 10) * np.sin(angle_rad) y = center_y - (length + 10) * np.cos(angle_rad) ax.text(x, y, f'{angle}', fontsize=fontsize, color='red', ha='center', va='center') ax.text(center_x, center_y, f'{angle_degrees:.2f}°', fontsize=fontsize, color='red', ha='center', va='center', bbox=dict(facecolor='white', alpha=0.5, edgecolor='red')) ax.axis('off') ax.set_xlim(0, image.shape[1]) ax.set_ylim(image.shape[0], 0) ax.set_aspect('equal') plt.show()
[docs]def create_circular_mask(height, width, radius): """ Create a circular mask. """ Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij') center_y, center_x = height // 2, width // 2 dist_from_center = torch.sqrt((X - center_x)**2 + (Y - center_y)**2) mask = dist_from_center <= radius return mask
[docs]def masking(object: torch.Tensor, radius: int) -> torch.Tensor: """ Apply a circular mask to the object in all three views. """ depth, height, width = object.shape mask_axial = create_circular_mask(height, width, radius) mask_coronal = create_circular_mask(depth, width, radius) mask_sagittal = create_circular_mask(depth, height, radius) masked_object = object.clone() for i in range(depth): masked_object[i, ~mask_axial] = 0 for j in range(height): masked_object[:, j, :][~mask_coronal] = 0 for k in range(width): masked_object[:, :, k][~mask_sagittal] = 0 return torch.tensor(masked_object).to(pytomography.dtype).to(pytomography.device)