Source code for mipcandy.common.module.preprocess

from math import ceil
from typing import Literal

import torch
from torch import nn

from mipcandy.types import Colormap, Shape2d, Shape3d, Shape, Paddings2d, Paddings3d, Paddings


[docs] def reverse_paddings(paddings: Paddings) -> Paddings: if len(paddings) == 4: return paddings[2], paddings[3], paddings[0], paddings[1] return paddings[4], paddings[5], paddings[2], paddings[3], paddings[0], paddings[1]
[docs] class Pad(nn.Module): def __init__(self, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: super().__init__() self._value: int = value self._mode: str = mode self.batch: bool = batch self._paddings: Paddings | None = None self.requires_grad_(False)
[docs] @staticmethod def _c_t(size: int, min_factor: int) -> int: """ Compute target on a single dimension """ return ceil(size / min_factor) * min_factor
[docs] @staticmethod def _c_p(size: int, min_factor: int) -> tuple[int, int]: """ Compute padding on a single dimension """ excess = Pad._c_t(size, min_factor) - size before = excess // 2 return before, excess - before
[docs] class Pad2d(Pad): def __init__(self, min_factor: int | Shape2d, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: super().__init__(value=value, mode=mode, batch=batch) self._min_factor: Shape2d = (min_factor,) * 2 if isinstance(min_factor, int) else min_factor
[docs] def paddings(self) -> Paddings2d | None: return self._paddings
[docs] def padded_shape(self, in_shape: tuple[int, int, ...]) -> tuple[int, int, ...]: return *in_shape[:-2], self._c_t(in_shape[-2], self._min_factor[0]), self._c_t( in_shape[-1], self._min_factor[1])
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.batch: _, _, h, w = x.shape suffix = (0,) * 4 else: _, h, w = x.shape suffix = (0,) * 2 self._paddings = self._c_p(h, self._min_factor[0]) + self._c_p(w, self._min_factor[1]) return nn.functional.pad(x, reverse_paddings(self._paddings) + suffix, self._mode, self._value)
[docs] class Pad3d(Pad): def __init__(self, min_factor: int | Shape3d, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: super().__init__(value=value, mode=mode, batch=batch) self._min_factor: Shape3d = (min_factor,) * 3 if isinstance(min_factor, int) else min_factor
[docs] def paddings(self) -> Paddings3d | None: return self._paddings
[docs] def padded_shape(self, in_shape: tuple[int, int, int, ...]) -> tuple[int, int, int, ...]: return (*in_shape[:-3], self._c_t(in_shape[-3], self._min_factor[0]), self._c_t( in_shape[-2], self._min_factor[1]), self._c_t(in_shape[-1], self._min_factor[2]))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.batch: _, _, d, h, w = x.shape suffix = (0,) * 4 else: _, d, h, w = x.shape suffix = (0,) * 2 self._paddings = self._c_p(d, self._min_factor[0]) + self._c_p(h, self._min_factor[1]) + self._c_p( w, self._min_factor[2]) return nn.functional.pad(x, reverse_paddings(self._paddings) + suffix, self._mode, self._value)
[docs] class Restore2d(nn.Module): def __init__(self, conjugate_padding: Pad2d) -> None: super().__init__() self.conjugate_padding: Pad2d = conjugate_padding self.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: paddings = self.conjugate_padding.paddings() if not paddings: raise ValueError("Paddings are not set yet, did you forget to pad before restoring?") pad_h0, pad_h1, pad_w0, pad_w1 = paddings if self.conjugate_padding.batch: _, _, h, w = x.shape return x[:, :, pad_h0: h - pad_h1, pad_w0: w - pad_w1] _, h, w = x.shape return x[:, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
[docs] class Restore3d(nn.Module): def __init__(self, conjugate_padding: Pad3d) -> None: super().__init__() self.conjugate_padding: Pad3d = conjugate_padding self.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: paddings = self.conjugate_padding.paddings() if not paddings: raise ValueError("Paddings are not set yet, did you forget to pad before restoring?") pad_d0, pad_d1, pad_h0, pad_h1, pad_w0, pad_w1 = paddings if self.conjugate_padding.batch: _, _, d, h, w = x.shape return x[:, :, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1] _, d, h, w = x.shape return x[:, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
[docs] class PadTo(Pad): def __init__(self, min_shape: Shape, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: super().__init__(value=value, mode=mode, batch=batch) self._min_shape: Shape = min_shape self._pad: Pad2d | Pad3d = (Pad2d if len(min_shape) == 2 else Pad3d)(min_shape, value=value, mode=mode, batch=batch)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self._pad(x) if any(x.shape[i + (2 if self.batch else 1)] < min_size for i, min_size in enumerate( self._min_shape)) else x
[docs] class Normalize(nn.Module): def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), strict: bool = False, method: Literal["linear", "intercept", "cut", "zscore"] = "linear") -> None: super().__init__() self._domain: tuple[float | None, float | None] = domain self._strict: bool = strict self._method: Literal["linear", "intercept", "cut", "zscore"] = method self.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: left, right = self._domain if left is None and right is None and self._method != "zscore": return x r_l, r_r = x.min(), x.max() match self._method: case "linear": if left is None or (left < r_l and not self._strict): left = r_l if right is None or (right > r_r and not self._strict): right = r_r numerator = right - left if numerator == 0: numerator = 1 denominator = r_r - r_l if denominator == 0: denominator = 1 return (x - r_l) * numerator / denominator + left case "intercept": if left is not None and right is None: return x - r_l + left if r_l < left or self._strict else x elif left is None and right is not None: return x - r_r + right if r_r > right or self._strict else x else: raise ValueError("Cannot use intercept normalization when both ends are fixed") case "cut": if self._strict: raise ValueError("Method \"cut\" cannot be strict") if left is not None: x = x.clamp(min=left) if right is not None: x = x.clamp(max=right) return x case "zscore": if left is not None or right is not None: raise ValueError("Method \"zscore\" cannot have fixed ends") return (x - x.mean()) / max(x.std(), torch.tensor(1e-8, device=x.device))
[docs] class CTNormalize(nn.Module): def __init__(self, mean_intensity: float, std_intensity: float, lower_bound: float, upper_bound: float) -> None: super().__init__() self._mean_intensity: float = mean_intensity self._std_intensity: float = std_intensity self._lower_bound: float = lower_bound self._upper_bound: float = upper_bound self.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return (x.clip(self._lower_bound, self._upper_bound) - self._mean_intensity) / max(self._std_intensity, 1e-8)
[docs] class ColorizeLabel(nn.Module): def __init__(self, *, colormap: Colormap | None = None, batch: bool = True) -> None: super().__init__() if not colormap: colormap = [] for r in range(8): for g in range(8): for b in range(32): colormap.append([r * 32, g * 32, 255 - b * 32]) self._colormap: torch.Tensor = torch.tensor(colormap) self._batch: bool = batch self.requires_grad_(False)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if not self._batch: x = x.unsqueeze(0) cmap = self._colormap.to(x.device) x = (cmap[(x > 0).int()] if 0 <= x.min() < x.max() <= 1 else cmap[x.int()]).movedim(-1, 1) return x if self._batch else x.squeeze(0)