Source code for pytomography.io.PET.gate

from __future__ import annotations
from collections.abc import Sequence
import pytomography
from pytomography.metadata import ObjectMeta
import torch
import numpy as np
import numpy.linalg as npl
import uproot
import nibabel as nib
from scipy.ndimage import affine_transform
from ..shared import get_header_value, get_attenuation_map_interfile
from .shared import listmode_to_sinogram, sinogram_to_listmode, get_detector_ids_from_trans_axial_ids, get_axial_trans_ids_from_info, get_scanner_LUT, smooth_randoms_sinogram, randoms_sinogram_to_sinogramTOF

[docs]def get_aligned_attenuation_map( headerfile: str, object_meta: ObjectMeta ) -> torch.tensor: """Returns an aligned attenuation map in units of inverse mm for reconstruction. This assumes that the attenuation map shares the same center point with the reconstruction space. Args: headerfile (str): Filepath to the header file of the attenuation map object_meta (ObjectMeta): Object metadata providing spatial information about the reconstructed dimensions. Returns: torch.Tensor: Aligned attenuation map """ amap = get_attenuation_map_interfile(headerfile).cpu().numpy() # Load metadata with open(headerfile) as f: headerdata = f.readlines() headerdata = np.array(headerdata) dx = get_header_value(headerdata, 'scaling factor (mm/pixel) [1]', np.float32) dy = get_header_value(headerdata, 'scaling factor (mm/pixel) [2]', np.float32) dz = get_header_value(headerdata, 'scaling factor (mm/pixel) [3]', np.float32) dr_amap = (dx, dy, dz) shape_amap = amap.shape object_origin_amap = (- np.array(shape_amap) / 2 + 0.5) * (np.array(dr_amap)) dr = object_meta.dr shape = object_meta.shape object_origin = object_origin = (- np.array(shape) / 2 + 0.5) * (np.array(dr)) M_PET = np.array([ [dr[0],0,0,object_origin[0]], [0,dr[1],0,object_origin[1]], [0,0,dr[2],object_origin[2]], [0,0,0,1] ]) M_CT = np.array([ [dr_amap[0],0,0,object_origin_amap[0]], [0,dr_amap[1],0,object_origin_amap[1]], [0,0,dr_amap[2],object_origin_amap[2]], [0,0,0,1] ]) amap = affine_transform(amap, npl.inv(M_CT)@M_PET, output_shape = shape, order=1) amap = torch.tensor(amap, device=pytomography.device) / 10 # to mm^-1 return amap
[docs]def get_detector_info( path: str, init_volume_name: str = 'crystal', mean_interaction_depth: float = 0, min_rsector_difference: int = 0 ) -> dict: """Generates detector geometry information dictionary from GATE macro file Args: path (str): Path to GATE macro file that defines geometry: should end in ".mac" init_volume_name (str, optional): Initial volume name in the GATE file. Defaults to 'crystal'. mean_interaction_depth (float, optional): Mean interaction depth of photons within crystal. Defaults to 0. min_rsector_difference (int, optional): Minimum r_sector difference for retained events. Defaults to 0. Returns: dict: PET geometry information dictionary """ with open(path) as f: headerdata = f.readlines() headerdata = np.array(headerdata) parents = ['crystal', 'submodule', 'module', 'rsector', 'world'] positions = [] for parent in parents: try: x = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=1) y = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=2) z = get_header_value(headerdata, f'/gate/{parent}/placement/setTranslation', split_substr=None, split_idx=3) except: x = y = z = 0 positions.append([x,y,z]) positions = np.array(positions) x_crystal, y_crystal, z_crystal = positions.sum(axis=0) x_crystal = np.array([x_crystal]) y_crystal = np.array([y_crystal]) z_crystal = np.array([z_crystal]) # Get edges of crystal (assume original in +X) TODO: fix info = {} info['min_rsector_difference'] = min_rsector_difference info['crystal_length'] = get_header_value(headerdata, f'/gate/{init_volume_name}/geometry/setXLength', split_substr=None, split_idx=1) info['radius'] = x_crystal[0] - info['crystal_length']/2 + mean_interaction_depth for parent in parents: repeaters = get_header_value(headerdata, f'/gate/{parent}/repeaters/insert', split_substr=None, split_idx=1, dtype=str, return_all=True) if not(repeaters): if parent=='submodule': info['submoduleAxialNr'] = 1 info['submoduleAxialSpacing'] = 0 info['submoduleTransNr'] = 1 info['submoduleTransSpacing'] = 0 continue for repeater in repeaters: if repeater=='cubicArray': repeat_numbers = np.array([get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber{coord}', split_substr=None, split_idx=1) for coord in ['X', 'Y', 'Z']]) repeat_vector = np.array([get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatVector', split_substr=None, split_idx=i) for i in range(1,4)]) idx_trans = np.argmax(repeat_numbers[:2]) info[f'{parent}TransNr'] = int(repeat_numbers[idx_trans]) info[f'{parent}TransSpacing'] = repeat_vector[idx_trans] info[f'{parent}AxialNr'] = int(repeat_numbers[2]) info[f'{parent}AxialSpacing'] = repeat_vector[2] elif repeater=='linear': repeat_number = get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber', split_substr=None, split_idx=1) repeat_vector = [get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatVector', split_substr=None, split_idx=i) for i in range(1,4)] # append to axial/trans information info[f'{parent}AxialNr'] = int(repeat_number) info[f'{parent}AxialSpacing'] = repeat_vector[2] elif repeater=='ring': repeat_number = get_header_value(headerdata, f'/gate/{parent}/{repeater}/setRepeatNumber', split_substr=None, split_idx=1) # Repeat number for a ring is in the axial direction info[f'{parent}TransNr'] = int(repeat_number) info[f'{parent}AxialNr'] = 1 info['NrCrystalsPerRing'] = info['crystalTransNr'] * info['moduleTransNr'] * info['submoduleTransNr'] * info['rsectorTransNr'] info['NrRings'] = info['crystalAxialNr'] * info['submoduleAxialNr'] * info['moduleAxialNr'] * info['rsectorAxialNr'] info['firstCrystalAxis'] = 1 return info
[docs]def get_axial_trans_ids_from_ROOT( f: object, info: dict, j: int = None, substr: str = 'Coincidences') -> Sequence[torch.Tensor]: """Obtain transaxial and axial IDS (for crystals, submodules, modules, and rsectors) corresponding to each listmode event in an opened ROOT file Args: f (object): Opened ROOT file info (dict): PET geometry information dictionary j (int, optional): Which of the detectors to consider in a coincidence event OR which detector to consider for a single (None). Defaults to None. substr (str, optional): Whether to consider coincidences or singles. Defaults to 'Coincidences'. Returns: Sequence[torch.Tensor]: Sequence of IDs (transaxial/axial) for all components (crystals, submodules, modules, and rsectors) """ if j is None: idx_str = '' else: idx_str = f'{j+1}' ids_rsector = torch.tensor(f[substr][f'rsectorID{idx_str}'].array(library="np")) ids_module = torch.tensor(f[substr][f'moduleID{idx_str}'].array(library="np")) ids_submodule = torch.tensor(f[substr][f'submoduleID{idx_str}'].array(library="np")) ids_crystal = torch.tensor(f[substr][f'crystalID{idx_str}'].array(library="np")) ids_trans_rsector = ids_rsector % info['rsectorTransNr'] ids_axial_rsector = ids_rsector // info['rsectorTransNr'] ids_trans_module = ids_module % info['moduleTransNr'] ids_axial_module = ids_module // info['moduleTransNr'] ids_trans_submodule = ids_submodule % info['submoduleTransNr'] ids_axial_submodule = ids_submodule // info['submoduleTransNr'] ids_trans_crystal = ids_crystal % info['crystalTransNr'] ids_axial_crystal = ids_crystal // info['crystalTransNr'] return ids_trans_crystal, ids_axial_crystal, ids_trans_submodule, ids_axial_submodule, ids_trans_module, ids_axial_module, ids_trans_rsector, ids_axial_rsector
[docs]def get_detector_ids_from_root( paths: Sequence[str], info: dict, tof_meta = None, substr: str = 'Coincidences', include_randoms: bool = True, include_scatters: bool = True, randoms_only: bool = False, scatters_only: bool = False ) -> torch.Tensor: """Obtain detector IDs corresponding to each listmode event in a set of ROOT files Args: paths (Sequence[str]): List of ROOT files to consider info (dict): PET geometry information dictionary tof_meta (PETTOFMeta, optional): PET time of flight metadata for binning. If none, then TOF is not considered Defaults to None. substr (str, optional): Name of events to consider in the ROOT file. Defaults to 'Coincidences'. include_randoms (bool, optional): Whether or not to include random events in the returned listmode events. Defaults to True. include_scatters (bool, optional): Whether or not to include scatter events in the returned listmode events. Defaults to True. randoms_only (bool, optional): Flag to return only random events. Defaults to False. scatters_only (bool, optional): Flag to return only scatter events. Defaults to False. Returns: torch.Tensor: Tensor of shape [N_events,2] (non-TOF) or [N_events,3] (TOF) """ if tof_meta is not None: TOF_bin_edges = tof_meta.bin_edges detector_ids_trio = [[],[],[]] for i,path in enumerate(paths): print(i) with uproot.open(path) as f: N_events = f[substr]['sourcePosX1'].array(library='np').shape[0] valid_indices = torch.ones(N_events).to(torch.bool) if not(include_randoms) or randoms_only or not(include_scatters) or scatters_only: xs1 = torch.tensor(f[substr]['sourcePosX1'].array(library='np')) xs2 = torch.tensor(f[substr]['sourcePosX2'].array(library='np')) ys1 = torch.tensor(f[substr]['sourcePosY1'].array(library='np')) ys2 = torch.tensor(f[substr]['sourcePosY2'].array(library='np')) zs1 = torch.tensor(f[substr]['sourcePosZ1'].array(library='np')) zs2 = torch.tensor(f[substr]['sourcePosZ2'].array(library='np')) random_indices = ~((xs1==xs2)*(ys1==ys2)*(zs1==zs2)) if not(include_scatters) or scatters_only: scatter_raleigh_1 = torch.tensor(f['Coincidences']['RayleighPhantom1'].array(library='np')) scatter_raleigh_2 = torch.tensor(f['Coincidences']['RayleighPhantom2'].array(library='np')) scatter_compton_1 = torch.tensor(f['Coincidences']['comptonPhantom1'].array(library='np')) scatter_compton_2 = torch.tensor(f['Coincidences']['comptonPhantom2'].array(library='np')) scatter_indices = (scatter_raleigh_1+scatter_raleigh_2+scatter_compton_1+scatter_compton_2).to(torch.bool) # Adjust indices we're looking for based on the events we want if randoms_only: valid_indices *= random_indices elif scatters_only: # Only include scatter events that arent from randoms valid_indices *= (scatter_indices)*(~random_indices) else: if not(include_randoms): valid_indices *= ~random_indices if not(include_scatters): valid_indices *= ~scatter_indices for j in range(2): ids_trans_crystal, ids_axial_crystal, ids_trans_submodule, ids_axial_submodule, ids_trans_module, ids_axial_module, ids_trans_rsector, ids_axial_rsector = get_axial_trans_ids_from_ROOT(f, info, j, substr) detector_ids = get_detector_ids_from_trans_axial_ids(ids_trans_crystal, ids_trans_submodule, ids_trans_module, ids_trans_rsector, ids_axial_crystal, ids_axial_submodule, ids_axial_module, ids_axial_rsector, info) detector_ids = detector_ids[valid_indices] detector_ids_trio[j].append(detector_ids.to(torch.int32)) if tof_meta is not None: t1 = f[substr]['time1'].array(library='np') t2 = f[substr]['time2'].array(library='np') tof_pos = 1e12*(t2 - t1) * 0.15 # ps to mm detector_id = np.digitize(-tof_pos, TOF_bin_edges) - 1 # First see if only binning scatters/randoms detector_id = detector_id[valid_indices] detector_ids_trio[2].append(torch.tensor(detector_id)) if tof_meta is not None: return torch.vstack([ torch.concatenate(detector_ids_trio[0]), torch.concatenate(detector_ids_trio[1]), torch.concatenate(detector_ids_trio[2])]).T else: return torch.vstack([ torch.concatenate(detector_ids_trio[0]), torch.concatenate(detector_ids_trio[1])]).T
[docs]def get_symmetry_histogram_from_ROOTfile( f: object, info: dict, substr: str = 'Coincidences', include_randoms: bool = True ) -> torch.Tensor: """Obtains a histogram that exploits symmetries when computing normalization factors from calibration ROOT scans Args: f (object): Opened ROOT file info (dict): PET geometry information dictionary substr (str, optional): Name of events to consider in ROOT file. Defaults to 'Coincidences'. include_randoms (bool, optional): Whether or not to include random events from data. Defaults to True. Returns: torch.Tensor: Symmetry histogram """ ids1_trans_crystal, ids1_axial_crystal, ids1_trans_submodule, ids1_axial_submodule, ids1_trans_module, ids1_axial_module, ids1_trans_rsector, ids1_axial_rsector = get_axial_trans_ids_from_ROOT(f, info, 0, substr= substr) ids2_trans_crystal, ids2_axial_crystal, ids2_trans_submodule, ids2_axial_submodule, ids2_trans_module, ids2_axial_module, ids2_trans_rsector, ids2_axial_rsector = get_axial_trans_ids_from_ROOT(f, info, 1, substr= substr) ids_trans_crystal = torch.vstack([ids1_trans_crystal, ids2_trans_crystal]) ids_axial_crystal = torch.vstack([ids1_axial_crystal, ids2_axial_crystal]) ids_axial_submodule = torch.vstack([ids1_axial_submodule, ids2_axial_submodule]) ids_axial_module = torch.vstack([ids1_axial_module, ids2_axial_module]) ids_trans_rsector = torch.vstack([ids1_trans_rsector, ids2_trans_rsector]) # Make sure smallest detector ID always comes first detector_ids1 = get_detector_ids_from_trans_axial_ids(ids1_trans_crystal, ids1_trans_submodule, ids1_trans_module, ids1_trans_rsector, ids1_axial_crystal, ids1_axial_submodule, ids1_axial_module, ids1_axial_rsector, info) detector_ids2 = get_detector_ids_from_trans_axial_ids(ids2_trans_crystal, ids2_trans_submodule, ids2_trans_module, ids2_trans_rsector, ids2_axial_crystal, ids2_axial_submodule, ids2_axial_module, ids2_axial_rsector, info) detector_ids = torch.vstack([detector_ids1, detector_ids2]) idx_min = detector_ids.min(axis=0).indices idx_max = detector_ids.max(axis=0).indices # Compute histogram quantities ids_delta_axial_submodule = ( torch.take_along_dim(ids_axial_submodule, idx_max.unsqueeze(0), 0) \ - torch.take_along_dim(ids_axial_submodule, idx_min.unsqueeze(0), 0) )\ + (info['submoduleAxialNr'] - 1) ids_delta_axial_module = ( torch.take_along_dim(ids_axial_module, idx_max.unsqueeze(0), 0) \ - torch.take_along_dim(ids_axial_module, idx_min.unsqueeze(0), 0) )\ + (info['moduleAxialNr'] - 1) ids_delta_trans_rsector = ( torch.take_along_dim(ids_trans_rsector, idx_max.unsqueeze(0), 0) \ - torch.take_along_dim(ids_trans_rsector, idx_min.unsqueeze(0), 0) )\ % info['rsectorTransNr'] histo = torch.vstack([ torch.take_along_dim(ids_axial_crystal, idx_min.unsqueeze(0), 0), torch.take_along_dim(ids_axial_crystal, idx_max.unsqueeze(0), 0), torch.take_along_dim(ids_trans_crystal, idx_min.unsqueeze(0), 0), torch.take_along_dim(ids_trans_crystal, idx_max.unsqueeze(0), 0), ids_delta_axial_submodule, ids_delta_axial_module, ids_delta_trans_rsector ]).T if include_randoms: xs1 = torch.tensor(f[substr]['sourcePosX1'].array(library='np')) xs2 = torch.tensor(f[substr]['sourcePosX2'].array(library='np')) ys1 = torch.tensor(f[substr]['sourcePosY1'].array(library='np')) ys2 = torch.tensor(f[substr]['sourcePosY2'].array(library='np')) zs1 = torch.tensor(f[substr]['sourcePosZ1'].array(library='np')) zs2 = torch.tensor(f[substr]['sourcePosZ2'].array(library='np')) same_location_idxs = (xs1==xs2)*(ys1==ys2)*(zs1==zs2) return histo[same_location_idxs] else: return histo
[docs]def get_symmetry_histogram_all_combos(info: dict) -> torch.Tensor: """Obtains the symmetry histogram for detector sensitivity corresponding to all possible detector pair combinations Args: info (dict): PET geometry information dictionary Returns: torch.Tensor: Histogram corresponding to all possible detector pair combinations. This simply counts the number of detector pairs in each bin of the histogram. """ ids_trans_crystal, ids_axial_crystal, ids_trans_submodule, ids_axial_submodule, ids_trans_module, ids_axial_module, ids_trans_rsector, ids_axial_rsector = get_axial_trans_ids_from_info(info, return_combinations=True, sort_by_detector_ids=True) ids_delta_axial_submodule = (ids_axial_submodule[:,1] - ids_axial_submodule[:,0]) + (info['submoduleAxialNr'] - 1) ids_delta_axial_module = (ids_axial_module[:,1] - ids_axial_module[:,0]) + (info['moduleAxialNr'] - 1) ids_delta_trans_rsector = (ids_trans_rsector[:,1] - ids_trans_rsector[:,0]) % info['rsectorTransNr'] # because of circle return torch.vstack([ids_axial_crystal[:,0], ids_axial_crystal[:,1], ids_trans_crystal[:,0], ids_trans_crystal[:,1], ids_delta_axial_submodule, ids_delta_axial_module, ids_delta_trans_rsector]).T
[docs]def get_normalization_weights_cylinder_calibration( paths: Sequence[str], info: dict, cylinder_radius: float, include_randoms: bool = True, ) -> torch.tensor: """Function to get sensitivty factor from a cylindrical calibration phantom Args: paths (Sequence[str]): List of paths corresponding to calibration scan info (dict): PET geometry information dictionary cylinder_radius (float): Radius of cylindrical phantom used in scan include_randoms (bool, optional): Whether or not to include random events from the cylinder calibration. Defaults to True. Returns: torch.tensor: Sensitivty factor for all possible detector combinations """ # Part 1: Geometry correction factor for non-unform exposure from cylindrical shell scanner_LUT = get_scanner_LUT(info) all_LOR_ids = torch.combinations(torch.arange(scanner_LUT.shape[0]).to(torch.int32), 2) geometric_correction_factor = 1/(torch.sqrt(1-(torch.abs(get_radius(all_LOR_ids, scanner_LUT)) / cylinder_radius )**2) + pytomography.delta) # Part 2: Detector sensitivity correction factor (exploits symmetries) Nr_crystal_axial_bins = info['crystalAxialNr'] Nr_crystal_trans_bins = info['crystalTransNr'] Nr_delta_submodule_axial_bins = info['submoduleAxialNr'] * 2 - 1 Nr_delta_module_axial_bins = info['moduleAxialNr'] * 2 - 1 Nr_delta_rsector_trans_bins = info['rsectorTransNr'] # b/c circle histo = torch.zeros([Nr_crystal_axial_bins, Nr_crystal_axial_bins, Nr_crystal_trans_bins, Nr_crystal_trans_bins, Nr_delta_submodule_axial_bins, Nr_delta_module_axial_bins, Nr_delta_rsector_trans_bins]) bin_edges = [torch.arange(x).to(torch.float32)-0.5 for x in [Nr_crystal_axial_bins+1, Nr_crystal_axial_bins+1, Nr_crystal_trans_bins+1, Nr_crystal_trans_bins+1, Nr_delta_submodule_axial_bins+1, Nr_delta_module_axial_bins+1, Nr_delta_rsector_trans_bins+1]] for path in paths: with uproot.open(path) as f: vals = get_symmetry_histogram_from_ROOTfile(f, info, include_randoms=include_randoms) histo += torch.histogramdd(vals.to(torch.float32), bin_edges)[0] vals_all_pairs = get_symmetry_histogram_all_combos(info) N_bins = torch.histogramdd(vals_all_pairs.to(torch.float32), bin_edges)[0] # exploits the fact that vals_all_pairs is in order of ascending detector ids return (histo/N_bins)[vals_all_pairs[:,0], vals_all_pairs[:,1], vals_all_pairs[:,2], vals_all_pairs[:,3], vals_all_pairs[:,4], vals_all_pairs[:,5], vals_all_pairs[:,6]] * geometric_correction_factor
[docs]def get_norm_sinogram_from_listmode_data( weights_sensitivity: torch.Tensor, info: dict ) -> torch.Tensor: """Obtains normalization "sensitivty" sinogram from listmode data Args: weights_sensitivity (torch.Tensor): Sensitivty weight corresponding to all possible detector pairs info (dict): PET geometry information dictionary Returns: torch.Tensor: PET sinogram """ scanner_LUT = get_scanner_LUT(info) all_LOR_ids = torch.combinations(torch.arange(scanner_LUT.shape[0]).to(torch.int32), 2) return listmode_to_sinogram(all_LOR_ids, info, weights=weights_sensitivity, normalization=True)
[docs]def get_norm_sinogram_from_root_data( normalization_paths: Sequence[str], info: dict, cylinder_radius: float, include_randoms: bool =True, ) -> torch.Tensor: """Obtain normalization "sensitivity" sinogram directly from ROOT files Args: normalization_paths (Sequence[str]): Paths to all ROOT files corresponding to calibration scan info (dict): PET geometry information dictionary cylinder_radius (float): Radius of cylinder used in calibration scan include_randoms (bool, optional): Whether or not to include randoms in loaded data. Defaults to True. Returns: torch.Tensor: PET sinogram """ eta = get_normalization_weights_cylinder_calibration( normalization_paths, info, include_randoms=include_randoms, cylinder_radius = cylinder_radius ) return get_norm_sinogram_from_listmode_data(eta, info)
[docs]def get_sinogram_from_root_data( paths: Sequence[str], info: dict, include_randoms: bool = True, include_scatters: bool = True, randoms_only: bool = False, scatters_only: bool = False ) -> torch.Tensor: """Get PET sinogram directly from ROOT data Args: paths (Sequence[str]): GATE generated ROOT files info (dict): PET geometry information dictionary include_randoms (bool, optional): Whether or not to include random events in the sinogram. Defaults to True. include_scatters (bool, optional): Whether or not to include scatter events in the sinogram. Defaults to True. randoms_only (bool, optional): Flag for only binning randoms. Defaults to False. scatters_only (bool, optional): Flag for only binning scatters. Defaults to False. Returns: torch.Tensor: PET sinogram """ detector_ids = get_detector_ids_from_root( paths, info, include_randoms=include_randoms, include_scatters=include_scatters, randoms_only=randoms_only, scatters_only=scatters_only, TOF=False) return listmode_to_sinogram(detector_ids, info)
[docs]def get_radius(detector_ids: torch.tensor, scanner_LUT: torch.tensor) -> torch.tensor: """Gets the radial position of all LORs Args: detector_ids (torch.tensor): Detector ID pairs corresponding to LORs scanner_LUT (torch.tensor): scanner look up table Returns: torch.tensor: radii of all detector ID pairs provided """ x1, y1, z1 = scanner_LUT[detector_ids[:,0]].T x2, y2, z2 = scanner_LUT[detector_ids[:,1]].T return torch.where( (x1==x2)*(y1==y2), torch.sqrt(x1**2+y1**2), (x1*y2-y1*x2)/torch.sqrt((x1-x2)**2+(y1-y2)**2) )
[docs]def get_angle(detector_ids: torch.tensor, scanner_LUT: torch.tensor) -> torch.tensor: """Gets the angular position of all LORs Args: detector_ids (torch.tensor): Detector ID pairs corresponding to LORs scanner_LUT (torch.tensor): scanner look up table Returns: torch.tensor: angle of all detector ID pairs provided """ x1, y1, z1 = scanner_LUT[detector_ids[:,0]].T x2, y2, z2 = scanner_LUT[detector_ids[:,1]].T return torch.where( (x1==x2)*(y1==y2), torch.inf, torch.arccos(torch.abs(x1-x2)/torch.sqrt((x1-x2)**2+(y1-y2)**2)) )
# Removes all LORs not intersecting with reconstruction cube
[docs]def remove_events_out_of_bounds( detector_ids: torch.tensor, scanner_LUT: torch.tensor, object_meta: ObjectMeta ) -> torch.tensor: r"""Removes all detected LORs outside of the reconstruced volume given by ``object_meta``. Args: detector_ids (torch.tensor): :math:`N \times 2` (non-TOF) or :math:`N \times 3` (TOF) tensor that provides detector ID pairs (and TOF bin) for coincidence events. scanner_LUT (torch.tensor): scanner lookup table that provides spatial coordinates for all detector ID pairs object_meta (ObjectMeta): object metadata providing the region of reconstruction Returns: torch.tensor: all detector ID pairs corresponding to coincidence events """ bmin = -torch.tensor(object_meta.shape) * torch.tensor(object_meta.dr) / 2 bmax = torch.tensor(object_meta.shape) * torch.tensor(object_meta.dr) / 2 bmin = bmin.to(detector_ids.device); bmax=bmax.to(detector_ids.device) origin = scanner_LUT[detector_ids[:,0]] direction = scanner_LUT[detector_ids[:,1]] - origin t1 = torch.where( direction>=0, (bmin - origin) / direction, (bmax - origin) / direction ) t2 = torch.where( direction>=0, (bmax - origin) / direction, (bmin - origin) / direction ) intersect = (t1[:,0]>t2[:,1])+(t1[:,1]>t2[:,0])+((t1[:,0]>t2[:,2]))+(t1[:,2]>t2[:,0]) return detector_ids[~intersect]
[docs]def get_attenuation_map_nifti(path, object_meta): # If img is none, extract data from path data = nib.load(path) img = data.get_fdata() Sx, Sy, Sz = -(np.array(img.shape)-1) / 2 dx, dy, dz = data.header['pixdim'][1:4] # Convert from RAS to LPS space for DICOM dx*=-1; dy*=-1 M_highres = np.zeros((4,4)) M_highres[0] = np.array([dx, 0, 0, Sx*dx]) M_highres[1] = np.array([0, dy, 0, Sy*dy]) M_highres[2] = np.array([0, 0, dz, Sz*dz]) M_highres[3] = np.array([0, 0, 0, 1]) dx, dy, dz = object_meta.dr Sx, Sy, Sz = -(np.array(object_meta.shape)-1) / 2 M_pet = np.zeros((4,4)) M_pet[0] = np.array([dx, 0, 0, Sx*dx]) M_pet[1] = np.array([0, dy, 0, Sy*dy]) M_pet[2] = np.array([0, 0, dz, Sz*dz]) M_pet[3] = np.array([0, 0, 0, 1]) M = npl.inv(M_highres) @ M_pet return torch.tensor(affine_transform(img, M, output_shape=object_meta.shape, mode='constant', order=1)) / 10