Source code for pytomography.transforms.SPECT.cutoff
import torch
from pytomography.transforms import Transform
from pytomography.utils.spatial import pad_proj
[docs]class CutOffTransform(Transform):
def __init__(self, mask):
"""Transform that cuts off the projection data outside of a certain region. This is used to remove the background from the projection data.
Args:
mask (torch.Tensor): Mask to cut off the projection data
"""
super(CutOffTransform, self).__init__()
self.padded_mask = pad_proj(mask)
self.mask = mask
@torch.no_grad()
[docs] def forward(
self,
proj: torch.Tensor,
padded: bool = True,
) -> torch.tensor:
"""Cuts off the projection data outside of a certain region.
Args:
proj (torch.Tensor): Projection data
padded (bool, optional): Whether or not the projection data is padded. Defaults to True.
Returns:
torch.Tensor: Projection data with cutoff applied
"""
if padded:
return proj * self.padded_mask
else:
return proj * self.mask
@torch.no_grad()
[docs] def backward(
self,
proj: torch.Tensor,
padded: bool = True,
) -> torch.tensor:
"""Returns the projection data without the cutoff.
Args:
proj (torch.Tensor): Projection data
padded (bool, optional): Whether or not the projection data is padded. Defaults to True.
Returns:
torch.Tensor: Projection data without cutoff"""
if padded:
return proj * self.padded_mask
else:
return proj * self.mask