Source code for pytomography.algorithms.fbp

"""This module contains classes that implement filtered back projection reconstruction algorithms.
"""
from __future__ import annotations
import pytomography
import torch
from pytomography.projectors import SystemMatrix
from pytomography.utils import RampFilter

[docs]class FilteredBackProjection: r"""Implementation of filtered back projection reconstruction :math:`\hat{f} = \frac{\pi}{N_{\text{proj}}} \mathcal{R}^{-1}\mathcal{F}^{-1}\Pi\mathcal{F} g` where :math:`N_{\text{proj}}` is the number of projections, :math:`\mathcal{R}` is the 3D radon transform, :math:`\mathcal{F}` is the 2D Fourier transform (applied to each projection seperately), and :math:`\Pi` is the filter applied in Fourier space, which is by default the ramp filter. Args: projections (torch.Tensor): projection data :math:`g` to be reconstructed system_matrix (SystemMatrix): system matrix for the imaging system. In FBP, phenomena such as attenuation and PSF should not be implemented in the system matrix filter (Callable, optional): Additional Fourier space filter (applied after Ramp Filter) used during reconstruction. """ def __init__( self, projections: torch.Tensor, system_matrix: SystemMatrix, filter=RampFilter ) -> None: self.system_matrix = system_matrix self.filter = filter # Random transform equivalent to SPECT System matrix
[docs] def __call__(self, projections): """Applies reconstruction Returns: torch.tensor: Reconstructed object prediction """ freq_fft = torch.fft.fftfreq(projections.shape[-2]).reshape((-1,1)).to(pytomography.device) # only works for SPECT filter_total = self.filter()(freq_fft) proj_fft = torch.fft.fft(self.proj, axis=-2) proj_fft = proj_fft* filter_total proj_filtered = torch.fft.ifft(proj_fft, axis=-2).real object_prediction = self.system_matrix.backward(proj_filtered) * torch.pi / len(self.system_matrix.proj_meta.shape[0]) # assumes the "angle" index is the first of the system matrix return object_prediction