r"""This module consists of preconditioned gradient ascent (PGA) algorithms: these algorithms are both statistical (since they depend on a likelihood function dependent on the imaging system) and iterative. Common clinical reconstruction algorithms, such as OSEM, correspond to a subclass of PGA algorithms. PGA algorithms are characterized by the update rule :math:`f^{n+1} = f^{n} + C^{n}(f^{n}) \left[\nabla_{f} L(g^n|f^{n}) - \beta \nabla_{f} V(f^{n}) \right]` where :math:`L(g^n|f^{n})` is the likelihood function, :math:`V(f^{n})` is the prior function, :math:`C^{n}(f^{n})` is the preconditioner, and :math:`\beta` is a scalar used to scale the prior function."""
from __future__ import annotations
from collections.abc import Callable, Sequence
import pytomography
import torch
from pytomography.callbacks import Callback, DataStorageCallback
from pytomography.likelihoods import Likelihood, SARTWeightedNegativeMSELikelihood
from pytomography.priors import Prior
from pytomography.io.SPECT import dicom
from pytomography.projectors import SystemMatrix
[docs]class PreconditionedGradientAscentAlgorithm:
r"""Generic class for preconditioned gradient ascent algorithms: i.e. those that have the form :math:`f^{n+1} = f^{n} + C^{n}(f^{n}) \left[\nabla_{f} L(g^n|f^{n}) - \beta \nabla_{f} V(f^{n}) \right]`.
Args:
likelihood (Likelihood): Likelihood class that facilitates computation of :math:`L(g^n|f^{n})` and its associated derivatives.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used Defaults to None.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
addition_after_iteration (float, optional): Value to add to the object after each iteration. This prevents image voxels getting "locked" at values of 0 for certain algorithms. Defaults to 0.
"""
def __init__(
self,
likelihood: Likelihood,
prior: Prior = None,
object_initial: torch.Tensor | None = None,
addition_after_iteration: float = 0,
**kwargs,
) -> None:
self.likelihood = likelihood
if object_initial is None:
self.object_prediction = self.likelihood.system_matrix._get_object_initial(pytomography.device)
else:
self.object_prediction = object_initial.to(pytomography.device).to(pytomography.dtype)
self.prior = prior
if self.prior is not None:
self.prior.set_object_meta(self.likelihood.system_matrix.object_meta)
self.prior.set_FOV_scale(self.likelihood.system_matrix._get_prior_FOV_scale())
# These are if objects / FPS are stored during reconstruction for uncertainty analysis afterwards
self.objects_stored = []
self.projections_predicted_stored = []
self.addition_after_iteration = addition_after_iteration
[docs] def _set_n_subsets(self, n_subsets: int):
"""Sets the number of subsets used in the reconstruction algorithm.
Args:
n_subsets (int): Number of subsets
"""
self.n_subsets = n_subsets
self.likelihood._set_n_subsets(n_subsets)
[docs] def _compute_preconditioner(
self,
object: torch.Tensor,
n_iter: int,
n_subset: int
) -> None:
r"""Computes the preconditioner factor :math:`C^{n}(f^{n})`. Must be implemented by any reconstruction algorithm that inherits from this generic class.
Args:
object (torch.Tensor): Object :math:`f^n`
n_iter (int): Iteration number
n_subset (int): Subset number
Raises:
NotImplementedError: .
"""
raise NotImplementedError("_compute_preconditioner not implemented for this reconstruction algorithm; this must be implemented by any subclass of PreconditionedGradientAscentAlgorithm")
[docs] def _compute_callback(self, n_iter: int, n_subset: int):
"""Method for computing callbacks after each reconstruction iteration
Args:
n_iter (int): Number of iterations
n_subset (int): Number of subsets
"""
self.object_prediction = self.callback.run(self.object_prediction, n_iter, n_subset)
[docs] def __call__(
self,
n_iters: int,
n_subsets: int = 1,
n_subset_specific: int | None = None,
callback: Callback | None = None,
):
"""_summary_
Args:
Args:
n_iters (int): Number of iterations
n_subsets (int): Number of subsets
n_subset_specific (int): Ignore all updates except for this subset.
callback (Callback, optional): Callback function to be called after each subiteration. Defaults to None.
Returns:
torch.Tensor: Reconstructed object.
"""
self.callback = callback
self.n_iters = n_iters
self._set_n_subsets(n_subsets)
# Perform reconstruction loop
for j in range(n_iters):
for k in range(n_subsets):
if n_subset_specific is not None:
if n_subset_specific!=k:
continue
if n_subsets==1:
subset_idx = None
else:
subset_idx = k
# Adjust object before iteration: note, because of uncertainty analysis this must be at beginning
if bool(self.prior):
self.prior.set_object(torch.clone(self.object_prediction).to(pytomography.device))
self.prior.set_beta_scale(self.likelihood.system_matrix.get_weighting_subset(subset_idx))
self.prior_gradient = self.prior(derivative_order=1)
else:
self.prior_gradient = 0
likelihood_gradient = self.likelihood.compute_gradient(self.object_prediction, subset_idx)
preconditioner = self._compute_preconditioner(self.object_prediction, j, subset_idx)
self.object_prediction += preconditioner * (likelihood_gradient - self.prior_gradient)
self.object_prediction[self.object_prediction<=self.addition_after_iteration] = self.addition_after_iteration
if self.callback is not None:
self._compute_callback(n_iter=j, n_subset=k)
# Remove the addition after the last iteration
#self.object_prediction -= self.addition_after_iteration
if self.callback is not None:
self.callback.finalize(self.object_prediction)
return self.object_prediction
[docs]class LinearPreconditionedGradientAscentAlgorithm(PreconditionedGradientAscentAlgorithm):
r"""Implementation of a special case of ``PreconditionedGradientAscentAlgorithm`` whereby :math:`C^{n}(f^n) = D^{n} f^{n}`
Args:
likelihood (Likelihood): Likelihood class that facilitates computation of :math:`L(g^n|f^{n})` and its associated derivatives.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used Defaults to None.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
addition_after_iteration (float, optional): Value to add to the object after each iteration. This prevents image voxels getting "locked" at values of 0 for certain algorithms. Defaults to 0.
"""
[docs] def _linear_preconditioner_factor(self, n_iter: int, n_subset: int):
r"""Implementation of object independent scaling factor :math:`D^{n}` in :math:`C^{n}(f^{n}) = D^{n} f^{n}`
Args:
n_iter (int): iteration number
n_subset (int): subset number
Raises:
NotImplementedError: .
"""
raise NotImplementedError("_linear_preconditioner_factor not implemented for this reconstruction algorithm; this must be implemented by any subclass of LinearPreconditionedGradientAscentAlgorithm")
[docs] def _compute_preconditioner(
self,
object: torch.Tensor,
n_iter: int,
n_subset: int
) -> torch.Tensor:
r"""Computes the preconditioner :math:`C^{n}(f^n) = D^{n} \text{diag}\left(f^{n}\right)` using the associated `_linear_preconditioner_factor` method.
Args:
object (torch.Tensor): Object :math:`f^{n}`
n_iter (int): Iteration :math:`n`
n_subset (int): Subset :math:`m`
Returns:
torch.Tensor: Preconditioner factor
"""
return object * self._linear_preconditioner_factor(n_iter, n_subset)
[docs] def compute_uncertainty(
self,
mask: torch.Tensor,
data_storage_callback: DataStorageCallback,
subiteration_number : int | None = None,
return_pct: bool = False,
include_additive_term: bool = False,
post_recon_filter: Transform | None = None
) -> float | Sequence[float]:
"""Estimates the uncertainty of the sum of voxels in a reconstructed image. Calling this method requires a masked region `mask` as well as an instance of `DataStorageCallback` that has been used in a reconstruction algorithm: this data storage contains the estimated object and associated forward projection at each subiteration number.
Args:
mask (torch.Tensor): Masked region of the reconstructed object: a boolean Tensor.
data_storage_callback (Callback): Callback that has been used in a reconstruction algorithm.
subiteration_number (int | None, optional): Subiteration number to compute the uncertainty for. If None, then computes the uncertainty for the last iteration. Defaults to None.
return_pct (bool, optional): If true, then additionally returns the percent uncertainty for the sum of counts. Defaults to False.
include_additive_term (bool): Whether or not to include uncertainty contribution from the additive term. This requires the ``additive_term_variance_estimate`` as an argument to the initialized likelihood. Defaults to False.
Returns:
float | Sequence[float]: Absolute uncertainty in the sum of counts in the masked region (if `return_pct` is False) OR absolute uncertainty and relative uncertainty in percent (if `return_pct` is True)
"""
if subiteration_number is None:
subiteration_number = len(data_storage_callback.objects) - 1
# Get final reconstruciton
final_recon = data_storage_callback.objects[subiteration_number].to(pytomography.device)
# Apply filter if provided
if post_recon_filter is not None:
Q_sequence_current = post_recon_filter(mask)
final_recon = post_recon_filter(final_recon)
else:
Q_sequence_current = mask.clone()
V = 0
for n in range(subiteration_number, 0, -1):
V += self._compute_B(Q_sequence_current, data_storage_callback, n-1, include_additive_term=include_additive_term)
if n>1:
Q_sequence_current = self._compute_Q(Q_sequence_current, data_storage_callback, n-1)
uncertainty_abs2 = torch.sum(V[0] * self.likelihood.projections * V[0])
# If uncertainty estimated in the additive term
if include_additive_term:
uncertainty_abs2 += torch.sum(self.likelihood.additive_term_variance_estimate(V[1]))
uncertainty_abs = torch.sqrt(uncertainty_abs2).item()
if not(return_pct):
return uncertainty_abs
else:
uncertainty_rel = uncertainty_abs / (final_recon*mask).sum().item() * 100
return uncertainty_abs, uncertainty_rel
[docs] def _compute_Q(
self,
input: torch.Tensor,
data_storage_callback: Callback,
n: int,
) -> torch.Tensor:
"""Computes the operation of :math:`Q` on an input object; this is a helper function for ``compute_uncertainty``. For more details, see the uncertainty paper.
Args:
input (torch.Tensor): Object on which Q operates
data_storage_callback (Callback): Data storage callback containing all objects and forward projections at each subiteration
n (int): Subiteration number
Returns:
torch.Tensor: Resulting output object from the operation of :math:`Q` on the input object
"""
if self.n_subsets==1:
subset_idx = None
else:
subset_idx = n%self.n_subsets
object_current_update = data_storage_callback.objects[n].to(pytomography.device)
object_future_update = data_storage_callback.objects[n+1].to(pytomography.device)
FP_current_update = data_storage_callback.projections_predicted[n].to(pytomography.device)
likelihood_grad_ff = self.likelihood.compute_gradient_ff(object_current_update, FP_current_update, subset_idx)
# TODO Fix None argument later (required for relaxation sequence)
output = input * object_current_update * self._linear_preconditioner_factor(None, subset_idx)
if self.prior is not None:
self.prior.set_beta_scale(self.likelihood.system_matrix.get_weighting_subset(subset_idx))
self.prior.set_object(object_current_update)
output = likelihood_grad_ff(output) - self.prior(derivative_order=2)(output)
else:
output = likelihood_grad_ff(output)
output += (object_future_update / (object_current_update+pytomography.delta)) * input
return output
[docs] def _compute_B(
self,
input: torch.Tensor,
data_storage_callback: Callback,
n: int,
include_additive_term: bool = False
) -> torch.Tensor:
"""Computes the operation of :math:`B` on an input object; this is a helper function for ``compute_uncertainty``. For more details, see the uncertainty paper.
Args:
input (torch.Tensor): Object on which B operates
data_storage_callback (Callback): Data storage callback containing all objects and forward projections at each subiteration
n (int): Subiteration number
include_additive_term (bool, optional): Whether or not to include uncertainty estimation for the additive term. Defaults to False.
Returns:
torch.Tensor: Resulting output projections from the operation of :math:`B` on the input object
"""
if self.n_subsets==1:
subset_idx = None
subset_indices_array = torch.arange(self.likelihood.system_matrix.proj_meta.shape[0]).to(torch.long).to(pytomography.device)
else:
subset_idx = n%self.n_subsets
subset_indices_array = self.likelihood.system_matrix.subset_indices_array[subset_idx]
object_current_update = data_storage_callback.objects[n].to(pytomography.device)
FP_current_update = data_storage_callback.projections_predicted[n].to(pytomography.device)
output = input * object_current_update * self._linear_preconditioner_factor(None, subset_idx)
output_primary = self.likelihood.compute_gradient_gf(object_current_update, FP_current_update, subset_idx)(output)
dim_idx = len(FP_current_update.shape) - len(self.likelihood.system_matrix.proj_meta.shape) # the subset dimension is always 0 unless dual photopeak is used, in which case it may be dimension 1
extra_dims = [FP_current_update.shape[i] for i in range(dim_idx)]
if include_additive_term:
output_additive = self.likelihood.compute_gradient_sf(object_current_update, FP_current_update, subset_idx)(output)
output_total = torch.zeros((
2,
*extra_dims,
*self.likelihood.system_matrix.proj_meta.shape,
)).to(pytomography.device)
output_total[0].index_copy_(dim_idx, subset_indices_array.to(pytomography.device), output_primary)
output_total[1].index_copy_(dim_idx, subset_indices_array.to(pytomography.device), output_additive)
return output_total
else:
output_total = torch.zeros((
1,
*extra_dims,
*self.likelihood.system_matrix.proj_meta.shape,
)).to(pytomography.device)
output_total[0].index_copy_(dim_idx, subset_indices_array.to(pytomography.device), output_primary)
return output_total
[docs]class OSEM(LinearPreconditionedGradientAscentAlgorithm):
r"""Implementation of the ordered subset expectation maximum algorithm :math:`f^{n+1} = f^{n} + \frac{f^n}{H_n^T} \nabla_{f} L(g^n|f^{n})`.
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.tensor | None = None,
):
super(OSEM, self).__init__(
likelihood = likelihood,
object_initial = object_initial,
)
[docs] def _linear_preconditioner_factor(self, n_iter: int, n_subset: int) -> torch.Tensor:
"""Computes the linear preconditioner factor :math:`D^n = 1/H_n^T 1`
Args:
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: linear preconditioner factor
"""
return 1/(self.likelihood._get_normBP(n_subset) + pytomography.delta)
[docs]class OSMAPOSL(PreconditionedGradientAscentAlgorithm):
r"""Implementation of the ordered subset maximum a posteriori one step late algorithm :math:`f^{n+1} = f^{n} + \frac{f^n}{H_n^T+\nabla_f V(f^n)} \left[ \nabla_{f} L(g^n|f^{n}) - \nabla_f V(f^n) \right]`
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used. Defaults to None.
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.tensor | None = None,
prior: Prior | None = None,
):
super(OSMAPOSL, self).__init__(
likelihood = likelihood,
object_initial = object_initial,
prior = prior
)
[docs] def _compute_preconditioner(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
r"""Computes the preconditioner factor :math:`C^n(f^n) = \frac{f^n}{H_n^T+\nabla_f V(f^n)}`
Args:
object (torch.Tensor): Object estimate :math:`f^n`
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: preconditioner factor.
"""
return object/(self.likelihood._get_normBP(n_subset) + self.prior_gradient + pytomography.delta)
[docs]class RBIEM(LinearPreconditionedGradientAscentAlgorithm):
r"""Implementation of the rescaled block iterative expectation maximum algorithm
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used. Defaults to None.
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.tensor | None = None,
prior: Prior | None = None,
):
super(RBIEM, self).__init__(
likelihood = likelihood,
object_initial = object_initial,
prior = prior
)
[docs] def _compute_preconditioner(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
r"""Computes the preconditioner factor :math:`C^n(f^n) = \frac{f^n}{H_n^T+\nabla_f V(f^n)}`
Args:
object (torch.Tensor): Object estimate :math:`f^n`
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: preconditioner factor.
"""
norm_BP = self.likelihood._get_normBP(n_subset)
norm_BP_allsubsets = self.likelihood._get_normBP(n_subset, return_sum=True)
rm = torch.max(norm_BP / (norm_BP_allsubsets + pytomography.delta))
return object/(norm_BP_allsubsets*rm + pytomography.delta)
[docs]class RBIMAP(PreconditionedGradientAscentAlgorithm):
r"""Implementation of the rescaled block iterative maximum a posteriori algorithm
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used. Defaults to None.
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.tensor | None = None,
prior: Prior | None = None,
):
super(RBIMAP, self).__init__(
likelihood = likelihood,
object_initial = object_initial,
prior = prior
)
[docs] def _compute_preconditioner(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
r"""Computes the preconditioner factor :math:`C^n(f^n) = \frac{f^n}{H_n^T+\nabla_f V(f^n)}`
Args:
object (torch.Tensor): Object estimate :math:`f^n`
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: preconditioner factor.
"""
norm_BP = self.likelihood._get_normBP(n_subset)
norm_BP_allsubsets = self.likelihood._get_normBP(n_subset, return_sum=True)
rm = torch.max((norm_BP + self.prior_gradient) / (norm_BP_allsubsets + self.prior_gradient + pytomography.delta))
return object/(norm_BP_allsubsets*rm + self.prior_gradient + pytomography.delta)
[docs]class BSREM(LinearPreconditionedGradientAscentAlgorithm):
r"""Implementation of the block sequential regularized expectation maximum algorithm :math:`f^{n+1} = f^{n} + \frac{\alpha(n)}{\omega_n H^T 1} \left[\nabla_{f} L(g^n|f^{n}) - \nabla_f V(f^n) \right]`
Args:
likelihood (Likelihood): likelihood function :math:`L`
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
prior (Prior, optional): Prior class that faciliates the computation of function :math:`V(f)` and its associated derivatives. If None, then no prior is used. Defaults to None.
relaxation_sequence (Callable, optional): Relxation sequence :math:`\alpha(n)` used to scale future updates. Defaults to 1 for all :math:`n`. Note that when this function is provided, it takes the iteration number (not the subiteration) so that e.g. if 4 iterations and 8 subsets are used, it would call :math:`\alpha(4)` for all 8 subiterations of the final iteration.
addition_after_iteration (float, optional): Value to add to the object after each iteration. This prevents image voxels getting "locked" at values of 0. Defaults to 1e-4.
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.tensor | None = None,
prior: Prior | None = None,
relaxation_sequence: Callable = lambda _: 1,
addition_after_iteration = 1e-4, # good for typical counts in Lu177 SPECT
):
self.relaxation_sequence = relaxation_sequence
super(BSREM, self).__init__(
likelihood = likelihood,
object_initial = object_initial,
prior = prior,
addition_after_iteration = addition_after_iteration
)
[docs] def _linear_preconditioner_factor(self, n_iter: int, n_subset: int):
r"""Computes the linear preconditioner factor :math:`D^n = 1/(\omega_n H^T 1)` where :math:`\omega_n` corresponds to the fraction of subsets at subiteration :math:`n`.
Args:
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: linear preconditioner factor
"""
relaxation_factor = self.relaxation_sequence(n_iter)
norm_BP = self.likelihood._get_normBP(n_subset, return_sum=True)
norm_BP_weight = self.likelihood.system_matrix.get_weighting_subset(n_subset)
return relaxation_factor/(norm_BP_weight * norm_BP + pytomography.delta)
[docs]class KEM(OSEM):
r"""Implementation of the ordered subset expectation maximum algorithm :math:`\alpha^{n+1} = \alpha^{n} + \frac{\alpha^n}{\tilde{H}_n^T} \nabla_{f} L(g^n|\alpha^{n})` and where the final predicted object is :math:`f^n = K \hat{\alpha}^{n}`. The system matrix :math:`\tilde{H}` includes the kernel transform :math:`K`.
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
"""
[docs] def _compute_callback(self, n_iter: int, n_subset: int):
r"""Method for computing callbacks after each reconstruction iteration. This is reimplemented for KEM because the callback needs to be called on :math:`f^n = K \hat{\alpha}^{n}` as opposed to :math:`\hat{\alpha}^{n}`
Args:
n_iter (int): Number of iterations
n_subset (int): Number of subsets
"""
self.callback.run(self.likelihood.system_matrix.kem_transform.forward(self.object_prediction), n_iter, n_subset)
[docs] def __call__(
self, *args, **kwargs
):
r"""Reimplementation of the call method such that :math:`f^n = K \hat{\alpha}^{n}` is returned as opposed to :math:`\hat{\alpha}^{n}`
Returns:
torch.Tensor: reconstructed object
"""
object_prediction = super(KEM, self).__call__(*args, **kwargs)
return self.likelihood.system_matrix.kem_transform.forward(object_prediction)
[docs]class MLEM(OSEM):
r"""Implementation of the maximum likelihood expectation maximum algorithm :math:`f^{n+1} = f^{n} + \frac{f^n}{H^T} \nabla_{f} L(g|f^{n})`.
Args:
likelihood (Likelihood): Likelihood function :math:`L`.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
"""
[docs] def __call__(self, n_iters, callback=None):
return super(MLEM, self).__call__(n_iters, n_subsets=1, callback=callback)
[docs]class SART(PreconditionedGradientAscentAlgorithm):
r"""Implementation of the SART algorithm. This algorithm takes as input the system matrix and projections (as opposed to a likelihood). This is an implementation of equation 3 of https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8506772/
Args:
system_matrix (SystemMatrix): System matrix for the imaging system.
projections (torch.Tensor): Projections for the imaging system.
additive_term (torch.Tensor | None, optional): Additive term for the imaging system. If None, then no additive term is used. Defaults to None.
object_initial (torch.Tensor | None, optional): Initial object for reconstruction algorithm. If None, then an object with 1 in every voxel is used. Defaults to None.
"""
def __init__(
self,
system_matrix: SystemMatrix,
projections: torch.Tensor,
additive_term: torch.Tensor | None = None,
object_initial: torch.tensor | None = None,
relaxation_sequence: Callable = lambda _: 1,
):
likelihood = SARTWeightedNegativeMSELikelihood(system_matrix, projections, additive_term=additive_term)
super().__init__(
likelihood = likelihood,
object_initial = object_initial,
)
self.relaxation_sequence = relaxation_sequence
[docs] def _compute_preconditioner(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
r"""Computes the preconditioner factor :math:`C^n(f^n) = \frac{1}{H_n^T+\nabla_f V(f^n)}`
Args:
object (torch.Tensor): Object estimate :math:`f^n`
n_iter (int): iteration number
n_subset (int): subset number
Returns:
torch.Tensor: preconditioner factor.
"""
relaxation_factor = self.relaxation_sequence(n_iter)
return relaxation_factor/(self.likelihood._get_normBP(n_subset) + pytomography.delta)
[docs]class PGAAMultiBedSPECT(PreconditionedGradientAscentAlgorithm):
"""Assistant class for performing reconstruction on multi-bed SPECT data. This class is a wrapper around a reconstruction algorithm that is called for each bed and then the results are stitched together.
Args:
files_NM (Sequence[str]): Sequence of SPECT raw data paths corresponding to each likelihood
reconstruction_algorithm (Algorithm): Reconstruction algorithm used for reconstruction of each bed position
"""
def __init__(
self,
files_NM: Sequence[str],
reconstruction_algorithms: Sequence[object],
) -> None:
self.files_NM = files_NM
self.reconstruction_algorithms = reconstruction_algorithms
[docs] def __call__(
self,
n_iters: int,
n_subsets: int,
callback: Callback | Sequence[Callback] | None = None,
) -> torch.Tensor:
"""Perform reconstruction of each bed position for specified iteraitons and subsets, and return the stitched image
Args:
n_iters (int): Number of iterations to perform reconstruction for.
n_subsets (int): Number of subsets to perform reconstruction for.
callback (Callback | Sequence[Callback] | None, optional): Callback function. If a single Callback is given, then the callback is computed for the stitched image. If a sequence of callbacks is given, then it must be the same length as the number of bed positions; each callback is called on the reconstruction for each bed position. If None, no Callback is used. Defaults to None.
Returns:
torch.Tensor: _description_
"""
self.callback = callback
for i in range(n_iters):
for j in range(n_subsets):
self.recons = []
for recon_algo in self.reconstruction_algorithms:
self.recons.append(recon_algo(1, n_subsets, n_subset_specific=j))
self.object_prediction = dicom.stitch_multibed(
recons=torch.stack(self.recons),
files_NM = self.files_NM
)
self._compute_callback(i,j)
self._finalize_callback()
return self.object_prediction
[docs] def _compute_callback(self, n_iter: int, n_subset: int):
"""Computes the callback at iteration ``n_iter`` and subset ``n_subset``.
Args:
n_iter (int): Iteration number
n_subset (int): Subset index
"""
if self.callback is not None:
if type(self.callback) is list:
for recon_algo_k, callback_k in zip(self.reconstruction_algorithms, self.callback):
recon_algo_k.callback = callback_k
recon_algo_k._compute_callback(n_iter=n_iter, n_subset=n_subset)
else:
self.object_prediction = dicom.stitch_multibed(
recons=torch.stack(self.recons),
files_NM = self.files_NM)
super()._compute_callback(n_iter=n_iter, n_subset=n_subset)
[docs] def _finalize_callback(self):
"""Finalizes callbacks after reconstruction. This method is called after the reconstruction algorithm has finished.
"""
if self.callback is not None:
if type(self.callback) is list:
for recon_algo_k, callback_k in zip(self.reconstruction_algorithms, self.callback):
recon_algo_k.callback = callback_k
recon_algo_k.callback.finalize(recon_algo_k.object_prediction)
else:
self.callback.finalize(self.object_prediction)
[docs] def compute_uncertainty(
self,
mask: torch.Tensor,
data_storage_callbacks: Sequence[Callback],
subiteration_number: int | None = None,
return_pct: bool = False,
include_additive_term: bool = False
):
"""Estimates the uncertainty in a mask (should be same shape as the stitched image). Calling this method requires a sequence of ``DataStorageCallback`` instances that have been used in a reconstruction algorithm: these data storage contain required information for each bed position.
Args:
mask (torch.Tensor): Masked region of the reconstructed object: a boolean Tensor. This mask should be the same shape as the stitched object.
data_storage_callbacks (Sequence[Callback]): Sequence of data storage callbacks used in reconstruction corresponding to each bed position.
subiteration_number (int | None, optional): Subiteration number to compute the uncertainty for. If None, then computes the uncertainty for the last iteration. Defaults to None.
return_pct (bool, optional): If true, then additionally returns the percent uncertainty for the sum of counts. Defaults to False.
include_additive_term (bool): Whether or not to include uncertainty contribution from the additive term. This requires the ``additive_term_variance_estimate`` as an argument to the initialized likelihood. Defaults to False.
Returns:
float | Sequence[float]: Absolute uncertainty in the sum of counts in the masked region (if `return_pct` is False) OR absolute uncertainty and relative uncertainty in percent (if `return_pct` is True)
"""
if subiteration_number is None:
subiteration_number = len(data_storage_callbacks[0].objects) - 1
# Crop mask to FOV region
recons = [data_storage_callback.objects[subiteration_number].to(pytomography.device) for data_storage_callback in data_storage_callbacks]
stitching_weights, zs = dicom.stitch_multibed(torch.stack(recons), self.files_NM, return_stitching_weights=True)
uncertainty_abs = []
total_counts = 0
for k in range(len(recons)):
mask_k = mask[:,:,zs[k]:zs[k]+recons[0].shape[-1]] * stitching_weights[k]
if mask_k.sum()==0:
continue
uncertainty_abs_k = self.reconstruction_algorithms[k].compute_uncertainty(mask_k, data_storage_callbacks[k], subiteration_number, return_pct=False, include_additive_term=include_additive_term)
total_counts += (recons[k]*mask_k).sum().item()
uncertainty_abs.append(uncertainty_abs_k)
uncertainty_abs_total = torch.sqrt(torch.sum(torch.tensor(uncertainty_abs)**2)).item()
if return_pct:
uncertainty_pct = uncertainty_abs_total / total_counts * 100
return uncertainty_abs_total, uncertainty_pct
else:
return uncertainty_abs_total