from dataclasses import dataclass, asdict
from json import dump, load
from math import ceil
from os import PathLike
from random import randint, choice
from typing import Sequence, override, Callable, Self, Any, Literal
import numpy as np
import torch
from rich.console import Console
from rich.progress import Progress, SpinnerColumn
from torch import nn
from mipcandy.data.dataset import SupervisedDataset
from mipcandy.data.geometric import crop
from mipcandy.types import Shape, AmbiguousShape
[docs]
@dataclass
class InspectionAnnotation(object):
shape: AmbiguousShape
foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int]
class_ids: tuple[int, ...]
class_counts: dict[int, int]
class_bboxes: dict[int, tuple[int, int, int, int] | tuple[int, int, int, int, int, int]]
class_locations: dict[int, tuple[tuple[int, int] | tuple[int, int, int], ...]]
spacing: Shape | None = None
[docs]
def foreground_shape(self) -> Shape:
r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2])
return r if len(self.foreground_bbox) == 4 else r + (self.foreground_bbox[5] - self.foreground_bbox[4],)
[docs]
def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]:
r = (round((self.foreground_bbox[1] + self.foreground_bbox[0]) * .5),
round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5))
return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),)
[docs]
class InspectionAnnotations(Sequence[InspectionAnnotation]):
def __init__(self, dataset: SupervisedDataset, background: int, intensity_stats: tuple[float, float, float, float],
*annotations: InspectionAnnotation) -> None:
self._dataset: SupervisedDataset = dataset
self._background: int = background
self._intensity_stats: tuple[float, float, float, float] = intensity_stats
self._annotations: tuple[InspectionAnnotation, ...] = annotations
self._shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None
self._foreground_shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None
self._statistical_foreground_shape: Shape | None = None
self._center_of_foregrounds: tuple[int, int] | tuple[int, int, int] | None = None
self._foreground_offsets: tuple[int, int] | tuple[int, int, int] | None = None
self._roi_shape: Shape | None = None
[docs]
def dataset(self) -> SupervisedDataset:
return self._dataset
[docs]
def background(self) -> int:
return self._background
[docs]
def intensity_stats(self) -> tuple[float, float, float, float]:
"""
:return: mean, std, 0.5th percentile, 99.5th percentile
"""
return self._intensity_stats
[docs]
def annotations(self) -> tuple[InspectionAnnotation, ...]:
return self._annotations
[docs]
@override
def __getitem__(self, item: int) -> InspectionAnnotation:
return self._annotations[item]
[docs]
@override
def __len__(self) -> int:
return len(self._annotations)
[docs]
def save(self, path: str | PathLike[str]) -> None:
with open(path, "w") as f:
dump({
"background": self._background, "intensity_stats": self._intensity_stats,
"annotations": [asdict(a) for a in self._annotations]
}, f)
[docs]
def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape]) -> tuple[
tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
depths = []
widths = []
heights = []
for annotation in self._annotations:
shape = get_shape(annotation)
if len(shape) == 2:
heights.append(shape[0])
widths.append(shape[1])
else:
depths.append(shape[0])
heights.append(shape[1])
widths.append(shape[2])
return tuple(depths) if depths else None, tuple(heights), tuple(widths)
[docs]
def shapes(self) -> tuple[tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
if self._shapes:
return self._shapes
self._shapes = self._get_shapes(lambda annotation: annotation.shape)
return self._shapes
[docs]
def statistical_shape(self, *, percentile: float = .95) -> Shape:
depths, heights, widths = self.shapes()
percentile *= 100
sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile)))
return (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
[docs]
def foreground_shapes(self) -> tuple[tuple[int, ...] | None, tuple[int, ...], tuple[int, ...]]:
if self._foreground_shapes:
return self._foreground_shapes
self._foreground_shapes = self._get_shapes(lambda annotation: annotation.foreground_shape())
return self._foreground_shapes
[docs]
def statistical_foreground_shape(self, *, percentile: float = .95) -> Shape:
depths, heights, widths = self.foreground_shapes()
percentile *= 100
sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile)))
return (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
[docs]
def crop_foreground(self, i: int, *, expand_ratio: float = 1) -> tuple[torch.Tensor, torch.Tensor]:
image, label = self._dataset.image(i), self._dataset.label(i)
annotation = self._annotations[i]
bbox = list(annotation.foreground_bbox)
shape = annotation.foreground_shape()
for dim_idx, size in enumerate(shape):
left = int((expand_ratio - 1) * size // 2)
right = int((expand_ratio - 1) * size - left)
bbox[dim_idx * 2] = max(0, bbox[dim_idx * 2] - left)
bbox[dim_idx * 2 + 1] = min(bbox[dim_idx * 2 + 1] + right, annotation.shape[dim_idx])
return crop(image.unsqueeze(0), bbox).squeeze(0), crop(label.unsqueeze(0), bbox).squeeze(0)
[docs]
def foreground_heatmap(self) -> torch.Tensor:
depths, heights, widths = self.foreground_shapes()
max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths))
accumulated_label = torch.zeros((1, *max_shape), device=self._dataset.device())
for i in range(len(self._dataset)):
label = self._dataset.label(i)
annotation = self._annotations[i]
paddings = [0, 0, 0, 0]
shape = annotation.foreground_shape()
for j, size in enumerate(max_shape):
left = (size - shape[j]) // 2
right = size - shape[j] - left
paddings.append(right)
paddings.append(left)
paddings.reverse()
accumulated_label += nn.functional.pad(
crop((label != self._background).unsqueeze(0), annotation.foreground_bbox), paddings
).squeeze(0)
return accumulated_label.squeeze(0).detach()
[docs]
def center_of_foregrounds(self) -> tuple[int, int] | tuple[int, int, int]:
if self._center_of_foregrounds:
return self._center_of_foregrounds
heatmap = self.foreground_heatmap()
center = (heatmap.sum(dim=1).argmax().item(), heatmap.sum(dim=0).argmax().item()) if heatmap.ndim == 2 else (
heatmap.sum(dim=(1, 2)).argmax().item(),
heatmap.sum(dim=(0, 2)).argmax().item(),
heatmap.sum(dim=(0, 1)).argmax().item(),
)
self._center_of_foregrounds = center
return self._center_of_foregrounds
[docs]
def center_of_foregrounds_offsets(self) -> tuple[int, int] | tuple[int, int, int]:
if self._foreground_offsets:
return self._foreground_offsets
center = self.center_of_foregrounds()
depths, heights, widths = self.foreground_shapes()
max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths))
offsets = (round(center[0] - max_shape[0] * .5), round(center[1] - max_shape[1] * .5))
self._foreground_offsets = offsets + (round(center[2] - max_shape[2] * .5),) if depths else offsets
return self._foreground_offsets
[docs]
def set_roi_shape(self, roi_shape: Shape | None) -> None:
if roi_shape is not None:
depths, heights, widths = self.shapes()
if depths:
if roi_shape[0] > min(depths) or roi_shape[1] > min(heights) or roi_shape[2] > min(widths):
raise ValueError(
f"ROI shape {roi_shape} exceeds minimum image shape ({min(depths)}, {min(heights)}, {min(widths)})")
else:
if roi_shape[0] > min(heights) or roi_shape[1] > min(widths):
raise ValueError(
f"ROI shape {roi_shape} exceeds minimum image shape ({min(heights)}, {min(widths)})")
self._roi_shape = roi_shape
[docs]
def roi_shape(self, *, clamp: bool = True, percentile: float = .95) -> Shape:
if self._roi_shape:
return self._roi_shape
sfs = self.statistical_foreground_shape(percentile=percentile)
if clamp:
if len(sfs) == 2:
sfs = (None, *sfs)
depths, heights, widths = self.shapes()
roi_shape = (min(min(heights), sfs[1]), min(min(widths), sfs[2]))
if depths:
roi_shape = (min(min(depths), sfs[0]),) + roi_shape
self._roi_shape = roi_shape
else:
self._roi_shape = sfs
return self._roi_shape
[docs]
def roi(self, i: int, *, clamp: bool = True, percentile: float = .95) -> tuple[int, int, int, int] | tuple[
int, int, int, int, int, int]:
annotation = self._annotations[i]
roi_shape = self.roi_shape(clamp=clamp, percentile=percentile)
offsets = self.center_of_foregrounds_offsets()
center = annotation.center_of_foreground()
roi = []
for i, position in enumerate(center):
left = roi_shape[i] // 2
right = roi_shape[i] - left
offset = min(max(offsets[i], left - position), annotation.shape[i] - right - position)
roi.append(position + offset - left)
roi.append(position + offset + right)
return tuple(roi)
[docs]
def crop_roi(self, i: int, *, clamp: bool = True, percentile: float = .95) -> tuple[torch.Tensor, torch.Tensor]:
image, label = self._dataset.image(i), self._dataset.label(i)
roi = self.roi(i, clamp=clamp, percentile=percentile)
return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)
[docs]
def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs}
[docs]
def _str_indices_to_int_indices(obj: dict[str, Any]) -> dict[int, Any]:
return {int(k): v for k, v in obj.items()}
[docs]
def parse_inspection_annotation(obj: dict[str, Any]) -> InspectionAnnotation:
obj["class_bboxes"] = _str_indices_to_int_indices(obj["class_bboxes"])
obj["class_locations"] = _str_indices_to_int_indices(obj["class_locations"])
return InspectionAnnotation(**obj)
[docs]
def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations:
with open(path) as f:
obj = load(f, object_pairs_hook=_lists_to_tuples)
annotations = InspectionAnnotations(dataset, obj["background"], obj["intensity_stats"], *(
parse_inspection_annotation(row) for row in obj["annotations"]
))
return annotations
[docs]
def bbox_from_indices(indices: torch.Tensor, num_dim: Literal[2, 3]) -> tuple[int, int, int, int]:
mins = indices.min(dim=0)[0].tolist()
maxs = indices.max(dim=0)[0].tolist()
bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1)
if num_dim == 3:
bbox += (mins[3], maxs[3] + 1)
return bbox
[docs]
def inspect(dataset: SupervisedDataset, *, background: int = 0, max_samples: int = 10000,
console: Console = Console()) -> InspectionAnnotations:
r = []
with torch.no_grad(), Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress:
task = progress.add_task("Inspecting dataset...", total=len(dataset))
foreground_voxels = []
for idx in range(len(dataset)):
label = dataset.label(idx).int()
progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}")
ndim = label.ndim - 1
fg_mask = label != background
indices = fg_mask.nonzero()
if len(indices) == 0:
r.append(InspectionAnnotation(
tuple(label.shape[1:]), (0, 0, 0, 0) if ndim == 2 else (0, 0, 0, 0, 0, 0), (), {}, {}, {})
)
continue
foreground_bbox = bbox_from_indices(indices, ndim)
class_ids = label.unique().tolist()
class_counts = {}
class_bboxes = {}
class_locations = {}
for class_id in class_ids:
indices = (label == class_id).nonzero()
class_counts[class_id] = len(indices)
class_bboxes[class_id] = bbox_from_indices(indices, ndim)
if len(indices) > max_samples:
target_samples = min(max_samples, len(indices))
sampled_idx = torch.randperm(len(indices))[:target_samples]
indices = indices[sampled_idx]
class_locations[class_id] = tuple(tuple(loc) for loc in indices[:, 1:].tolist())
r.append(InspectionAnnotation(
tuple(label.shape[1:]), foreground_bbox, tuple(
class_id for class_id in class_ids if class_id != background
), class_counts, class_bboxes, class_locations
))
image = dataset.image(idx)
if image.shape[1:] != label.shape[1:]:
raise ValueError(f"Image shape {image.shape} does not match label shape {label.shape} spatially at"
f"index {idx}")
if image.shape[0] > 1:
fg_mask = fg_mask.expand_as(image)
fg = image[fg_mask]
if len(fg) > 0:
foreground_voxels.append(fg)
if len(foreground_voxels) == 0:
raise ValueError("No foreground voxels found in dataset")
all_fg = torch.cat(foreground_voxels)
all_fg_np = all_fg.cpu().numpy()
intensity_stats = (all_fg.mean().item(), all_fg.std().item(), float(np.percentile(all_fg_np, 0.5)),
float(np.percentile(all_fg_np, 99.5)))
return InspectionAnnotations(dataset, background, intensity_stats, *r)
[docs]
class ROIDataset(SupervisedDataset[list[int]]):
def __init__(self, annotations: InspectionAnnotations, *, clamp: bool = True, percentile: float = .95) -> None:
super().__init__(list(range(len(annotations))), list(range(len(annotations))),
transform=annotations.dataset().transform(), device=annotations.dataset().device())
self._annotations: InspectionAnnotations = annotations
self._clamp: bool = clamp
self._percentile: float = percentile
[docs]
@override
def construct_new(self, images: list[Any], labels: list[Any]) -> Self:
new = self.__class__(self._annotations, percentile=self._percentile)
new._images = images
new._labels = labels
return new
[docs]
@override
def load_image(self, idx: int) -> torch.Tensor:
raise NotImplementedError
[docs]
@override
def load_label(self, idx: int) -> torch.Tensor:
raise NotImplementedError
[docs]
@override
def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
i = self._images[idx]
if i != self._labels[idx]:
raise ValueError(f"Image {i} and label {self._labels[idx]} indices do not match")
with torch.no_grad():
return self._annotations.crop_roi(i, clamp=self._clamp, percentile=self._percentile)
[docs]
def crop_and_pad(x: torch.Tensor, bbox_lbs: list[int], bbox_ubs: list[int], *,
pad_value: int | float = 0) -> torch.Tensor:
shape = x.shape[1:]
dim = len(shape)
valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
slices = tuple([slice(0, x.shape[0])] + [slice(valid_bbox_lbs[i], valid_bbox_ubs[i]) for i in range(dim)])
cropped = x[slices]
padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
padding_torch = []
for left, right in reversed(padding):
padding_torch.extend([left, right])
padded = nn.functional.pad(cropped, padding_torch, mode="constant", value=pad_value)
return padded
[docs]
class RandomROIDataset(ROIDataset):
def __init__(self, annotations: InspectionAnnotations, batch_size: int, *, num_patches_per_case: int = 1,
oversample_rate: float = .33, clamp: bool = False, percentile: float = .5,
min_factor: int = 16) -> None:
super().__init__(annotations, clamp=clamp, percentile=percentile)
if num_patches_per_case > 1:
images = [idx for idx in self._images for _ in range(num_patches_per_case)]
self._images, self._labels = images, images.copy()
self._batch_size: int = batch_size
self._oversample_rate: float = oversample_rate
median_shape = self._annotations.statistical_shape(percentile=self._percentile)
median_shape = [ceil(s / min_factor) * min_factor for s in median_shape]
self._roi_shape: Shape = (min(median_shape[0], 2048), min(median_shape[1], 2048)) if len(
median_shape) == 2 else (min(median_shape[0], 128), min(median_shape[1], 128), min(median_shape[2], 128))
[docs]
def convert_idx(self, idx: int) -> int:
idx, idx2 = self._images[idx], self._labels[idx]
if idx != idx2:
raise ValueError(f"Image {idx} and label {idx2} indices do not match")
return idx
[docs]
def roi_shape(self, *, roi_shape: Shape | None = None) -> None | Shape:
if not roi_shape:
return self._roi_shape
self._roi_shape = roi_shape
[docs]
@override
def construct_new(self, images: list[Any], labels: list[Any]) -> Self:
new = self.__class__(self._annotations, self._batch_size, oversample_rate=self._oversample_rate,
clamp=self._clamp, percentile=self._percentile)
new._images = images
new._labels = labels
new._roi_shape = self._roi_shape
return new
[docs]
def random_roi(self, idx: int, force_foreground: bool) -> tuple[list[int], list[int]]:
idx = self.convert_idx(idx)
annotation = self._annotations[idx]
roi_shape = self._roi_shape
dim = len(annotation.shape)
need_to_pad = [max(0, roi_shape[i] - annotation.shape[i]) for i in range(dim)]
lbs = [-need_to_pad[i] // 2 for i in range(dim)]
ubs = [annotation.shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - roi_shape[i] for i in range(dim)]
if force_foreground and len(annotation.class_ids) > 0:
selected_class = choice(annotation.class_ids)
selected_voxel = choice(annotation.class_locations[selected_class])
bbox_lbs = [max(lbs[i], selected_voxel[i] - roi_shape[i] // 2) for i in range(dim)]
else:
bbox_lbs = [randint(lbs[i], ubs[i]) for i in range(dim)]
return bbox_lbs, [bbox_lbs[i] + roi_shape[i] for i in range(dim)]
[docs]
def oversample_foreground(self, idx: int) -> bool:
return idx % self._batch_size >= round(self._batch_size * (1 - self._oversample_rate))
[docs]
@override
def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
force_foreground = self.oversample_foreground(idx)
lbs, ubs = self.random_roi(idx, force_foreground)
dataset = self._annotations.dataset()
idx = self.convert_idx(idx)
return crop_and_pad(dataset.image(idx), lbs, ubs), crop_and_pad(dataset.label(idx), lbs, ubs)