Source code for pytomography.callbacks.callback
from __future__ import annotations
import abc
import torch
[docs]class Callback():
"""Abstract class used for callbacks. Subclasses must redefine the ``__init__`` and ``run`` methods. If a callback is used as an argument in an iterative reconstruction algorihtm, the ``__run__`` method is called after each subiteration.
"""
@abc.abstractmethod
def __init__(self):
"""Abstract method for ``__init__``.
"""
...
@abc.abstractmethod
[docs] def run(self, object: torch.Tensor, n_iter: int):
"""Abstract method for ``run``.
Args:
object (torch.Tensor[Lx, Ly, Lz]): Object at current iteration/subset in the reconstruction algorithm
n_iter (int): The iteration number
Returns:
torch.Tensor: Modified object from callback. This must be returned by all callbacks (if the callback doesn't change the object, then the passed object is returned)
"""
return object
[docs] def finalize(self, object: torch.Tensor):
"""Abstract method for ``run``.
Args:
object (torch.Tensor[Lx, Ly, Lz]): Reconstructed object (all iterations/subsets completed)
"""
return None
[docs]class MultiCallback(Callback):
"""Class for combining multiple callbacks into a single callback. This is useful for passing multiple callbacks to an iterative reconstruction algorithm.
"""
def __init__(
self,
callbacks: list[Callback]
):
self.callbacks = callbacks
[docs] def run(self, object: torch.Tensor, n_iter: int, n_subset: int) -> torch.Tensor:
"""Runs the callbacks sequentially
Args:
object (torch.Tensor): Object at current iteration/subset in the reconstruction algorithm
n_iter (int): Iteration number
n_subset (int): Subset number
Returns:
torch.Tensor: Modified object from callback. This must be returned by all callbacks (if the callback doesn't change the object, then the passed object is returned)
"""
for callback in self.callbacks:
object = callback.run(object, n_iter, n_subset)
return object
[docs] def finalize(self, object: torch.Tensor):
"""Finalizes the callback
Args:
object (torch.Tensor): Reconstructed object
"""
for callback in self.callbacks:
callback.finalize(object)