from __future__ import annotations
from typing import Sequence
import torch
import pytomography
from pytomography.io.PET import shared
from pytomography.projectors.PET import PETLMSystemMatrix
import numpy as np
import parallelproj
from torchrbf import RBFInterpolator
from torch.nn.functional import grid_sample
from pytomography.io.PET.shared import sinogram_coordinates, sinogram_to_spatial, listmode_to_sinogram
from pytomography.projectors.PET import create_sinogramSM_from_LMSM
from pytomography.metadata.PET import PETTOFMeta
from pytomography.metadata import ObjectMeta, ProjMeta
from pytomography.projectors import SystemMatrix
[docs]def total_compton_cross_section(energy: torch.Tensor) -> torch.Tensor:
"""Computes the total compton cross section of interaction :math:`\sigma` at the given photon energies
Args:
energy (torch.Tensor): Energies of photons considered
Returns:
torch.Tensor: Cross section at each corresponding energy
"""
a = energy / 511
l = torch.log(1+2*a)
sigma0 = 6.65e-25
return 0.75 * sigma0 * ((1+a)/a**2 * (2*(1+a) / (1+2*a) - l/a) + l/(2*a) - (1+3*a) / (1+2*a) / (1+2*a))
[docs]def photon_energy_after_compton_scatter_511kev(cos_theta: torch.Tensor) -> torch.Tensor:
"""Computes the corresponding photon energy after a 511keV photon scatters
Args:
cos_theta (torch.Tensor): Angle of scatter
Returns:
torch.Tensor: Photon energy after scattering.
"""
return 511 / (2 - cos_theta)
[docs]def diff_compton_cross_section(cos_theta: torch.Tensor, energy: torch.Tensor) -> torch.Tensor:
r"""Computes the differential cross section :math:`d\sigma/d\omega` at given photon energies and scattering angles
Args:
cos_theta (torch.Tensor): Cosine of the scattering angle
energy (torch.Tensor): Energy of the incident photon before scattering
Returns:
torch.Tensor: Differential compton cross section
"""
Re = 2.818e-13
sin_theta_2 = 1- cos_theta**2
P = 1 / (1+energy/511 * (1-cos_theta))
return Re**2 / 2 * P * (1-P * sin_theta_2 + P**2)
[docs]def detector_efficiency(
scatter_energy: torch.Tensor,
energy_resolution: float = 0.15,
energy_threshhold: float = 430
) -> torch.Tensor:
"""Computes the probability a photon of given energy is detected within the energy limits of the detector
Args:
scatter_energy (torch.Tensor): Energy of the photon impinging the detector
energy_resolution (float, optional): Energy resolution of the crystals (represented as a fraction of 511keV). This is the uncertainty of energy measurements. Defaults to 0.15.
energy_threshhold (float, optional): Lower limit of energies detected by the crystal which are registered as events. Defaults to 430.
Returns:
torch.Tensor: Probability that the photon gets detected
"""
sigma = 511 * energy_resolution / (2*np.sqrt(2*np.log(2)))
return 0.5 * (1 - torch.erf((energy_threshhold-scatter_energy) / (np.sqrt(2) * sigma)))
[docs]def tof_efficiency(
offset: torch.Tensor,
tof_bins_dense_centers: torch.Tensor,
tof_meta: PETTOFMeta
) -> torch.Tensor:
"""Computes the probability that a coincidence event with timing difference offset is detected in each of the TOF bins specified by ``tof_bins_dense_centers``.
Args:
offset (torch.Tensor): Timing offset (in spatial units) between a coincidence event. When this function is used in SSS, ``offset`` has shape :math:`(N_{TOF}, N_{coinc})` where :math:`N_{coinc}` is the number of coincidence events considered, and :math:`N_{TOF}` is the number of time of flight bins in the sinogram.
tof_bins_dense_centers (torch.Tensor): The centers of each of the dense TOF bins. These are seperate from the TOF bins of the sinogram: these TOF bins correspond to the partioning of the integrals in Watson(2007) Equation 2. When used in SSS, this tensor has shape :math:`(N_{coinc}, N_{denseTOF})` where :math:`N_{denseTOF}` are the number of dense TOF bins considered.
tof_meta (PETTOFMeta): TOF metadata for the sinogram
Returns:
torch.Tensor: Relative probability of detecting the event at offset ``offset`` in each of the ``tof_bins_dense_centers`` locations.
"""
prob = torch.exp(-(offset.unsqueeze(-1)-tof_bins_dense_centers.unsqueeze(0))**2 / (2*tof_meta.sigma.item()**2))
prob = prob / prob.sum(dim=0).unsqueeze(0)
return prob
[docs]def get_sample_scatter_points(
attenuation_map: torch.Tensor,
stepsize: float = 4,
attenuation_cutoff: float = 0.004
) -> torch.Tensor:
"""Selects a subset of points in the attenuation map used as scatter points.
Args:
attenuation_map (torch.Tensor): Attenuation map
stepsize (float, optional): Stepsize in x/y/z between sampled points. Defaults to 4.
attenuation_cutoff (float, optional): Only consider points above this threshhold. Defaults to 0.004.
Returns:
torch.Tensor: Tensor of coordinates
"""
mgrid = torch.meshgrid(*[torch.arange(0,s,stepsize) for s in attenuation_map.shape])
coords = torch.vstack([m.flatten() for m in mgrid])
idx_above_cutoff = (attenuation_map[::stepsize,::stepsize,::stepsize].permute((2,1,0)).cpu().numpy().T>attenuation_cutoff).flatten()
coords = coords[:,idx_above_cutoff]
return coords.to(pytomography.device)
[docs]def get_sample_detector_ids(
proj_meta: ProjMeta,
sinogram_interring_stepsize: int = 4,
sinogram_intraring_stepsize: int = 4
) -> Sequence[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Selects a subset of detector IDs in the PET scanner used for obtaining scatter estimates in the sparse sinogram
Args:
proj_meta (ProjMeta): PET projection metadata (sinogram/listmode)
sinogram_interring_stepsize (int, optional): Axial stepsize between rings. Defaults to 4.
sinogram_intraring_stepsize (int, optional): Stepsize of crystals within a given ring. Defaults to 4.
Returns:
Sequence[torch.Tensor, torch.Tensor, torch.Tensor]: Crystal index within ring, ring index, and detector ID pairs corresponding to all sampled LORs.
"""
idx_intraring = torch.arange(0, proj_meta.info['NrCrystalsPerRing'], sinogram_intraring_stepsize)
idx_ring = torch.arange(0, proj_meta.info['NrRings'], sinogram_interring_stepsize)
# Include the top ring for interpolation to not have to extrapolate
if not(proj_meta.info['NrRings']-1 in idx_ring):
idx_ring = torch.cat((idx_ring, torch.tensor([proj_meta.info['NrRings']-1])))
idx = torch.cartesian_prod(idx_ring, idx_intraring).T
idx = idx[1] + idx[0]*proj_meta.info['NrCrystalsPerRing']
return idx_intraring, idx_ring, torch.combinations(idx.cpu(), 2)
[docs]def compute_sss_sparse_sinogram(
object_meta: ObjectMeta,
proj_meta: ProjMeta,
pet_image: torch.Tensor,
attenuation_image: torch.Tensor,
image_stepsize: int = 4,
attenuation_cutoff: float = 0.004,
sinogram_interring_stepsize: int = 4,
sinogram_intraring_stepsize: int = 4
) -> torch.Tensor:
"""Generates a sparse single scatter simulation sinogram for non-TOF PET data.
Args:
object_meta (ObjectMeta): Object metadata corresponding to reconstructed PET image used in the simulation
proj_meta (ProjMeta): Projection metadata specifying the details of the PET scanner
pet_image (torch.Tensor): PET image used to estimate the scatter
attenuation_image (torch.Tensor): Attenuation map used in scatter simulation
image_stepsize (int, optional): Stepsize in x/y/z between sampled scatter points. Defaults to 4.
attenuation_cutoff (float, optional): Only consider points above this threshhold. Defaults to 0.004.
sinogram_interring_stepsize (int, optional): Axial stepsize between rings. Defaults to 4.
sinogram_intraring_stepsize (int, optional): Stepsize of crystals within a given ring. Defaults to 4.
Returns:
torch.Tensor: Estimated sparse single scatter simulation sinogram.
"""
# Important quantities
E_PET = torch.tensor(511).to(pytomography.device)
dr = torch.tensor(object_meta.dr)
shape = torch.tensor(object_meta.shape)
object_origin = (- np.array(object_meta.shape) / 2 + 0.5) * (np.array(object_meta.dr))
scanner_LUT = proj_meta.scanner_lut
total_compton_cross_section_511keV = total_compton_cross_section(E_PET)
# Get sample image/sinogram points
coords = get_sample_scatter_points(attenuation_image, stepsize=image_stepsize, attenuation_cutoff=attenuation_cutoff)
coords_position = (coords - shape.unsqueeze(1).to(pytomography.device)/2 + 0.5) * dr.unsqueeze(1).to(pytomography.device)
_, _, detector_ids_scatter = get_sample_detector_ids(proj_meta, sinogram_interring_stepsize, sinogram_intraring_stepsize)
# Begin
idxA, idxB = detector_ids_scatter.to(pytomography.device).T
rA = scanner_LUT.to(pytomography.device)[idxA]
rB = scanner_LUT.to(pytomography.device)[idxB]
# Maybe now loop over scatter points
probability = 0
counts = 0
for scatter_point in range(coords.shape[1]):
# Get position and add random offset within the voxel
scatter_point_position = coords_position[:,scatter_point] + ((torch.rand(3) - 0.5) * dr).to(pytomography.device)
# Compute value of attenuation coefficient at scatter point
mu_value = attenuation_image[tuple(coords[:,scatter_point].tolist())]
# Compute emission/transmission integrals for that scatter point
emission_integrals = parallelproj.joseph3d_fwd(
scatter_point_position.unsqueeze(0).expand(scanner_LUT.shape[0], -1),
scanner_LUT.to(pytomography.device),
pet_image,
object_origin,
object_meta.dr,
)
transmission_integrals = parallelproj.joseph3d_fwd(
scatter_point_position.unsqueeze(0).expand(scanner_LUT.shape[0], -1),
scanner_LUT.to(pytomography.device),
attenuation_image.to(pytomography.dtype).to(pytomography.device),
object_origin,
object_meta.dr,
)
transmission_integrals_exp = torch.exp(-transmission_integrals)
# Compute scatter contribution
rSA = rA - scatter_point_position
rSB = rB - scatter_point_position
rSA_norm = torch.norm(rSA, dim=1) # distance between S and A
rSB_norm = torch.norm(rSB, dim=1) # distance between S and B
# Compute cos(scattering_angle) = cos(pi-angle_between_vectors) = -cos(angle_between_vectors)
cos_theta = - (rSA*rSB).sum(axis=1) / rSA_norm / rSB_norm
E_new = photon_energy_after_compton_scatter_511kev(cos_theta) # 0.127 ms
energy_efficiency = detector_efficiency(E_new)
# Angle of impingement upon detectors (assumes circle, maybe fix later)
cos_thetaA_incidence = (rSA[:,:2]*rA[:,:2]).sum(axis=1) / rSA_norm / torch.norm(rA[:,:2], dim=1)
cos_thetaB_incidence = (rSB[:,:2]*rB[:,:2]).sum(axis=1) / rSB_norm / torch.norm(rB[:,:2], dim=1)
compton_cross_section_ratio = total_compton_cross_section(E_new) / total_compton_cross_section_511keV
# Start TOF Loop here, needs to consider many different offsets for each emission integral
# Compute probability without considering TOF information
probability_without_tof = 1/(rSB_norm**2 * rSA_norm**2) *\
(emission_integrals[idxA] * transmission_integrals_exp[idxB] ** (compton_cross_section_ratio - 1) + emission_integrals[idxB] * transmission_integrals_exp[idxA] ** (compton_cross_section_ratio - 1)) *\
transmission_integrals_exp[idxB] * transmission_integrals_exp[idxA] * mu_value * energy_efficiency * cos_thetaA_incidence * cos_thetaB_incidence * diff_compton_cross_section(cos_theta, E_PET) / total_compton_cross_section_511keV * np.prod(object_meta.dr)
probability += probability_without_tof
counts += 1
scatter_sinogram_sparse = shared.listmode_to_sinogram(detector_ids_scatter, proj_meta.info, weights=(probability/counts).cpu())
return scatter_sinogram_sparse
[docs]def compute_sss_sparse_sinogram_TOF(
object_meta: ObjectMeta,
proj_meta: ProjMeta,
pet_image: torch.Tensor,
attenuation_image: torch.Tensor,
tof_meta: PETTOFMeta,
image_stepsize: int = 4,
attenuation_cutoff: float = 0.004,
sinogram_interring_stepsize: int = 4,
sinogram_intraring_stepsize: int = 4,
num_dense_tof_bins: int = 25,
N_splits: int = 1
)->torch.Tensor:
"""Generates a sparse single scatter simulation sinogram for TOF PET data.
Args:
object_meta (ObjectMeta): Object metadata corresponding to reconstructed PET image used in the simulation
proj_meta (ProjMeta): Projection metadata specifying the details of the PET scanner
pet_image (torch.Tensor): PET image used to estimate the scatter
attenuation_image (torch.Tensor): Attenuation map used in scatter simulation
tof_meta (PETTOFMeta): PET TOF Metadata corresponding to the sinogram estimate
attenuation_image (torch.Tensor): Attenuation map used in scatter simulation
image_stepsize (int, optional): Stepsize in x/y/z between sampled scatter points. Defaults to 4.
attenuation_cutoff (float, optional): Only consider points above this threshhold. Defaults to 0.004.
sinogram_interring_stepsize (int, optional): Axial stepsize between rings. Defaults to 4.
sinogram_intraring_stepsize (int, optional): Stepsize of crystals within a given ring. Defaults to 4.
num_dense_tof_bins (int, optional): Number of dense TOF bins used when partioning the emission integrals (these integrals must be partioned for TOF-based estimation). Defaults to 25.
Returns:
torch.Tensor: Estimated sparse single scatter simulation sinogram.
"""
# Important quantities
E_PET = torch.tensor(511).to(pytomography.device)
dr = torch.tensor(object_meta.dr)
shape = torch.tensor(object_meta.shape)
object_origin = (- np.array(object_meta.shape) / 2 + 0.5) * (np.array(object_meta.dr))
scanner_LUT = proj_meta.scanner_lut
total_compton_cross_section_511keV = total_compton_cross_section(E_PET)
# Get sample image/sinogram points
coords = get_sample_scatter_points(attenuation_image, stepsize=image_stepsize, attenuation_cutoff=attenuation_cutoff)
coords_position = (coords - shape.unsqueeze(1).to(pytomography.device)/2 + 0.5) * dr.unsqueeze(1).to(pytomography.device)
_, _, detector_ids_scatter = get_sample_detector_ids(proj_meta, sinogram_interring_stepsize, sinogram_intraring_stepsize)
# Begin
idxA, idxB = detector_ids_scatter.to(pytomography.device).T
rA = scanner_LUT.to(pytomography.device)[idxA]
rB = scanner_LUT.to(pytomography.device)[idxB]
# Now loop over scatter points
probability = torch.zeros([tof_meta.num_bins, detector_ids_scatter.shape[0]]).to(pytomography.device)
tof_bin_idxs = torch.arange(tof_meta.num_bins)
tof_bin_positions = tof_meta.bin_positions.to(pytomography.device)
counts = 0
for scatter_point in range(coords.shape[1]):
scatter_point_position = coords_position[:,scatter_point] + ((torch.rand(3) - 0.5) * dr).to(pytomography.device)
# Compute value of attenuation coefficient at scatter point
mu_value = attenuation_image[tuple(coords[:,scatter_point].tolist())]
# Compute emission/transmission integrals for that scatter point
rSD = scanner_LUT.to(pytomography.device) - scatter_point_position
rSD_norm = torch.norm(rSD, dim=1)
bin_edges_scaling = torch.linspace(0,1,num_dense_tof_bins+1).to(pytomography.device)
bin_edges_distance_along_LOR = bin_edges_scaling.reshape((1,-1)) * rSD_norm.reshape((-1,1))
bin_centers_distance_along_LOR = (bin_edges_distance_along_LOR[:,1:] + bin_edges_distance_along_LOR[:,:-1]) / 2
bin_edges = scatter_point_position.reshape((1,1,-1)) + bin_edges_distance_along_LOR.unsqueeze(-1) * (rSD/rSD_norm.unsqueeze(-1)).unsqueeze(1)
# Evaluate emission integral in many distinct line segments between scatter point and detectors (used for TOF)
emission_integrals = parallelproj.joseph3d_fwd(
bin_edges[:,:-1].flatten(end_dim=-2),
bin_edges[:,1:].flatten(end_dim=-2),
pet_image,
object_origin,
object_meta.dr,
).reshape((scanner_LUT.shape[0],num_dense_tof_bins))
transmission_integrals = parallelproj.joseph3d_fwd(
scatter_point_position.unsqueeze(0).expand(scanner_LUT.shape[0], -1),
scanner_LUT.to(pytomography.device),
attenuation_image.to(pytomography.dtype).to(pytomography.device),
object_origin,
object_meta.dr,
)
transmission_integrals_exp = torch.exp(-transmission_integrals)
rSA = rA - scatter_point_position
rSB = rB - scatter_point_position
rSA_norm = torch.norm(rSA, dim=1) # distance between S and A
rSB_norm = torch.norm(rSB, dim=1) # distance between S and B
offset_SA = - ((rSB_norm-rSA_norm).unsqueeze(0)/2 + tof_bin_positions.unsqueeze(1)) # first dim TOFbin
offset_SB = -offset_SA
# Loop over split TOF bins
for tof_bin_idxs_partial in torch.tensor_split(tof_bin_idxs, N_splits):
prob_SA = tof_efficiency(offset_SA[tof_bin_idxs_partial], bin_centers_distance_along_LOR[idxA], tof_meta) # first dim TOFbin
prob_SB = tof_efficiency(offset_SB[tof_bin_idxs_partial], bin_centers_distance_along_LOR[idxB], tof_meta) # first dim TOFbin
# Compute emission integrals
emission_integralsA = (prob_SA*emission_integrals[idxA].unsqueeze(0)).sum(dim=-1)
emission_integralsB = (prob_SB*emission_integrals[idxB].unsqueeze(0)).sum(dim=-1)
cos_theta = - (rSA*rSB).sum(axis=1) / rSA_norm / rSB_norm
E_new = photon_energy_after_compton_scatter_511kev(cos_theta)
energy_efficiency = detector_efficiency(E_new)
# Angle of impingement upon detectors (assumes circle, maybe fix later)
cos_thetaA_incidence = (rSA[:,:2]*rA[:,:2]).sum(axis=1) / rSA_norm / torch.norm(rA[:,:2], dim=1)
cos_thetaB_incidence = (rSB[:,:2]*rB[:,:2]).sum(axis=1) / rSB_norm / torch.norm(rB[:,:2], dim=1)
compton_cross_section_ratio = total_compton_cross_section(E_new) / total_compton_cross_section_511keV
probability[tof_bin_idxs_partial] += 1/(rSB_norm**2 * rSA_norm**2) *\
(emission_integralsA * transmission_integrals_exp[idxB] ** (compton_cross_section_ratio - 1) + emission_integralsB * transmission_integrals_exp[idxA] ** (compton_cross_section_ratio - 1)) *\
transmission_integrals_exp[idxB] * transmission_integrals_exp[idxA] * mu_value * energy_efficiency * cos_thetaA_incidence * cos_thetaB_incidence * diff_compton_cross_section(cos_theta, E_PET) / total_compton_cross_section_511keV * np.prod(object_meta.dr)
counts+=1
probability = probability.ravel()
# Get TOF bins
TOF_bins = torch.cartesian_prod(torch.arange(tof_meta.num_bins), detector_ids_scatter[:,0])[:,0]
# This aligns with how probability was unraveled
detector_ids_scatter_with_TOF = torch.concatenate([detector_ids_scatter.repeat(tof_meta.num_bins,1), TOF_bins.unsqueeze(1)], dim=-1)
scatter_sinogram_sparse = shared.listmode_to_sinogram(detector_ids_scatter_with_TOF, proj_meta.info, tof_meta=tof_meta, weights=(probability/counts).cpu())
return scatter_sinogram_sparse
[docs]def interpolate_sparse_sinogram(
scatter_sinogram_sparse: torch.Tensor,
proj_meta: ProjMeta,
idx_intraring: torch.Tensor,
idx_ring: torch.Tensor
) -> torch.Tensor:
"""Interpolates a sparse SSS sinogram estimate using linear interpolation on all oblique planes.
Args:
scatter_sinogram_sparse (torch.Tensor): Estimated sparse SSS sinogram from the ``compute_sss_sparse_sinogram`` or ``compute_sss_sparse_sinogram_TOF`` functions
proj_meta (ProjMeta): PET projection metadata corresponding to the sinogram
idx_intraring (torch.Tensor): Intraring indices corresponding to non-zero locations of the sinogram (obtained via the ``get_sample_detector_ids`` function)
idx_ring (torch.Tensor): Interring indices corresponding to non-zero locations of the sinogram (obtained via the ``get_sample_detector_ids`` function)
Returns:
torch.Tensor: Interpolated SSS sinogram
"""
lor_coordinates, sinogram_index = sinogram_coordinates(proj_meta.info)
_, ring_coordinates = sinogram_to_spatial(proj_meta.info)
# First interpolate r/theta in all seperate oblique planes
intra_crystal_index_pairs_sparse = torch.combinations(torch.arange(proj_meta.info['NrCrystalsPerRing']),2).T
intra_crystal_index_pairs = torch.combinations(idx_intraring,2).T
inter_crystal_index_pairs = torch.cartesian_prod(idx_ring, idx_ring).T
angular_radial_idx = lor_coordinates[intra_crystal_index_pairs_sparse[0], intra_crystal_index_pairs_sparse[1]]
angular_radial_idx_sparse = lor_coordinates[intra_crystal_index_pairs[0], intra_crystal_index_pairs[1]]
sinogram_plane_idx_sparse = sinogram_index[inter_crystal_index_pairs[0], inter_crystal_index_pairs[1]]
interpolator = RBFInterpolator(
angular_radial_idx_sparse.to(torch.float32).to(pytomography.device),
scatter_sinogram_sparse[angular_radial_idx_sparse.T[0], angular_radial_idx_sparse.T[1]][:,sinogram_plane_idx_sparse].to(pytomography.device),
kernel='linear',
device=pytomography.device
)
interp_vals = interpolator(angular_radial_idx.to(torch.float32).to(pytomography.device))
scatter_sinogram_interp_rtheta = torch.zeros(*scatter_sinogram_sparse.shape[:2], sinogram_plane_idx_sparse.shape[0]).to(pytomography.device)
scatter_sinogram_interp_rtheta[angular_radial_idx.T[0], angular_radial_idx.T[1]] = interp_vals
scatter_sinogram_interp_rtheta = scatter_sinogram_interp_rtheta.reshape(scatter_sinogram_interp_rtheta.shape[0], scatter_sinogram_interp_rtheta.shape[1], len(idx_ring), len(idx_ring))
# Now interpolate Z using grid_sample
z1_sparse = z2_sparse = ring_coordinates[idx_ring][:,0].cpu().numpy().astype(np.float32)
z1 = z2 = ring_coordinates[np.arange(proj_meta.info['NrRings'])][:,0].cpu().numpy().astype(np.float32)
idx = torch.searchsorted(torch.tensor(-z1_sparse), torch.tensor(-z1[1:-1]), side='right') - 1
idx += -(z1_sparse[idx] - z1[1:-1]) / (z1_sparse[idx+1] - z1_sparse[idx])
idx = torch.concatenate([torch.tensor([0]), idx, torch.tensor([z1_sparse.shape[0]-1])])
idx = 2/idx.max() * idx - 1
interp_mesh = np.stack(np.meshgrid(idx,idx, indexing='ij'), axis=-1)
interp_mesh = torch.tensor(interp_mesh).to(torch.float32).to(pytomography.device)
# r/theta becomes batch/channel in grid_sample, which is fine
scatter_sinogram_interp_all = grid_sample(
scatter_sinogram_interp_rtheta.flatten(start_dim=0, end_dim=1).unsqueeze(0),
interp_mesh.unsqueeze(0),
align_corners=True
).reshape((*scatter_sinogram_interp_rtheta.shape[:2], len(z1), len(z2))).cpu()
idx_ring1 = torch.argsort(sinogram_index.ravel()) % sinogram_index.shape[-1]
idx_ring2 = torch.argsort(sinogram_index.ravel()) // sinogram_index.shape[-1]
scatter_sinogram_interp_all = scatter_sinogram_interp_all[:,:,idx_ring1,idx_ring2]
return scatter_sinogram_interp_all
[docs]def scale_estimated_scatter(
proj_scatter: torch.Tensor,
system_matrix: SystemMatrix,
proj_data: torch.Tensor,
attenuation_image: torch.Tensor,
attenuation_image_cutoff: float = 0.004,
sinogram_random: torch.Tensor | None = None
) -> torch.Tensor:
"""Given an interpolated (but unscaled) SSS sinogram/listmode, scales the scatter estimate by considering back projection of masked data. The mask corresponds to all locations below a certain attenuation value, where it is likely that all detected events are purely due to scatter.
Args:
proj_scatter (torch.Tensor): Estimated (but unscaled) SSS data.
system_matrix (SystemMatrix): PET system matrix
proj_data (torch.Tensor): PET projection data corresponding to all detected events
attenuation_image (torch.Tensor): Attenuation map
attenuation_image_cutoff (float, optional): Mask considers regions below this value (forward projected). In particular, the attenuation map is masked above this value, then forward projected. Regions equal to zero in the forward projection are considered for the mask. This allows for hollow regions within the attenuation map to still be considered. Defaults to 0.004.
sinogram_random (torch.Tensor | None, optional): Projection data of estimated random events. Defaults to None.
Returns:
torch.Tensor: Scaled SSS projection data (sinogram/listmode).
"""
system_matrix.TOF = False
norm_BP = system_matrix.compute_normalization_factor()
proj_data_mask = system_matrix.forward((attenuation_image>attenuation_image_cutoff).to(torch.float32))>0
# Random
if sinogram_random is not None:
BP_random_mask = system_matrix.backward(~proj_data_mask*sinogram_random.to(system_matrix.output_device)) / norm_BP
else:
BP_random_mask = 0
if len(proj_data.shape)>3: # TOF dimension added
system_matrix.TOF = True
proj_data_mask = proj_data_mask.unsqueeze(-1)
else:
system_matrix.TOF = False
# Scatter
# Need to get back projecgion of masked scatter and masked totall;
# we'll split into subsets to preserve memory since this requires
# making copies of potentially very large sinogram tensors
N_SUBSETS = 20
system_matrix.set_n_subsets(N_SUBSETS)
BP_scatter_mask = 0
BP_total_mask = 0
for subset_idx in range(N_SUBSETS):
proj_scatter_masked = system_matrix.get_projection_subset(proj_scatter, subset_idx) * system_matrix.get_projection_subset(~proj_data_mask, subset_idx)
proj_total_masked = system_matrix.get_projection_subset(proj_data, subset_idx) * system_matrix.get_projection_subset(~proj_data_mask, subset_idx)
BP_scatter_mask += system_matrix.backward(proj_scatter_masked, subset_idx = subset_idx) / norm_BP
BP_total_mask += system_matrix.backward(proj_total_masked, subset_idx=subset_idx) / norm_BP
BP_scatter_estimated_mask = BP_total_mask - BP_random_mask
BP_scatter_estimated_mask[BP_scatter_estimated_mask<0] = 0
scale_factor = ((BP_scatter_mask*BP_scatter_estimated_mask).sum() / (BP_scatter_mask**2).sum()).item()
return scale_factor * proj_scatter
[docs]def get_sss_scatter_estimate(
object_meta: ObjectMeta,
proj_meta: ProjMeta,
pet_image: torch.Tensor,
attenuation_image: torch.Tensor,
system_matrix: SystemMatrix,
proj_data: torch.Tensor | None = None,
image_stepsize: int = 4,
attenuation_cutoff: float = 0.004,
sinogram_interring_stepsize: int = 4,
sinogram_intraring_stepsize: int = 4,
sinogram_random: torch.Tensor | None = None,
tof_meta: PETTOFMeta = None,
num_dense_tof_bins: int = 25,
N_splits: int = 1
) -> torch.Tensor:
"""Main function used to get SSS scatter estimation during PET reconstruction
Args:
object_meta (ObjectMeta): Object metadata corresponding to ``pet_image``.
proj_meta (ProjMeta): Projection metadata corresponding to ``proj_data``.
pet_image (torch.Tensor): Reconstructed PET image used to get SSS estimate
attenuation_image (torch.Tensor): Attenuation map corresponding to PET image
system_matrix (SystemMatrix): PET system matrix
proj_data (torch.Tensor | None): All measured coincident events (sinogram/listmode). If None, then assumes listmode (coincidence events stored in ``proj_meta``).
image_stepsize (int, optional): Spacing between points in object space used to obtain initial sparse sinogram estimate. Defaults to 4.
attenuation_cutoff (float, optional): Only consider point located at attenuation values above this value as scatter points. Defaults to 0.004.
sinogram_interring_stepsize (int, optional): Sinogram interring spacing for initial sparse sinogram estimate. Defaults to 4.
sinogram_intraring_stepsize (int, optional): Sinogram intraring spacing for initial sparse sinogram estimate. Defaults to 4.
sinogram_random (torch.Tensor | None, optional): Estimated randoms. Defaults to None.
tof_meta (PETTOFMeta, optional): TOFMetadata corresponding to ``proj_data`` (if TOF is considered). Defaults to None.
num_dense_tof_bins (int, optional): Number of dense TOF bins to use for partioning emission integrals when performing a TOF estimate. This is seperate from TOF bins used in the PET data. Defaults to 25.
N_splits (int, optional): Splits the TOF bins into subsets and loops over them sequentially (as opposed to parallel) for scatter estimation. Defaults to 1.
Returns:
torch.Tensor: Estimated SSS projection data (sinogram/listmode)
"""
if type(system_matrix) is PETLMSystemMatrix:
listmode = True
else:
listmode = False
if tof_meta is None:
# Get sparse sinogram
scatter_sinogram_sparse_unscaled = compute_sss_sparse_sinogram(object_meta, proj_meta, pet_image, attenuation_image, image_stepsize, attenuation_cutoff, sinogram_interring_stepsize, sinogram_intraring_stepsize)
# Interpolate sparse sinogram
scatter_sinogram_unscaled = interpolate_sparse_sinogram(scatter_sinogram_sparse_unscaled, proj_meta, *get_sample_detector_ids(proj_meta, sinogram_interring_stepsize, sinogram_intraring_stepsize)[:2])
else:
# Get sparse sinogram
scatter_sinogram_sparse_unscaled = compute_sss_sparse_sinogram_TOF(object_meta, proj_meta, pet_image, attenuation_image, tof_meta, image_stepsize, attenuation_cutoff, sinogram_interring_stepsize, sinogram_intraring_stepsize, num_dense_tof_bins, N_splits)
scatter_sinogram_unscaled = torch.empty(scatter_sinogram_sparse_unscaled.shape, dtype=torch.float32)
# Interpolate sparse sinogram (loop over TOF bins)
for i in range(scatter_sinogram_sparse_unscaled.shape[-1]):
scatter_sinogram_unscaled[...,i] = interpolate_sparse_sinogram(scatter_sinogram_sparse_unscaled[:,:,:,i], proj_meta, *get_sample_detector_ids(proj_meta, sinogram_interring_stepsize, sinogram_intraring_stepsize)[:2])
del(scatter_sinogram_sparse_unscaled) # save memory for next step
# Need to create a sinogram system matrix for scaling
if listmode:
system_matrix = create_sinogramSM_from_LMSM(system_matrix)
if tof_meta is None:
proj_data = listmode_to_sinogram(proj_meta.detector_ids.cpu(), proj_meta.info)
else:
proj_data = listmode_to_sinogram(proj_meta.detector_ids.cpu(), proj_meta.info, tof_meta=tof_meta)
# Scale sinogram
proj_scatter = scale_estimated_scatter(scatter_sinogram_unscaled, system_matrix, proj_data, attenuation_image, attenuation_cutoff, sinogram_random = sinogram_random)
return proj_scatter