Source code for pytomography.callbacks.data_saving
from __future__ import annotations
import torch
from .callback import Callback
from pytomography.likelihoods import Likelihood
import torch
[docs]class DataStorageCallback(Callback):
"""Callback that stores the object and forward projection at each iteration
Args:
likelihood (Likelihood): Likelihood function used in reconstruction
object_initial (torch.Tensor[Lx, Ly, Lz]): Initial object in the reconstruction algorithm
"""
def __init__(
self,
likelihood: Likelihood,
object_initial: torch.Tensor
) -> None:
self.object_previous = torch.clone(object_initial)
self.objects = []
self.projections_predicted = []
self.likelihood = likelihood
[docs] def run(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
"""Applies the callback
Args:
object (torch.Tensor[Lx, Ly, Lz]): Object at current iteration
n_iter (int): Current iteration number
n_subset (int): Current subset index
Returns:
torch.Tensor: Original object passed (object is not modifed)
"""
# Append from previous iteration
self.objects.append(self.object_previous.cpu())
# FP contains scatter
self.projections_predicted.append(self.likelihood.projections_predicted.cpu())
self.object_previous = torch.clone(object)
return object
[docs] def finalize(self, object: torch.Tensor):
"""Finalizes the callback after all iterations are called
Args:
object (torch.Tensor[Lx, Ly, Lz]): Reconstructed object (all iterations/subsets completed)
"""
self.objects.append(object.cpu())
self.projections_predicted.append(None)