Source code for pytomography.algorithms.dip_recon
from __future__ import annotations
import torch
import torch.nn as nn
from pytomography.likelihoods import Likelihood
from .preconditioned_gradient_ascent import OSEM
[docs]class DIPRecon:
r"""Implementation of the Deep Image Prior reconstruction technique (see https://ieeexplore.ieee.org/document/8581448). This reconstruction technique requires an instance of a user-defined ``prior_network`` that implements two functions: (i) a ``fit`` method that takes in an ``object`` (:math:`x`) which the network ``f(z;\theta)`` is subsequently fit to, and (ii) a ``predict`` function that returns the current network prediction :math:`f(z;\theta)`. For more details, see the Deep Image Prior tutorial.
Args:
likelihood (Likelihood): Initialized likelihood function for the imaging system considered
prior_network (nn.Module): User defined prior network that implements the neural network :math:`f(z;\theta)` that predicts an object given a prior image :math:`z`. This network also implements a ``fit`` method that takes in an object and fits the network to the object (for a specified number of iterations: SubIt2 in the paper).
rho (float, optional): Value of :math:`\rho` used in the optimization procedure. Larger values of :math:`rho` give larger weight to the neural network, while smaller values of :math:`rho` give larger weight to the EM updates. Defaults to 1.
"""
def __init__(
self,
likelihood: Likelihood,
prior_network: nn.Module,
rho: float = 3e-3,
) -> None:
self.EM_algorithm = OSEM(
likelihood,
object_initial = nn.ReLU()(prior_network.predict().clone())
)
self.likelihood = likelihood
self.prior_network = prior_network
self.rho = rho
[docs] def _compute_callback(self, n_iter: int, n_subset: int):
"""Method for computing callbacks after each reconstruction iteration
Args:
n_iter (int): Number of iterations
n_subset (int): Number of subsets
"""
self.object_prediction = self.callback.run(self.object_prediction, n_iter, n_subset)
[docs] def __call__(
self,
n_iters,
subit1,
n_subsets_osem=1,
callback=None,
):
r"""Implementation of Algorithm 1 in https://ieeexplore.ieee.org/document/8581448. This implementation gives the additional option to use ordered subsets. The quantity SubIt2 specified in the paper is controlled by the user-defined ``prior_network`` class.
Args:
n_iters (int): Number of iterations (MaxIt in paper)
subit1 (int): Number of OSEM iterations before retraining neural network (SubIt1 in paper)
n_subsets_osem (int, optional): Number of subsets to use in OSEM reconstruction. Defaults to 1.
Returns:
torch.Tensor: Reconstructed image
"""
self.callback = callback
# Initialize quantities
mu = 0
norm_BP = self.likelihood.system_matrix.compute_normalization_factor()
x = self.prior_network.predict()
x_network = x.clone()
for _ in range(n_iters):
for j in range(subit1):
for k in range(n_subsets_osem):
self.EM_algorithm.object_prediction = nn.ReLU()(x.clone())
x_EM = self.EM_algorithm(n_iters = 1, n_subsets = n_subsets_osem, n_subset_specific=k)
x = 0.5 * (x_network - mu - norm_BP / self.rho) + 0.5 * torch.sqrt((x_network - mu - norm_BP / self.rho)**2 + 4 * x_EM * norm_BP / self.rho)
self.prior_network.fit(x + mu)
x_network = self.prior_network.predict()
mu += x - x_network
self.object_prediction = nn.ReLU()(x_network)
# evaluate callback
if self.callback is not None:
self._compute_callback(n_iter = _, n_subset=None)
return self.object_prediction